package burlap.behavior.singleagent.learning.tdmethods;

import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.ValueFunctionInitialization;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.singleagent.options.Option;
import burlap.behavior.singleagent.planning.OOMDPPlanner;
import burlap.behavior.singleagent.planning.QComputablePlanner;
import burlap.behavior.singleagent.planning.commonpolicies.EpsilonGreedy;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.Action;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import javax.management.RuntimeErrorException;

/* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/QLearning.class */
public class QLearning extends OOMDPPlanner implements QComputablePlanner, LearningAgent {
    protected Map<StateHashTuple, QLearningStateNode> qIndex;
    protected ValueFunctionInitialization qInitFunction;
    protected LearningRate learningRate;
    protected Policy learningPolicy;
    protected int maxEpisodeSize;
    protected int eStepCounter;
    protected int numEpisodesForPlanning;
    protected double maxQChangeForPlanningTermination;
    protected LinkedList<EpisodeAnalysis> episodeHistory;
    protected int numEpisodesToStore;
    protected double maxQChangeInLastEpisode = Double.POSITIVE_INFINITY;
    protected boolean shouldDecomposeOptions = true;
    protected boolean shouldAnnotateOptions = true;
    protected int totalNumberOfSteps = 0;

    public QLearning(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, double d3) {
        QLInit(domain, rewardFunction, terminalFunction, d, stateHashFactory, new ValueFunctionInitialization.ConstantValueFunctionInitialization(d2), d3, new EpsilonGreedy(this, 0.1d), Integer.MAX_VALUE);
    }

    public QLearning(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, double d3, int i) {
        QLInit(domain, rewardFunction, terminalFunction, d, stateHashFactory, new ValueFunctionInitialization.ConstantValueFunctionInitialization(d2), d3, new EpsilonGreedy(this, 0.1d), i);
    }

    public QLearning(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, double d3, Policy policy, int i) {
        QLInit(domain, rewardFunction, terminalFunction, d, stateHashFactory, new ValueFunctionInitialization.ConstantValueFunctionInitialization(d2), d3, policy, i);
    }

    public QLearning(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, ValueFunctionInitialization valueFunctionInitialization, double d2, Policy policy, int i) {
        QLInit(domain, rewardFunction, terminalFunction, d, stateHashFactory, valueFunctionInitialization, d2, policy, i);
    }

    protected void QLInit(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, ValueFunctionInitialization valueFunctionInitialization, double d2, Policy policy, int i) {
        plannerInit(domain, rewardFunction, terminalFunction, d, stateHashFactory);
        this.qIndex = new HashMap();
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.learningPolicy = policy;
        this.maxEpisodeSize = i;
        this.qInitFunction = valueFunctionInitialization;
        this.numEpisodesToStore = 1;
        this.episodeHistory = new LinkedList<>();
        this.numEpisodesForPlanning = 1;
        this.maxQChangeForPlanningTermination = 0.0d;
    }

