package burlap.behavior.singleagent.learning.actorcritic.critics;

import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.singleagent.ValueFunctionInitialization;
import burlap.behavior.singleagent.learning.actorcritic.Critic;
import burlap.behavior.singleagent.learning.actorcritic.CritiqueResult;
import burlap.behavior.singleagent.options.Option;
import burlap.behavior.singleagent.options.OptionEvaluatingRF;
import burlap.behavior.singleagent.planning.ValueFunction;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.Action;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/learning/actorcritic/critics/TDLambda.class */
public class TDLambda implements Critic, ValueFunction {
    protected RewardFunction rf;
    protected TerminalFunction tf;
    protected double gamma;
    protected StateHashFactory hashingFactory;
    protected LearningRate learningRate;
    protected ValueFunctionInitialization vInitFunction;
    protected double lambda;
    protected LinkedList<StateEligibilityTrace> traces;
    protected int totalNumberOfSteps = 0;
    protected Map<StateHashTuple, VValue> vIndex = new HashMap();

    /* loaded from: input_file:burlap/behavior/singleagent/learning/actorcritic/critics/TDLambda$StateEligibilityTrace.class */
    public static class StateEligibilityTrace {
        public double eligibility;
        public StateHashTuple sh;
        public VValue v;

        public StateEligibilityTrace(StateHashTuple stateHashTuple, double d, VValue vValue) {
            this.sh = stateHashTuple;
            this.eligibility = d;
            this.v = vValue;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:burlap/behavior/singleagent/learning/actorcritic/critics/TDLambda$VValue.class */
    public class VValue {
        public double v;

        public VValue(double d) {
            this.v = d;
        }
    }

    public TDLambda(RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, double d3, double d4) {
        this.rf = rewardFunction;
        this.tf = terminalFunction;
        this.gamma = d;
        this.hashingFactory = stateHashFactory;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.vInitFunction = new ValueFunctionInitialization.ConstantValueFunctionInitialization(d3);
        this.lambda = d4;
    }

    public TDLambda(RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, ValueFunctionInitialization valueFunctionInitialization, double d3) {
        this.rf = rewardFunction;
        this.tf = terminalFunction;
        this.gamma = d;
        this.hashingFactory = stateHashFactory;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.vInitFunction = valueFunctionInitialization;
        this.lambda = d3;
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.Critic
    public void addNonDomainReferencedAction(Action action) {
        if (!(action instanceof Option) || (this.rf instanceof OptionEvaluatingRF)) {
            return;
        }
        this.rf = new OptionEvaluatingRF(this.rf);
    }

    public void setRewardFunction(RewardFunction rewardFunction) {
        this.rf = rewardFunction;
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.Critic
    public void initializeEpisode(State state) {
        this.traces = new LinkedList<>();
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.Critic
    public void endEpisode() {
        this.traces.clear();
    }

    public void setLearningRate(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.Critic
    public CritiqueResult critiqueAndUpdate(State state, GroundedAction groundedAction, State state2) {
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        StateHashTuple hashState2 = this.hashingFactory.hashState(state2);
        double reward = this.rf.reward(state, groundedAction, state2);
        double d = this.gamma;
        if (groundedAction.action instanceof Option) {
            d = Math.pow(this.gamma, ((Option) groundedAction.action).getLastNumSteps());
        }
        VValue v = getV(hashState);
        double d2 = 0.0d;
        if (!this.tf.isTerminal(state2)) {
            d2 = getV(hashState2).v;
        }
        double d3 = (reward + (d * d2)) - v.v;
        boolean z = false;
        Iterator<StateEligibilityTrace> it = this.traces.iterator();
        while (it.hasNext()) {
            StateEligibilityTrace next = it.next();
            if (next.sh.equals(hashState)) {
                z = true;
                next.eligibility = 1.0d;
            }
            next.v.v += this.learningRate.pollLearningRate(this.totalNumberOfSteps, next.sh.s, null) * d3 * next.eligibility;
            next.eligibility = next.eligibility * this.lambda * d;
        }
        if (!z) {
            v.v += this.learningRate.pollLearningRate(this.totalNumberOfSteps, hashState.s, null) * d3;
            this.traces.add(new StateEligibilityTrace(hashState, d * this.lambda, v));
        }
        CritiqueResult critiqueResult = new CritiqueResult(state, groundedAction, state2, d3);
        this.totalNumberOfSteps++;
        return critiqueResult;
    }

    @Override // burlap.behavior.singleagent.planning.ValueFunction
    public double value(State state) {
        return getV(this.hashingFactory.hashState(state)).v;
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.Critic
    public void resetData() {
        this.vIndex.clear();
        this.traces.clear();
        this.learningRate.resetDecay();
    }

    protected VValue getV(StateHashTuple stateHashTuple) {
        VValue vValue = this.vIndex.get(stateHashTuple);
        if (vValue == null) {
            vValue = new VValue(this.vInitFunction.value(stateHashTuple.s));
            this.vIndex.put(stateHashTuple, vValue);
        }
        return vValue;
    }
}
