package burlap.behavior.singleagent.planning.stochastic.rtdp;

import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.ValueFunctionInitialization;
import burlap.behavior.singleagent.planning.ValueFunctionPlanner;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.debugtools.DPrint;
import burlap.debugtools.RandomFactory;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.core.TransitionProbability;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/rtdp/BoundedRTDP.class */
public class BoundedRTDP extends ValueFunctionPlanner {
    protected ValueFunctionInitialization lowerVInit;
    protected ValueFunctionInitialization upperVInit;
    protected int maxRollouts;
    protected double maxDiff;
    protected Map<StateHashTuple, Double> lowerBoundV = new HashMap();
    protected Map<StateHashTuple, Double> upperBoundV = new HashMap();
    protected int maxDepth = -1;
    protected boolean currentValueFunctionIsLower = false;
    protected boolean defaultToLowerValueAfterPlanning = true;
    protected StateSelectionMode selectionMode = StateSelectionMode.MODELBASED;
    protected int numBellmanUpdates = 0;
    protected int numSteps = 0;
    protected boolean runRolloutsInReverse = true;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/rtdp/BoundedRTDP$StateSelectionAndExpectedGap.class */
    public static class StateSelectionAndExpectedGap {
        public StateHashTuple sh;
        public double expectedGap;

        public StateSelectionAndExpectedGap(StateHashTuple stateHashTuple, double d) {
            this.sh = stateHashTuple;
            this.expectedGap = d;
        }
    }

    /* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/rtdp/BoundedRTDP$StateSelectionMode.class */
    public enum StateSelectionMode {
        MODELBASED,
        WEIGHTEDMARGIN,
        MAXMARGIN
    }

    public BoundedRTDP(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, ValueFunctionInitialization valueFunctionInitialization, ValueFunctionInitialization valueFunctionInitialization2, double d2, int i) {
        VFPInit(domain, rewardFunction, terminalFunction, d, stateHashFactory);
        this.lowerVInit = valueFunctionInitialization;
        this.upperVInit = valueFunctionInitialization2;
        this.maxDiff = d2;
        this.useCachedTransitions = false;
    }

    public void setMaxNumberOfRollouts(int i) {
        this.maxRollouts = i;
    }

    public void setMaxRolloutDepth(int i) {
        this.maxDepth = i;
    }

    public void setMaxDifference(double d) {
        this.maxDiff = d;
    }

    public void setStateSelectionMode(StateSelectionMode stateSelectionMode) {
        this.selectionMode = stateSelectionMode;
    }

    public void setDefaultValueFunctionAfterARollout(boolean z) {
        this.defaultToLowerValueAfterPlanning = z;
    }

    public void setRunRolloutsInRevere(boolean z) {
        this.runRolloutsInReverse = z;
    }

    @Override // burlap.behavior.singleagent.planning.ValueFunctionPlanner, burlap.behavior.singleagent.planning.OOMDPPlanner
    public void planFromState(State state) {
        DPrint.cl(this.debugCode, "Beginning Planning.");
        for (int i = 0; runRollout(state) > this.maxDiff && (i < this.maxRollouts || this.maxRollouts == -1); i++) {
        }
        DPrint.cl(this.debugCode, "Finished planning with a total of " + this.numBellmanUpdates + " backups.");
    }

    public void setValueFunctionToUpperBound() {
        this.valueFunction = this.upperBoundV;
        this.valueInitializer = this.upperVInit;
        this.currentValueFunctionIsLower = false;
    }

    public void setValueFunctionToLowerBound() {
        this.valueFunction = this.lowerBoundV;
        this.valueInitializer = this.lowerVInit;
        this.currentValueFunctionIsLower = true;
    }

    public int getNumberOfBellmanUpdates() {
        return this.numBellmanUpdates;
    }

    public int getNumberOfSteps() {
        return this.numSteps;
    }