    public void setLearningRateFunction(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    public void setQInitFunction(ValueFunctionInitialization valueFunctionInitialization) {
        this.qInitFunction = valueFunctionInitialization;
    }

    public void setLearningPolicy(Policy policy) {
        this.learningPolicy = policy;
    }

    public void setMaximumEpisodesForPlanning(int i) {
        if (i > 0) {
            this.numEpisodesForPlanning = i;
        } else {
            this.numEpisodesForPlanning = 1;
        }
    }

    public void setMaxQChangeForPlanningTerminaiton(double d) {
        if (d > 0.0d) {
            this.maxQChangeForPlanningTermination = d;
        } else {
            this.maxQChangeForPlanningTermination = 0.0d;
        }
    }

    public int getLastNumSteps() {
        return this.eStepCounter;
    }

    public void toggleShouldDecomposeOption(boolean z) {
        this.shouldDecomposeOptions = z;
        for (Action action : this.actions) {
            if (action instanceof Option) {
                ((Option) action).toggleShouldRecordResults(z);
            }
        }
    }

    public void toggleShouldAnnotateOptionDecomposition(boolean z) {
        this.shouldAnnotateOptions = z;
        for (Action action : this.actions) {
            if (action instanceof Option) {
                ((Option) action).toggleShouldAnnotateResults(z);
            }
        }
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public List<QValue> getQs(State state) {
        return getQs(stateHash(state));
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public QValue getQ(State state, AbstractGroundedAction abstractGroundedAction) {
        return getQ(stateHash(state), (GroundedAction) abstractGroundedAction);
    }

    protected List<QValue> getQs(StateHashTuple stateHashTuple) {
        return getStateNode(stateHashTuple).qEntry;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public QValue getQ(StateHashTuple stateHashTuple, GroundedAction groundedAction) {
        QLearningStateNode stateNode = getStateNode(stateHashTuple);
        if (groundedAction.params.length > 0 && !this.domain.isObjectIdentifierDependent() && groundedAction.parametersAreObjects()) {
            groundedAction = translateAction(groundedAction, stateHashTuple.s.getObjectMatchingTo(stateNode.s.s, false));
        }
        for (QValue qValue : stateNode.qEntry) {
            if (qValue.a.equals(groundedAction)) {
                return qValue;
            }
        }
        return null;
    }

    protected QLearningStateNode getStateNode(StateHashTuple stateHashTuple) {
        QLearningStateNode qLearningStateNode = this.qIndex.get(stateHashTuple);
        if (qLearningStateNode == null) {
            qLearningStateNode = new QLearningStateNode(stateHashTuple);
            List<GroundedAction> allGroundedActions = getAllGroundedActions(stateHashTuple.s);
            if (allGroundedActions.size() == 0) {
                getAllGroundedActions(stateHashTuple.s);
                throw new RuntimeErrorException(new Error("No possible actions in this state, cannot continue Q-learning"));
            }
            for (GroundedAction groundedAction : allGroundedActions) {
                qLearningStateNode.addQValue(groundedAction, this.qInitFunction.qValue(stateHashTuple.s, groundedAction));
            }
            this.qIndex.put(stateHashTuple, qLearningStateNode);
        }
        return qLearningStateNode;
    }

    protected double getMaxQ(StateHashTuple stateHashTuple) {
        double d = Double.NEGATIVE_INFINITY;
        for (QValue qValue : getQs(stateHashTuple)) {
            if (qValue.q > d) {
                d = qValue.q;
            }
        }
        return d;
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void planFromState(State state) {
        int i = 0;
        do {
            runLearningEpisodeFrom(state);
            i++;
            if (i >= this.numEpisodesForPlanning) {
                return;
            }
        } while (this.maxQChangeInLastEpisode > this.maxQChangeForPlanningTermination);
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis runLearningEpisodeFrom(State state) {
        return runLearningEpisodeFrom(state, this.maxEpisodeSize);
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis runLearningEpisodeFrom(State state, int i) {
        double lastCumulativeReward;
        toggleShouldAnnotateOptionDecomposition(this.shouldAnnotateOptions);
        EpisodeAnalysis episodeAnalysis = new EpisodeAnalysis(state);
        StateHashTuple stateHash = stateHash(state);
        this.eStepCounter = 0;
        this.maxQChangeInLastEpisode = 0.0d;
        while (!this.tf.isTerminal(stateHash.s) && this.eStepCounter < i) {
            GroundedAction groundedAction = (GroundedAction) this.learningPolicy.getAction(stateHash.s);
            QValue q = getQ(stateHash, groundedAction);
            StateHashTuple stateHash2 = stateHash(groundedAction.executeIn(stateHash.s));
            double d = 0.0d;
            if (!this.tf.isTerminal(stateHash2.s)) {
                d = getMaxQ(stateHash2);
            }
            double d2 = this.gamma;
            if (groundedAction.action.isPrimitive()) {
                lastCumulativeReward = this.rf.reward(stateHash.s, groundedAction, stateHash2.s);
                this.eStepCounter++;
                episodeAnalysis.recordTransitionTo(groundedAction, stateHash2.s, lastCumulativeReward);
            } else {
                Option option = (Option) groundedAction.action;
                lastCumulativeReward = option.getLastCumulativeReward();
                int lastNumSteps = option.getLastNumSteps();
                d2 = Math.pow(this.gamma, lastNumSteps);
                this.eStepCounter += lastNumSteps;
                if (this.shouldDecomposeOptions) {
                    episodeAnalysis.appendAndMergeEpisodeAnalysis(option.getLastExecutionResults());
                } else {
                    episodeAnalysis.recordTransitionTo(groundedAction, stateHash2.s, lastCumulativeReward);
                }
            }
            double d3 = q.q;
            q.q += this.learningRate.pollLearningRate(this.totalNumberOfSteps, stateHash.s, groundedAction) * ((lastCumulativeReward + (d2 * d)) - q.q);
            double abs = Math.abs(d3 - q.q);
            if (abs > this.maxQChangeInLastEpisode) {
                this.maxQChangeInLastEpisode = abs;
            }
            stateHash = stateHash2;
            this.totalNumberOfSteps++;
        }
        if (this.episodeHistory.size() >= this.numEpisodesToStore) {
            this.episodeHistory.poll();
        }
        this.episodeHistory.offer(episodeAnalysis);
        return episodeAnalysis;
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public EpisodeAnalysis getLastLearningEpisode() {
        return this.episodeHistory.getLast();
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public void setNumEpisodesToStore(int i) {
        if (i > 0) {
            this.numEpisodesToStore = i;
        } else {
            this.numEpisodesToStore = 1;
        }
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public List<EpisodeAnalysis> getAllStoredLearningEpisodes() {
        return this.episodeHistory;
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void resetPlannerResults() {
        this.mapToStateIndex.clear();
        this.qIndex.clear();
        this.episodeHistory.clear();
        this.eStepCounter = 0;
        this.maxQChangeInLastEpisode = Double.POSITIVE_INFINITY;
    }
}
