package burlap.behavior.singleagent.learning.tdmethods.vfa;

import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.singleagent.options.Option;
import burlap.behavior.singleagent.planning.OOMDPPlanner;
import burlap.behavior.singleagent.planning.QComputablePlanner;
import burlap.behavior.singleagent.planning.commonpolicies.EpsilonGreedy;
import burlap.behavior.singleagent.vfa.ActionApproximationResult;
import burlap.behavior.singleagent.vfa.FunctionWeight;
import burlap.behavior.singleagent.vfa.ValueFunctionApproximation;
import burlap.behavior.singleagent.vfa.WeightGradient;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.Action;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;

/* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/vfa/GradientDescentSarsaLam.class */
public class GradientDescentSarsaLam extends OOMDPPlanner implements QComputablePlanner, LearningAgent {
    protected ValueFunctionApproximation vfa;
    protected LearningRate learningRate;
    protected Policy learningPolicy;
    protected double lambda;
    protected int maxEpisodeSize;
    protected int eStepCounter;
    protected int numEpisodesForPlanning;
    protected double maxWeightChangeForPlanningTermination;
    protected LinkedList<EpisodeAnalysis> episodeHistory;
    protected int numEpisodesToStore;
    protected double maxWeightChangeInLastEpisode = Double.POSITIVE_INFINITY;
    protected boolean useFeatureWiseLearningRate = true;
    protected double minEligibityForUpdate = 0.01d;
    protected boolean useReplacingTraces = false;
    protected boolean shouldDecomposeOptions = true;
    protected boolean shouldAnnotateOptions = true;
    protected int totalNumberOfSteps = 0;

    /* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/vfa/GradientDescentSarsaLam$EligibilityTraceVector.class */
    public static class EligibilityTraceVector {
        public FunctionWeight weight;
        public double eligibilityValue;
        public double initialWeightValue;

        public EligibilityTraceVector(FunctionWeight functionWeight, double d) {
            this.weight = functionWeight;
            this.eligibilityValue = d;
            this.initialWeightValue = functionWeight.weightValue();
        }
    }

    public GradientDescentSarsaLam(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, ValueFunctionApproximation valueFunctionApproximation, double d2, double d3) {
        GDSLInit(domain, rewardFunction, terminalFunction, d, valueFunctionApproximation, d2, new EpsilonGreedy(this, 0.1d), Integer.MAX_VALUE, d3);
    }

    public GradientDescentSarsaLam(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, ValueFunctionApproximation valueFunctionApproximation, double d2, int i, double d3) {
        GDSLInit(domain, rewardFunction, terminalFunction, d, valueFunctionApproximation, d2, new EpsilonGreedy(this, 0.1d), i, d3);
    }

    public GradientDescentSarsaLam(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, ValueFunctionApproximation valueFunctionApproximation, double d2, Policy policy, int i, double d3) {
        GDSLInit(domain, rewardFunction, terminalFunction, d, valueFunctionApproximation, d2, policy, i, d3);
    }

    protected void GDSLInit(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, ValueFunctionApproximation valueFunctionApproximation, double d2, Policy policy, int i, double d3) {
        plannerInit(domain, rewardFunction, terminalFunction, d, null);
        this.vfa = valueFunctionApproximation;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.learningPolicy = policy;
        this.maxEpisodeSize = i;
        this.lambda = d3;
        this.numEpisodesToStore = 1;
        this.episodeHistory = new LinkedList<>();
        this.numEpisodesForPlanning = 1;
        this.maxWeightChangeForPlanningTermination = 0.0d;
    }

