package burlap.behavior.singleagent.learning.lspi;

import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.singleagent.learning.lspi.SARSCollector;
import burlap.behavior.singleagent.learning.lspi.SARSData;
import burlap.behavior.singleagent.planning.OOMDPPlanner;
import burlap.behavior.singleagent.planning.QComputablePlanner;
import burlap.behavior.singleagent.planning.commonpolicies.EpsilonGreedy;
import burlap.behavior.singleagent.planning.commonpolicies.GreedyQPolicy;
import burlap.behavior.singleagent.vfa.ActionApproximationResult;
import burlap.behavior.singleagent.vfa.ActionFeaturesQuery;
import burlap.behavior.singleagent.vfa.FeatureDatabase;
import burlap.behavior.singleagent.vfa.StateFeature;
import burlap.behavior.singleagent.vfa.ValueFunctionApproximation;
import burlap.behavior.singleagent.vfa.common.LinearVFA;
import burlap.debugtools.DPrint;
import burlap.oomdp.auxiliary.common.ConstantStateGenerator;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:burlap/behavior/singleagent/learning/lspi/LSPI.class */
public class LSPI extends OOMDPPlanner implements QComputablePlanner, LearningAgent {
    protected ValueFunctionApproximation vfa;
    protected SARSData dataset;
    protected FeatureDatabase featureDatabase;
    protected SimpleMatrix lastWeights;
    protected SARSCollector planningCollector;
    protected Policy learningPolicy;
    protected int numEpisodesToStore;
    protected double identityScalar = 100.0d;
    protected int numSamplesForPlanning = 10000;
    protected double maxChange = 1.0E-6d;
    protected int maxNumPlanningIterations = 30;
    protected int maxLearningSteps = Integer.MAX_VALUE;
    protected int numStepsSinceLastLearningPI = 0;
    protected int minNewStepsForLearningPI = 100;
    protected LinkedList<EpisodeAnalysis> episodeHistory = new LinkedList<>();

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/singleagent/learning/lspi/LSPI$SSFeatures.class */
    public class SSFeatures {
        public List<ActionFeaturesQuery> sActionFeatures;
        public List<ActionFeaturesQuery> sPrimeActionFeatures;

        public SSFeatures(List<ActionFeaturesQuery> list, List<ActionFeaturesQuery> list2) {
            this.sActionFeatures = list;
            this.sPrimeActionFeatures = list2;
        }
    }

    public LSPI(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, FeatureDatabase featureDatabase) {
        plannerInit(domain, rewardFunction, terminalFunction, d, null);
        this.featureDatabase = featureDatabase;
        this.vfa = new LinearVFA(this.featureDatabase);
        this.learningPolicy = new EpsilonGreedy(this, 0.1d);
    }

    public void setDataset(SARSData sARSData) {
        this.dataset = sARSData;
    }

    public SARSData getDataset() {
        return this.dataset;
    }

    public FeatureDatabase getFeatureDatabase() {
        return this.featureDatabase;
    }

    public void setFeatureDatabase(FeatureDatabase featureDatabase) {
        this.featureDatabase = featureDatabase;
    }

    public double getIdentityScalar() {
        return this.identityScalar;
    }

    public void setIdentityScalar(double d) {
        this.identityScalar = d;
    }

    public int getNumSamplesForPlanning() {
        return this.numSamplesForPlanning;
    }

    public void setNumSamplesForPlanning(int i) {
        this.numSamplesForPlanning = i;
    }

    public SARSCollector getPlanningCollector() {
        return this.planningCollector;
    }

    public void setPlanningCollector(SARSCollector sARSCollector) {
        this.planningCollector = sARSCollector;
    }

    public int getMaxNumPlanningIterations() {
        return this.maxNumPlanningIterations;
    }

    public void setMaxNumPlanningIterations(int i) {
        this.maxNumPlanningIterations = i;
    }

    public Policy getLearningPolicy() {
        return this.learningPolicy;
    }

    public void setLearningPolicy(Policy policy) {
        this.learningPolicy = policy;
    }

    public int getMaxLearningSteps() {
        return this.maxLearningSteps;
    }

    public void setMaxLearningSteps(int i) {
        this.maxLearningSteps = i;
    }

    public int getMinNewStepsForLearningPI() {
        return this.minNewStepsForLearningPI;
    }

    public void setMinNewStepsForLearningPI(int i) {
        this.minNewStepsForLearningPI = i;
    }

    public double getMaxChange() {
        return this.maxChange;
    }

    public void setMaxChange(double d) {
        this.maxChange = d;
    }

    public SimpleMatrix LSTDQ() {
        GreedyQPolicy greedyQPolicy = new GreedyQPolicy(this);
        ArrayList arrayList = new ArrayList(this.dataset.size());
        for (SARSData.SARS sars : this.dataset.dataset) {
            arrayList.add(new SSFeatures(this.featureDatabase.getActionFeaturesSets(sars.s, gaListWrapper(sars.a)), this.featureDatabase.getActionFeaturesSets(sars.sp, gaListWrapper(greedyQPolicy.getAction(sars.sp)))));
        }
        int numberOfFeatures = this.featureDatabase.numberOfFeatures();
        SimpleMatrix scale = SimpleMatrix.identity(numberOfFeatures).scale(this.identityScalar);
        SimpleMatrix simpleMatrix = new SimpleMatrix(numberOfFeatures, 1);
        for (int i = 0; i < arrayList.size(); i++) {
            SimpleMatrix phiConstructor = phiConstructor(((SSFeatures) arrayList.get(i)).sActionFeatures, numberOfFeatures);
            SimpleMatrix phiConstructor2 = phiConstructor(((SSFeatures) arrayList.get(i)).sPrimeActionFeatures, numberOfFeatures);
            double d = this.dataset.get(i).r;
            scale = scale.minus(scale.mult(phiConstructor).mult(phiConstructor.minus(phiConstructor2.scale(this.gamma)).transpose()).mult(scale).scale(1.0d / (phiConstructor.minus(phiConstructor2.scale(this.gamma)).transpose().mult(scale).mult(phiConstructor).get(0) + 1.0d)));
            simpleMatrix = simpleMatrix.plus(phiConstructor.scale(d));
        }
        SimpleMatrix mult = scale.mult(simpleMatrix);
        this.vfa = new LinearVFA(this.featureDatabase);
        for (int i2 = 0; i2 < numberOfFeatures; i2++) {
            this.vfa.setWeight(i2, mult.get(i2, 0));
        }
        return mult;
    }

