package burlap.behavior.singleagent.learning.actorcritic.critics;

import burlap.behavior.singleagent.ValueFunctionInitialization;
import burlap.behavior.singleagent.learning.actorcritic.CritiqueResult;
import burlap.behavior.singleagent.learning.actorcritic.critics.TDLambda;
import burlap.behavior.singleagent.options.Option;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/learning/actorcritic/critics/TimeIndexedTDLambda.class */
public class TimeIndexedTDLambda extends TDLambda {
    protected List<Map<StateHashTuple, TDLambda.VValue>> vTIndex;
    protected int curTime;
    protected int maxEpisodeSize;

    /* loaded from: input_file:burlap/behavior/singleagent/learning/actorcritic/critics/TimeIndexedTDLambda$StateTimeElibilityTrace.class */
    public static class StateTimeElibilityTrace extends TDLambda.StateEligibilityTrace {
        public int timeIndex;

        public StateTimeElibilityTrace(StateHashTuple stateHashTuple, int i, double d, TDLambda.VValue vValue) {
            super(stateHashTuple, d, vValue);
            this.timeIndex = i;
        }
    }

    public TimeIndexedTDLambda(RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, double d3, double d4) {
        super(rewardFunction, terminalFunction, d, stateHashFactory, d2, d3, d4);
        this.maxEpisodeSize = Integer.MAX_VALUE;
        this.vTIndex = new ArrayList();
    }

    public TimeIndexedTDLambda(RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, double d3, double d4, int i) {
        super(rewardFunction, terminalFunction, d, stateHashFactory, d2, d3, d4);
        this.maxEpisodeSize = Integer.MAX_VALUE;
        this.maxEpisodeSize = i;
        this.vTIndex = new ArrayList();
    }

    public TimeIndexedTDLambda(RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, ValueFunctionInitialization valueFunctionInitialization, double d3, int i) {
        super(rewardFunction, terminalFunction, d, stateHashFactory, d2, valueFunctionInitialization, d3);
        this.maxEpisodeSize = Integer.MAX_VALUE;
        this.maxEpisodeSize = i;
        this.vTIndex = new ArrayList();
    }

    public int getCurTime() {
        return this.curTime;
    }

    public void setCurTime(int i) {
        this.curTime = i;
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.critics.TDLambda, burlap.behavior.singleagent.learning.actorcritic.Critic
    public void initializeEpisode(State state) {
        super.initializeEpisode(state);
        this.curTime = 0;
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.critics.TDLambda, burlap.behavior.singleagent.learning.actorcritic.Critic
    public void endEpisode() {
        super.endEpisode();
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.critics.TDLambda, burlap.behavior.singleagent.learning.actorcritic.Critic
    public CritiqueResult critiqueAndUpdate(State state, GroundedAction groundedAction, State state2) {
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        StateHashTuple hashState2 = this.hashingFactory.hashState(state2);
        double reward = this.rf.reward(state, groundedAction, state2);
        double d = this.gamma;
        int i = 1;
        if (groundedAction.action instanceof Option) {
            Option option = (Option) groundedAction.action;
            d = Math.pow(this.gamma, option.getLastNumSteps());
            i = option.getLastNumSteps();
        }
        TDLambda.VValue v = getV(hashState, this.curTime);
        double d2 = 0.0d;
        if (!this.tf.isTerminal(state2) && this.curTime < this.maxEpisodeSize - 1) {
            d2 = getV(hashState2, this.curTime + i).v;
        }
        double d3 = (reward + (d * d2)) - v.v;
        Iterator<TDLambda.StateEligibilityTrace> it = this.traces.iterator();
        while (it.hasNext()) {
            TDLambda.StateEligibilityTrace next = it.next();
            next.v.v += this.learningRate.pollLearningRate(this.totalNumberOfSteps, next.sh.s, null) * d3 * next.eligibility;
            next.eligibility = next.eligibility * this.lambda * d;
        }
        v.v += this.learningRate.pollLearningRate(this.totalNumberOfSteps, hashState.s, null) * d3;
        this.traces.add(new StateTimeElibilityTrace(hashState, this.curTime, d * this.lambda, v));
        this.curTime += i;
        CritiqueResult critiqueResult = new CritiqueResult(state, groundedAction, state2, d3);
        this.totalNumberOfSteps++;
        return critiqueResult;
    }

    protected TDLambda.VValue getV(StateHashTuple stateHashTuple, int i) {
        while (this.vTIndex.size() <= i) {
            this.vTIndex.add(new HashMap());
        }
        Map<StateHashTuple, TDLambda.VValue> map = this.vTIndex.get(i);
        TDLambda.VValue vValue = map.get(stateHashTuple);
        if (vValue == null) {
            vValue = new TDLambda.VValue(this.vInitFunction.value(stateHashTuple.s));
            map.put(stateHashTuple, vValue);
        }
        return vValue;
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.critics.TDLambda, burlap.behavior.singleagent.learning.actorcritic.Critic
    public void resetData() {
        super.resetData();
        this.vTIndex.clear();
    }
}
