package burlap.behavior.singleagent.learning.lspi;

import burlap.debugtools.RandomFactory;
import burlap.oomdp.auxiliary.StateGenerator;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.Action;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.List;

/* loaded from: input_file:burlap/behavior/singleagent/learning/lspi/SARSCollector.class */
public abstract class SARSCollector {
    protected List<Action> actions;

    /* loaded from: input_file:burlap/behavior/singleagent/learning/lspi/SARSCollector$UniformRandomSARSCollector.class */
    public static class UniformRandomSARSCollector extends SARSCollector {
        public UniformRandomSARSCollector(Domain domain) {
            super(domain);
        }

        public UniformRandomSARSCollector(List<Action> list) {
            super(list);
        }

        @Override // burlap.behavior.singleagent.learning.lspi.SARSCollector
        public SARSData collectDataFrom(State state, RewardFunction rewardFunction, int i, TerminalFunction terminalFunction, SARSData sARSData) {
            if (sARSData == null) {
                sARSData = new SARSData();
            }
            State state2 = state;
            for (int i2 = 0; !terminalFunction.isTerminal(state2) && i2 < i; i2++) {
                List<GroundedAction> allApplicableGroundedActionsFromActionList = Action.getAllApplicableGroundedActionsFromActionList(this.actions, state2);
                GroundedAction groundedAction = allApplicableGroundedActionsFromActionList.get(RandomFactory.getMapped(0).nextInt(allApplicableGroundedActionsFromActionList.size()));
                State executeIn = groundedAction.executeIn(state2);
                sARSData.add(state2, groundedAction, rewardFunction.reward(state2, groundedAction, executeIn), executeIn);
                state2 = executeIn;
            }
            return sARSData;
        }
    }

    public SARSCollector(Domain domain) {
        this.actions = domain.getActions();
    }

    public SARSCollector(List<Action> list) {
        this.actions = list;
    }

    public abstract SARSData collectDataFrom(State state, RewardFunction rewardFunction, int i, TerminalFunction terminalFunction, SARSData sARSData);

    public SARSData collectNInstances(StateGenerator stateGenerator, RewardFunction rewardFunction, int i, int i2, TerminalFunction terminalFunction, SARSData sARSData) {
        if (sARSData == null) {
            sARSData = new SARSData(i);
        }
        while (i > 0) {
            int min = Math.min(i, i2);
            int size = sARSData.size();
            collectDataFrom(stateGenerator.generateState(), rewardFunction, min, terminalFunction, sARSData);
            i -= sARSData.size() - size;
        }
        return sARSData;
    }
}