    public void setLearningRate(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    public void setUseFeatureWiseLearningRate(boolean z) {
        this.useFeatureWiseLearningRate = z;
    }

    public void setLearningPolicy(Policy policy) {
        this.learningPolicy = policy;
    }

    public void setMaximumEpisodesForPlanning(int i) {
        if (i > 0) {
            this.numEpisodesForPlanning = i;
        } else {
            this.numEpisodesForPlanning = 1;
        }
    }

    public void setMaxVFAWeightChangeForPlanningTerminaiton(double d) {
        if (d > 0.0d) {
            this.maxWeightChangeForPlanningTermination = d;
        } else {
            this.maxWeightChangeForPlanningTermination = 0.0d;
        }
    }

    public int getLastNumSteps() {
        return this.eStepCounter;
    }

    public void setUseReplaceTraces(boolean z) {
        this.useReplacingTraces = z;
    }

    public void toggleShouldDecomposeOption(boolean z) {
        this.shouldDecomposeOptions = z;
        for (Action action : this.actions) {
            if (action instanceof Option) {
                ((Option) action).toggleShouldRecordResults(z);
            }
        }
    }

    public void toggleShouldAnnotateOptionDecomposition(boolean z) {
        this.shouldAnnotateOptions = z;
        for (Action action : this.actions) {
            if (action instanceof Option) {
                ((Option) action).toggleShouldAnnotateResults(z);
            }
        }
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis runLearningEpisodeFrom(State state) {
        return runLearningEpisodeFrom(state, this.maxEpisodeSize);
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis runLearningEpisodeFrom(State state, int i) {
        double lastCumulativeReward;
        EpisodeAnalysis episodeAnalysis = new EpisodeAnalysis(state);
        this.maxWeightChangeInLastEpisode = 0.0d;
        State state2 = state;
        this.eStepCounter = 0;
        HashMap hashMap = new HashMap();
        GroundedAction groundedAction = (GroundedAction) this.learningPolicy.getAction(state2);
        List<ActionApproximationResult> allActionApproximations = getAllActionApproximations(state2);
        ActionApproximationResult extractApproximationForAction = ActionApproximationResult.extractApproximationForAction(allActionApproximations, groundedAction);
        while (!this.tf.isTerminal(state2) && this.eStepCounter < i) {
            WeightGradient weightGradient = this.vfa.getWeightGradient(extractApproximationForAction.approximationResult);
            State executeIn = groundedAction.executeIn(state2);
            GroundedAction groundedAction2 = (GroundedAction) this.learningPolicy.getAction(executeIn);
            List<ActionApproximationResult> allActionApproximations2 = getAllActionApproximations(executeIn);
            ActionApproximationResult extractApproximationForAction2 = ActionApproximationResult.extractApproximationForAction(allActionApproximations2, groundedAction2);
            double d = extractApproximationForAction2.approximationResult.predictedValue;
            if (this.tf.isTerminal(executeIn)) {
                d = 0.0d;
            }
            double d2 = this.gamma;
            if (groundedAction.action.isPrimitive()) {
                lastCumulativeReward = this.rf.reward(state2, groundedAction, executeIn);
                this.eStepCounter++;
                episodeAnalysis.recordTransitionTo(groundedAction, executeIn, lastCumulativeReward);
            } else {
                Option option = (Option) groundedAction.action;
                lastCumulativeReward = option.getLastCumulativeReward();
                int lastNumSteps = option.getLastNumSteps();
                d2 = Math.pow(this.gamma, lastNumSteps);
                this.eStepCounter += lastNumSteps;
                if (this.shouldDecomposeOptions) {
                    episodeAnalysis.appendAndMergeEpisodeAnalysis(option.getLastExecutionResults());
                } else {
                    episodeAnalysis.recordTransitionTo(groundedAction, executeIn, lastCumulativeReward);
                }
            }
            double d3 = (lastCumulativeReward + (d2 * d)) - extractApproximationForAction.approximationResult.predictedValue;
            if (this.useReplacingTraces) {
                for (ActionApproximationResult actionApproximationResult : allActionApproximations) {
                    if (actionApproximationResult.ga.equals(groundedAction)) {
                        Iterator<FunctionWeight> it = actionApproximationResult.approximationResult.functionWeights.iterator();
                        while (it.hasNext()) {
                            EligibilityTraceVector eligibilityTraceVector = (EligibilityTraceVector) hashMap.get(Integer.valueOf(it.next().weightId()));
                            if (eligibilityTraceVector != null) {
                                eligibilityTraceVector.eligibilityValue = 0.0d;
                            }
                        }
                    } else {
                        Iterator<FunctionWeight> it2 = actionApproximationResult.approximationResult.functionWeights.iterator();
                        while (it2.hasNext()) {
                            hashMap.remove(Integer.valueOf(it2.next().weightId()));
                        }
                    }
                }
            }
            double pollLearningRate = this.useFeatureWiseLearningRate ? 0.0d : this.learningRate.pollLearningRate(this.totalNumberOfSteps, state2, groundedAction);
            HashSet hashSet = new HashSet();
            for (EligibilityTraceVector eligibilityTraceVector2 : hashMap.values()) {
                int weightId = eligibilityTraceVector2.weight.weightId();
                if (this.useFeatureWiseLearningRate) {
                    pollLearningRate = this.learningRate.pollLearningRate(this.totalNumberOfSteps, eligibilityTraceVector2.weight.weightId());
                }
                eligibilityTraceVector2.eligibilityValue += weightGradient.getPartialDerivative(weightId);
                double weightValue = eligibilityTraceVector2.weight.weightValue() + (pollLearningRate * d3 * eligibilityTraceVector2.eligibilityValue);
                eligibilityTraceVector2.weight.setWeight(weightValue);
                double abs = Math.abs(eligibilityTraceVector2.initialWeightValue - weightValue);
                if (abs > this.maxWeightChangeInLastEpisode) {
                    this.maxWeightChangeInLastEpisode = abs;
                }
                eligibilityTraceVector2.eligibilityValue *= this.lambda * d2;
                if (eligibilityTraceVector2.eligibilityValue < this.minEligibityForUpdate) {
                    hashSet.add(Integer.valueOf(weightId));
                }
            }
            for (FunctionWeight functionWeight : extractApproximationForAction.approximationResult.functionWeights) {
                int weightId2 = functionWeight.weightId();
                if (!hashMap.containsKey(functionWeight)) {
                    if (this.useFeatureWiseLearningRate) {
                        pollLearningRate = this.learningRate.pollLearningRate(this.totalNumberOfSteps, weightId2);
                    }
                    EligibilityTraceVector eligibilityTraceVector3 = new EligibilityTraceVector(functionWeight, weightGradient.getPartialDerivative(weightId2));
                    double weightValue2 = functionWeight.weightValue() + (pollLearningRate * d3 * eligibilityTraceVector3.eligibilityValue);
                    functionWeight.setWeight(weightValue2);
                    double abs2 = Math.abs(eligibilityTraceVector3.initialWeightValue - weightValue2);
                    if (abs2 > this.maxWeightChangeInLastEpisode) {
                        this.maxWeightChangeInLastEpisode = abs2;
                    }
                    eligibilityTraceVector3.eligibilityValue *= this.lambda * d2;
                    if (eligibilityTraceVector3.eligibilityValue >= this.minEligibityForUpdate) {
                        hashMap.put(Integer.valueOf(weightId2), eligibilityTraceVector3);
                    }
                }
            }
            Iterator it3 = hashSet.iterator();
            while (it3.hasNext()) {
                hashMap.remove((Integer) it3.next());
            }
            state2 = executeIn;
            groundedAction = groundedAction2;
            extractApproximationForAction = extractApproximationForAction2;
            allActionApproximations = allActionApproximations2;
            this.totalNumberOfSteps++;
        }
        if (this.episodeHistory.size() >= this.numEpisodesToStore) {
            this.episodeHistory.poll();
            this.episodeHistory.offer(episodeAnalysis);
        }
        return episodeAnalysis;
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis getLastLearningEpisode() {
        return this.episodeHistory.getLast();
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public void setNumEpisodesToStore(int i) {
        if (i > 0) {
            this.numEpisodesToStore = i;
        } else {
            this.numEpisodesToStore = 1;
        }
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public List<EpisodeAnalysis> getAllStoredLearningEpisodes() {
        return this.episodeHistory;
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public List<QValue> getQs(State state) {
        List<GroundedAction> allGroundedActions = getAllGroundedActions(state);
        ArrayList arrayList = new ArrayList(allGroundedActions.size());
        List<ActionApproximationResult> stateActionValues = this.vfa.getStateActionValues(state, allGroundedActions);
        Iterator<GroundedAction> it = allGroundedActions.iterator();
        while (it.hasNext()) {
            arrayList.add(getQFromFeaturesFor(stateActionValues, state, it.next()));
        }
        return arrayList;
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public QValue getQ(State state, AbstractGroundedAction abstractGroundedAction) {
        ArrayList arrayList = new ArrayList(1);
        arrayList.add((GroundedAction) abstractGroundedAction);
        return getQFromFeaturesFor(this.vfa.getStateActionValues(state, arrayList), state, (GroundedAction) abstractGroundedAction);
    }

    protected QValue getQFromFeaturesFor(List<ActionApproximationResult> list, State state, GroundedAction groundedAction) {
        return new QValue(state, groundedAction, ActionApproximationResult.extractApproximationForAction(list, groundedAction).approximationResult.predictedValue);
    }

    protected List<ActionApproximationResult> getAllActionApproximations(State state) {
        return this.vfa.getStateActionValues(state, getAllGroundedActions(state));
    }

    protected ActionApproximationResult getActionApproximation(State state, GroundedAction groundedAction) {
        ArrayList arrayList = new ArrayList(1);
        arrayList.add(groundedAction);
        return ActionApproximationResult.extractApproximationForAction(this.vfa.getStateActionValues(state, arrayList), groundedAction);
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void planFromState(State state) {
        int i = 0;
        do {
            runLearningEpisodeFrom(state);
            i++;
            if (i >= this.numEpisodesForPlanning) {
                return;
            }
        } while (this.maxWeightChangeInLastEpisode > this.maxWeightChangeForPlanningTermination);
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void resetPlannerResults() {
        this.vfa.resetWeights();
        this.eStepCounter = 0;
        this.maxWeightChangeInLastEpisode = Double.POSITIVE_INFINITY;
        this.episodeHistory.clear();
    }
}