    public void runPolicyIteration(int i, double d) {
        boolean z = false;
        for (int i2 = 0; i2 < i && !z; i2++) {
            SimpleMatrix LSTDQ = LSTDQ();
            double d2 = Double.POSITIVE_INFINITY;
            if (this.lastWeights != null) {
                d2 = this.lastWeights.minus(LSTDQ).normF();
                if (d2 <= d) {
                    z = true;
                }
            }
            this.lastWeights = LSTDQ;
            DPrint.cl(0, "Finished iteration: " + i2 + ". Weight change: " + d2);
        }
        DPrint.cl(0, "Finished Policy Iteration.");
    }

    protected SimpleMatrix phiConstructor(List<ActionFeaturesQuery> list, int i) {
        SimpleMatrix simpleMatrix = new SimpleMatrix(i, 1);
        if (list.size() != 1) {
            throw new RuntimeException("Expected only one actions's set of features.");
        }
        for (StateFeature stateFeature : list.get(0).features) {
            simpleMatrix.set(stateFeature.id, stateFeature.value);
        }
        return simpleMatrix;
    }

    protected List<GroundedAction> gaListWrapper(AbstractGroundedAction abstractGroundedAction) {
        ArrayList arrayList = new ArrayList(1);
        arrayList.add((GroundedAction) abstractGroundedAction);
        return arrayList;
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public List<QValue> getQs(State state) {
        List<GroundedAction> allGroundedActions = getAllGroundedActions(state);
        ArrayList arrayList = new ArrayList(allGroundedActions.size());
        List<ActionApproximationResult> stateActionValues = this.vfa.getStateActionValues(state, allGroundedActions);
        Iterator<GroundedAction> it = allGroundedActions.iterator();
        while (it.hasNext()) {
            arrayList.add(getQFromFeaturesFor(stateActionValues, state, it.next()));
        }
        return arrayList;
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public QValue getQ(State state, AbstractGroundedAction abstractGroundedAction) {
        ArrayList arrayList = new ArrayList(1);
        arrayList.add((GroundedAction) abstractGroundedAction);
        return getQFromFeaturesFor(this.vfa.getStateActionValues(state, arrayList), state, (GroundedAction) abstractGroundedAction);
    }

    protected QValue getQFromFeaturesFor(List<ActionApproximationResult> list, State state, GroundedAction groundedAction) {
        return new QValue(state, groundedAction, ActionApproximationResult.extractApproximationForAction(list, groundedAction).approximationResult.predictedValue);
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void planFromState(State state) {
        if (this.planningCollector == null) {
            this.planningCollector = new SARSCollector.UniformRandomSARSCollector(this.actions);
        }
        this.dataset = this.planningCollector.collectNInstances(new ConstantStateGenerator(state), this.rf, this.numSamplesForPlanning, Integer.MAX_VALUE, this.tf, this.dataset);
        runPolicyIteration(this.maxNumPlanningIterations, this.maxChange);
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void resetPlannerResults() {
        this.dataset.clear();
        this.vfa.resetWeights();
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis runLearningEpisodeFrom(State state) {
        return runLearningEpisodeFrom(state, this.maxLearningSteps);
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis runLearningEpisodeFrom(State state, int i) {
        EpisodeAnalysis evaluateBehavior = this.learningPolicy.evaluateBehavior(state, this.rf, this.tf, i);
        updateDatasetWithLearningEpisode(evaluateBehavior);
        if (shouldRereunPolicyIteration(evaluateBehavior)) {
            runPolicyIteration(this.maxNumPlanningIterations, this.maxChange);
            this.numStepsSinceLastLearningPI = 0;
        } else {
            this.numStepsSinceLastLearningPI += evaluateBehavior.numTimeSteps() - 1;
        }
        if (this.episodeHistory.size() >= this.numEpisodesToStore) {
            this.episodeHistory.poll();
        }
        this.episodeHistory.offer(evaluateBehavior);
        return evaluateBehavior;
    }

    protected void updateDatasetWithLearningEpisode(EpisodeAnalysis episodeAnalysis) {
        if (this.dataset == null) {
            this.dataset = new SARSData(episodeAnalysis.numTimeSteps() - 1);
        }
        for (int i = 0; i < episodeAnalysis.numTimeSteps() - 1; i++) {
            this.dataset.add(episodeAnalysis.getState(i), episodeAnalysis.getAction(i), episodeAnalysis.getReward(i + 1), episodeAnalysis.getState(i + 1));
        }
    }

    protected boolean shouldRereunPolicyIteration(EpisodeAnalysis episodeAnalysis) {
        return (this.numStepsSinceLastLearningPI + episodeAnalysis.numTimeSteps()) - 1 > this.minNewStepsForLearningPI;
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis getLastLearningEpisode() {
        return this.episodeHistory.getLast();
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public void setNumEpisodesToStore(int i) {
        this.numEpisodesToStore = i;
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public List<EpisodeAnalysis> getAllStoredLearningEpisodes() {
        return this.episodeHistory;
    }
}