    public double runRollout(State state) {
        LinkedList linkedList = new LinkedList();
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        while (!this.tf.isTerminal(hashState.s) && (linkedList.size() < this.maxDepth + 1 || this.maxDepth == -1)) {
            if (this.runRolloutsInReverse) {
                linkedList.offerFirst(hashState);
            }
            setValueFunctionToLowerBound();
            this.lowerBoundV.put(hashState, Double.valueOf(maxQ(hashState.s).q));
            setValueFunctionToUpperBound();
            QValue maxQ = maxQ(hashState.s);
            this.upperBoundV.put(hashState, Double.valueOf(maxQ.q));
            this.numBellmanUpdates += 2;
            this.numSteps++;
            StateSelectionAndExpectedGap nextState = getNextState(hashState.s, (GroundedAction) maxQ.a);
            hashState = nextState.sh;
            if (nextState.expectedGap < this.maxDiff) {
                break;
            }
        }
        if (this.tf.isTerminal(hashState.s)) {
            this.lowerBoundV.put(hashState, Double.valueOf(0.0d));
            this.upperBoundV.put(hashState, Double.valueOf(0.0d));
        }
        double d = 0.0d;
        if (this.runRolloutsInReverse) {
            while (linkedList.size() > 0) {
                StateHashTuple stateHashTuple = (StateHashTuple) linkedList.pop();
                setValueFunctionToLowerBound();
                QValue maxQ2 = maxQ(stateHashTuple.s);
                this.lowerBoundV.put(stateHashTuple, Double.valueOf(maxQ2.q));
                setValueFunctionToUpperBound();
                QValue maxQ3 = maxQ(stateHashTuple.s);
                this.upperBoundV.put(stateHashTuple, Double.valueOf(maxQ3.q));
                this.numBellmanUpdates += 2;
                d = maxQ3.q - maxQ2.q;
            }
        } else {
            d = getGap(this.hashingFactory.hashState(state));
        }
        if (this.defaultToLowerValueAfterPlanning) {
            setValueFunctionToLowerBound();
        } else {
            setValueFunctionToUpperBound();
        }
        return d;
    }

    protected StateSelectionAndExpectedGap getNextState(State state, GroundedAction groundedAction) {
        if (this.selectionMode == StateSelectionMode.MODELBASED) {
            StateHashTuple hashState = this.hashingFactory.hashState(groundedAction.executeIn(state));
            return new StateSelectionAndExpectedGap(hashState, getGap(hashState));
        }
        if (this.selectionMode == StateSelectionMode.WEIGHTEDMARGIN) {
            return getNextStateBySampling(state, groundedAction);
        }
        if (this.selectionMode == StateSelectionMode.MAXMARGIN) {
            return getNextStateByMaxMargin(state, groundedAction);
        }
        throw new RuntimeException("Unknown state selection mode.");
    }

    protected StateSelectionAndExpectedGap getNextStateByMaxMargin(State state, GroundedAction groundedAction) {
        List<TransitionProbability> transitions = groundedAction.action.getTransitions(state, groundedAction.params);
        double d = 0.0d;
        double d2 = Double.NEGATIVE_INFINITY;
        ArrayList arrayList = new ArrayList(transitions.size());
        for (TransitionProbability transitionProbability : transitions) {
            StateHashTuple hashState = this.hashingFactory.hashState(transitionProbability.s);
            double gap = getGap(hashState);
            d += transitionProbability.p * gap;
            if (gap == d2) {
                arrayList.add(hashState);
            } else if (gap > d2) {
                arrayList.clear();
                arrayList.add(hashState);
                d2 = gap;
            }
        }
        return new StateSelectionAndExpectedGap((StateHashTuple) arrayList.get(RandomFactory.getMapped(0).nextInt(arrayList.size())), d);
    }

    protected StateSelectionAndExpectedGap getNextStateBySampling(State state, GroundedAction groundedAction) {
        List<TransitionProbability> transitions = groundedAction.action.getTransitions(state, groundedAction.params);
        double d = 0.0d;
        double[] dArr = new double[transitions.size()];
        StateHashTuple[] stateHashTupleArr = new StateHashTuple[transitions.size()];
        for (int i = 0; i < transitions.size(); i++) {
            TransitionProbability transitionProbability = transitions.get(i);
            StateHashTuple hashState = this.hashingFactory.hashState(transitionProbability.s);
            stateHashTupleArr[i] = hashState;
            dArr[i] = transitionProbability.p * getGap(hashState);
            d += dArr[i];
        }
        double nextDouble = RandomFactory.getMapped(0).nextDouble();
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            d2 += dArr[i2] / d;
            if (nextDouble < d2) {
                return new StateSelectionAndExpectedGap(stateHashTupleArr[i2], d);
            }
        }
        throw new RuntimeException("Error: probabilities in state selection did not sum to 1.");
    }

    protected double getGap(StateHashTuple stateHashTuple) {
        setValueFunctionToLowerBound();
        double value = value(stateHashTuple);
        setValueFunctionToUpperBound();
        return value(stateHashTuple) - value;
    }

    protected QValue maxQ(State state) {
        List<QValue> qs = getQs(state);
        double d = Double.NEGATIVE_INFINITY;
        ArrayList arrayList = new ArrayList(qs.size());
        for (QValue qValue : qs) {
            if (qValue.q == d) {
                arrayList.add(qValue);
            } else if (qValue.q > d) {
                d = qValue.q;
                arrayList.clear();
                arrayList.add(qValue);
            }
        }
        return (QValue) arrayList.get(RandomFactory.getMapped(0).nextInt(arrayList.size()));
    }
}
