package burlap.behavior.singleagent.learning.tdmethods;

import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.ValueFunctionInitialization;
import burlap.behavior.singleagent.options.Option;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.Iterator;
import java.util.LinkedList;

/* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/SarsaLam.class */
public class SarsaLam extends QLearning {
    protected double lambda;

    /* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/SarsaLam$EligibilityTrace.class */
    public static class EligibilityTrace {
        public double eligibility;
        public StateHashTuple sh;
        public QValue q;
        public double initialQ;

        public EligibilityTrace(StateHashTuple stateHashTuple, QValue qValue, double d) {
            this.sh = stateHashTuple;
            this.q = qValue;
            this.eligibility = d;
            this.initialQ = qValue.q;
        }
    }

    public SarsaLam(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, double d3, double d4) {
        super(domain, rewardFunction, terminalFunction, d, stateHashFactory, d2, d3);
        sarsalamInit(d4);
    }

    public SarsaLam(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, double d3, int i, double d4) {
        super(domain, rewardFunction, terminalFunction, d, stateHashFactory, d2, d3, i);
        sarsalamInit(d4);
    }

    public SarsaLam(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, double d3, Policy policy, int i, double d4) {
        super(domain, rewardFunction, terminalFunction, d, stateHashFactory, d2, d3, policy, i);
        sarsalamInit(d4);
    }

    public SarsaLam(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, ValueFunctionInitialization valueFunctionInitialization, double d2, Policy policy, int i, double d3) {
        super(domain, rewardFunction, terminalFunction, d, stateHashFactory, valueFunctionInitialization, d2, policy, i);
        sarsalamInit(d3);
    }

    protected void sarsalamInit(double d) {
        this.lambda = d;
    }

    @Override // burlap.behavior.singleagent.learning.tdmethods.QLearning, burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis runLearningEpisodeFrom(State state, int i) {
        double lastCumulativeReward;
        EpisodeAnalysis episodeAnalysis = new EpisodeAnalysis(state);
        this.maxQChangeInLastEpisode = 0.0d;
        StateHashTuple stateHash = stateHash(state);
        this.eStepCounter = 0;
        LinkedList linkedList = new LinkedList();
        GroundedAction groundedAction = (GroundedAction) this.learningPolicy.getAction(stateHash.s);
        QValue q = getQ(stateHash, groundedAction);
        while (!this.tf.isTerminal(stateHash.s) && this.eStepCounter < i) {
            StateHashTuple stateHash2 = stateHash(groundedAction.executeIn(stateHash.s));
            GroundedAction groundedAction2 = (GroundedAction) this.learningPolicy.getAction(stateHash2.s);
            QValue q2 = getQ(stateHash2, groundedAction2);
            double d = q2.q;
            if (this.tf.isTerminal(stateHash2.s)) {
                d = 0.0d;
            }
            double d2 = this.gamma;
            if (groundedAction.action.isPrimitive()) {
                lastCumulativeReward = this.rf.reward(stateHash.s, groundedAction, stateHash2.s);
                this.eStepCounter++;
                episodeAnalysis.recordTransitionTo(groundedAction, stateHash2.s, lastCumulativeReward);
            } else {
                Option option = (Option) groundedAction.action;
                lastCumulativeReward = option.getLastCumulativeReward();
                int lastNumSteps = option.getLastNumSteps();
                d2 = Math.pow(this.gamma, lastNumSteps);
                this.eStepCounter += lastNumSteps;
                if (this.shouldDecomposeOptions) {
                    episodeAnalysis.appendAndMergeEpisodeAnalysis(option.getLastExecutionResults());
                } else {
                    episodeAnalysis.recordTransitionTo(groundedAction, stateHash2.s, lastCumulativeReward);
                }
            }
            double d3 = (lastCumulativeReward + (d2 * d)) - q.q;
            boolean z = false;
            Iterator it = linkedList.iterator();
            while (it.hasNext()) {
                EligibilityTrace eligibilityTrace = (EligibilityTrace) it.next();
                if (eligibilityTrace.sh.equals(stateHash)) {
                    if (eligibilityTrace.q.a.equals(groundedAction)) {
                        z = true;
                        eligibilityTrace.eligibility = 1.0d;
                    } else {
                        eligibilityTrace.eligibility = 0.0d;
                    }
                }
                eligibilityTrace.q.q += this.learningRate.pollLearningRate(this.totalNumberOfSteps, eligibilityTrace.sh.s, eligibilityTrace.q.a) * eligibilityTrace.eligibility * d3;
                eligibilityTrace.eligibility = eligibilityTrace.eligibility * this.lambda * d2;
                double abs = Math.abs(eligibilityTrace.initialQ - eligibilityTrace.q.q);
                if (abs > this.maxQChangeInLastEpisode) {
                    this.maxQChangeInLastEpisode = abs;
                }
            }
            if (!z) {
                q.q += this.learningRate.pollLearningRate(this.totalNumberOfSteps, q.s, q.a) * d3;
                EligibilityTrace eligibilityTrace2 = new EligibilityTrace(stateHash, q, this.lambda * d2);
                linkedList.add(eligibilityTrace2);
                double abs2 = Math.abs(eligibilityTrace2.initialQ - eligibilityTrace2.q.q);
                if (abs2 > this.maxQChangeInLastEpisode) {
                    this.maxQChangeInLastEpisode = abs2;
                }
            }
            stateHash = stateHash2;
            groundedAction = groundedAction2;
            q = q2;
            this.totalNumberOfSteps++;
        }
        if (this.episodeHistory.size() >= this.numEpisodesToStore) {
            this.episodeHistory.poll();
        }
        this.episodeHistory.offer(episodeAnalysis);
        return episodeAnalysis;
    }
}
