package burlap.behavior.singleagent.learnbydemo.mlirl.differentiableplanners;

import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.ValueFunctionInitialization;
import burlap.behavior.singleagent.learnbydemo.mlirl.differentiableplanners.diffvinit.DifferentiableVInit;
import burlap.behavior.singleagent.learnbydemo.mlirl.differentiableplanners.diffvinit.VanillaDiffVinit;
import burlap.behavior.singleagent.learnbydemo.mlirl.support.BoltzmannPolicyGradient;
import burlap.behavior.singleagent.learnbydemo.mlirl.support.DifferentiableRF;
import burlap.behavior.singleagent.learnbydemo.mlirl.support.QGradientPlanner;
import burlap.behavior.singleagent.learnbydemo.mlirl.support.QGradientTuple;
import burlap.behavior.singleagent.planning.OOMDPPlanner;
import burlap.behavior.singleagent.planning.stochastic.sparsesampling.SparseSampling;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.datastructures.BoltzmannDistribution;
import burlap.debugtools.DPrint;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.core.TransitionProbability;
import burlap.oomdp.singleagent.GroundedAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/learnbydemo/mlirl/differentiableplanners/DifferentiableSparseSampling.class */
public class DifferentiableSparseSampling extends OOMDPPlanner implements QGradientPlanner {
    protected int h;
    protected int c;
    protected DifferentiableVInit vinit;
    protected Map<SparseSampling.HashedHeightState, DiffStateNode> nodesByHeight;
    protected Map<StateHashTuple, QAndQGradient> rootLevelQValues;
    protected double boltzBeta;
    protected int rfDim;
    protected boolean useVariableC = false;
    protected boolean forgetPreviousPlanResults = false;
    protected int numUpdates = 0;

    /* loaded from: input_file:burlap/behavior/singleagent/learnbydemo/mlirl/differentiableplanners/DifferentiableSparseSampling$DiffStateNode.class */
    public class DiffStateNode {
        StateHashTuple sh;
        int height;
        double v;
        double[] vgrad;
        boolean closed = false;

        public DiffStateNode(StateHashTuple stateHashTuple, int i) {
            this.sh = stateHashTuple;
            this.height = i;
        }

        public QAndQGradient estimateQs() {
            int i = DifferentiableSparseSampling.this.rfDim;
            List<GroundedAction> allGroundedActions = DifferentiableSparseSampling.this.getAllGroundedActions(this.sh.s);
            QAndQGradient qAndQGradient = new QAndQGradient(allGroundedActions.size());
            int cAtHeight = DifferentiableSparseSampling.this.getCAtHeight(this.height);
            for (GroundedAction groundedAction : allGroundedActions) {
                if (this.height == 0 || cAtHeight == 0) {
                    qAndQGradient.add(new QValue(this.sh.s, groundedAction, DifferentiableSparseSampling.this.vinit.value(this.sh.s)), new QGradientTuple(this.sh.s, groundedAction, DifferentiableSparseSampling.this.vinit.getQGradient(this.sh.s, groundedAction)));
                } else if (cAtHeight > 0) {
                    sampledBellmanQEstimate(groundedAction, qAndQGradient);
                } else {
                    fulldBellmanQEstimate(groundedAction, qAndQGradient);
                }
            }
            return qAndQGradient;
        }

        public void sampledBellmanQEstimate(GroundedAction groundedAction, QAndQGradient qAndQGradient) {
            double[] dArr = new double[DifferentiableSparseSampling.this.rfDim];
            double d = 0.0d;
            for (int i = 0; i < DifferentiableSparseSampling.this.c; i++) {
                State executeIn = groundedAction.executeIn(this.sh.s);
                double reward = DifferentiableSparseSampling.this.rf.reward(this.sh.s, groundedAction, executeIn);
                double[] gradient = ((DifferentiableRF) DifferentiableSparseSampling.this.rf).getGradient(this.sh.s, groundedAction, executeIn);
                VAndVGradient estimateV = DifferentiableSparseSampling.this.getStateNode(executeIn, this.height - 1).estimateV();
                d += reward + (DifferentiableSparseSampling.this.gamma * estimateV.v);
                for (int i2 = 0; i2 < gradient.length; i2++) {
                    int i3 = i2;
                    dArr[i3] = dArr[i3] + gradient[i2] + (DifferentiableSparseSampling.this.gamma * estimateV.vGrad[i2]);
                }
            }
            double d2 = d / DifferentiableSparseSampling.this.c;
            for (int i4 = 0; i4 < dArr.length; i4++) {
                int i5 = i4;
                dArr[i5] = dArr[i5] / DifferentiableSparseSampling.this.c;
            }
            qAndQGradient.add(new QValue(this.sh.s, groundedAction, d2), new QGradientTuple(this.sh.s, groundedAction, dArr));
        }

        public void fulldBellmanQEstimate(GroundedAction groundedAction, QAndQGradient qAndQGradient) {
            double[] dArr = new double[DifferentiableSparseSampling.this.rfDim];
            double d = 0.0d;
            for (TransitionProbability transitionProbability : groundedAction.action.getTransitions(this.sh.s, groundedAction.params)) {
                State state = transitionProbability.s;
                double reward = DifferentiableSparseSampling.this.rf.reward(this.sh.s, groundedAction, state);
                double[] gradient = ((DifferentiableRF) DifferentiableSparseSampling.this.rf).getGradient(this.sh.s, groundedAction, state);
                VAndVGradient estimateV = DifferentiableSparseSampling.this.getStateNode(state, this.height - 1).estimateV();
                d += transitionProbability.p * (reward + (DifferentiableSparseSampling.this.gamma * estimateV.v));
                for (int i = 0; i < gradient.length; i++) {
                    int i2 = i;
                    dArr[i2] = dArr[i2] + (transitionProbability.p * (gradient[i] + (DifferentiableSparseSampling.this.gamma * estimateV.vGrad[i])));
                }
            }
            qAndQGradient.add(new QValue(this.sh.s, groundedAction, d), new QGradientTuple(this.sh.s, groundedAction, dArr));
        }

        public VAndVGradient estimateV() {
            if (this.closed) {
                return new VAndVGradient(this.v, this.vgrad);
            }
            if (DifferentiableSparseSampling.this.tf.isTerminal(this.sh.s)) {
                this.v = 0.0d;
                this.vgrad = new double[DifferentiableSparseSampling.this.rfDim];
                this.closed = true;
                return new VAndVGradient(this.v, this.vgrad);
            }
            QAndQGradient estimateQs = estimateQs();
            setV(estimateQs);
            setVGrad(estimateQs);
            this.closed = true;
            DifferentiableSparseSampling.this.numUpdates++;
            return new VAndVGradient(this.v, this.vgrad);
        }

        protected void setV(QAndQGradient qAndQGradient) {
            double[] dArr = new double[qAndQGradient.qs.size()];
            for (int i = 0; i < qAndQGradient.qs.size(); i++) {
                dArr[i] = qAndQGradient.qs.get(i).q;
            }
            double[] probabilities = new BoltzmannDistribution(dArr, 1.0d / DifferentiableSparseSampling.this.boltzBeta).getProbabilities();
            double d = 0.0d;
            for (int i2 = 0; i2 < dArr.length; i2++) {
                d += dArr[i2] * probabilities[i2];
            }
            this.v = d;
        }

        protected void setVGrad(QAndQGradient qAndQGradient) {
            this.vgrad = new double[DifferentiableSparseSampling.this.rfDim];
            int length = this.vgrad.length;
            double[] dArr = new double[qAndQGradient.qs.size()];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = qAndQGradient.qs.get(i).q;
            }
            double[][] dArr2 = new double[dArr.length][length];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                double[] dArr3 = qAndQGradient.qGrads.get(i2).gradient;
                for (int i3 = 0; i3 < length; i3++) {
                    dArr2[i2][i3] = dArr3[i3];
                }
            }
            double maxBetaScaled = BoltzmannPolicyGradient.maxBetaScaled(dArr, DifferentiableSparseSampling.this.boltzBeta);
            double logSum = BoltzmannPolicyGradient.logSum(dArr, maxBetaScaled, DifferentiableSparseSampling.this.boltzBeta);
            for (int i4 = 0; i4 < dArr.length; i4++) {
                double exp = Math.exp((DifferentiableSparseSampling.this.boltzBeta * dArr[i4]) - logSum);
                double[] computePolicyGradient = BoltzmannPolicyGradient.computePolicyGradient((DifferentiableRF) DifferentiableSparseSampling.this.rf, DifferentiableSparseSampling.this.boltzBeta, dArr, maxBetaScaled, logSum, dArr2, i4);
                for (int i5 = 0; i5 < length; i5++) {
                    double[] dArr4 = this.vgrad;
                    int i6 = i5;
                    dArr4[i6] = dArr4[i6] + (exp * dArr2[i4][i5]) + (dArr[i4] * computePolicyGradient[i5]);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/singleagent/learnbydemo/mlirl/differentiableplanners/DifferentiableSparseSampling$QAndQGradient.class */
    public static class QAndQGradient {
        List<QValue> qs;
        List<QGradientTuple> qGrads;

        public QAndQGradient(List<QValue> list, List<QGradientTuple> list2) {
            this.qs = list;
            this.qGrads = list2;
        }

        public QAndQGradient(int i) {
            this.qs = new ArrayList(i);
            this.qGrads = new ArrayList(i);
        }

        public void add(QValue qValue, QGradientTuple qGradientTuple) {
            this.qs.add(qValue);
            this.qGrads.add(qGradientTuple);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/singleagent/learnbydemo/mlirl/differentiableplanners/DifferentiableSparseSampling$VAndVGradient.class */
    public static class VAndVGradient {
        double v;
        double[] vGrad;

        public VAndVGradient(double d, double[] dArr) {
            this.v = d;
            this.vGrad = dArr;
        }
    }

    public DifferentiableSparseSampling(Domain domain, DifferentiableRF differentiableRF, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, int i, int i2, double d2) {
        plannerInit(domain, differentiableRF, terminalFunction, d, stateHashFactory);
        this.h = i;
        this.c = i2;
        this.boltzBeta = d2;
        this.nodesByHeight = new HashMap();
        this.rootLevelQValues = new HashMap();
        this.rfDim = differentiableRF.getParameterDimension();
        this.vinit = new VanillaDiffVinit(new ValueFunctionInitialization.ConstantValueFunctionInitialization(), differentiableRF);
        this.debugCode = 6368290;
    }

    public void setUseVariableCSize(boolean z) {
        this.useVariableC = z;
    }

    public void setC(int i) {
        this.c = i;
    }

    public void setH(int i) {
        this.h = i;
    }

    public int getC() {
        return this.c;
    }

    public int getH() {
        return this.h;
    }

    public void setForgetPreviousPlanResults(boolean z) {
        this.forgetPreviousPlanResults = z;
        if (this.forgetPreviousPlanResults) {
            this.nodesByHeight.clear();
        }
    }

    public void setValueForLeafNodes(ValueFunctionInitialization valueFunctionInitialization) {
        if (valueFunctionInitialization instanceof DifferentiableVInit) {
            this.vinit = (DifferentiableVInit) valueFunctionInitialization;
        } else {
            this.vinit = new VanillaDiffVinit(valueFunctionInitialization, (DifferentiableRF) this.rf);
        }
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public int getDebugCode() {
        return this.debugCode;
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void setDebugCode(int i) {
        this.debugCode = i;
    }

    public int getNumberOfValueEsitmates() {
        return this.numUpdates;
    }

    @Override // burlap.behavior.singleagent.learnbydemo.mlirl.support.QGradientPlanner
    public void setBoltzmannBetaParameter(double d) {
        this.boltzBeta = d;
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public List<QValue> getQs(State state) {
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        QAndQGradient qAndQGradient = this.rootLevelQValues.get(hashState);
        if (qAndQGradient == null) {
            planFromState(state);
            qAndQGradient = this.rootLevelQValues.get(hashState);
        }
        return qAndQGradient.qs;
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public QValue getQ(State state, AbstractGroundedAction abstractGroundedAction) {
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        QAndQGradient qAndQGradient = this.rootLevelQValues.get(hashState);
        if (qAndQGradient == null) {
            planFromState(state);
            qAndQGradient = this.rootLevelQValues.get(hashState);
        }
        if (abstractGroundedAction.params.length > 0 && !this.domain.isObjectIdentifierDependent() && abstractGroundedAction.parametersAreObjects()) {
            abstractGroundedAction = abstractGroundedAction.translateParameters(state, this.mapToStateIndex.get(hashState).s);
        }
        for (QValue qValue : qAndQGradient.qs) {
            if (qValue.a.equals(abstractGroundedAction)) {
                return qValue;
            }
        }
        return null;
    }

    @Override // burlap.behavior.singleagent.learnbydemo.mlirl.support.QGradientPlanner
    public List<QGradientTuple> getAllQGradients(State state) {
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        QAndQGradient qAndQGradient = this.rootLevelQValues.get(hashState);
        if (qAndQGradient == null) {
            planFromState(state);
            qAndQGradient = this.rootLevelQValues.get(hashState);
        }
        return qAndQGradient.qGrads;
    }

    @Override // burlap.behavior.singleagent.learnbydemo.mlirl.support.QGradientPlanner
    public QGradientTuple getQGradient(State state, GroundedAction groundedAction) {
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        QAndQGradient qAndQGradient = this.rootLevelQValues.get(hashState);
        if (qAndQGradient == null) {
            planFromState(state);
            qAndQGradient = this.rootLevelQValues.get(hashState);
        }
        if (groundedAction.params.length > 0 && !this.domain.isObjectIdentifierDependent() && groundedAction.parametersAreObjects()) {
            groundedAction = (GroundedAction) groundedAction.translateParameters(state, this.mapToStateIndex.get(hashState).s);
        }
        for (QGradientTuple qGradientTuple : qAndQGradient.qGrads) {
            if (qGradientTuple.a.equals(groundedAction)) {
                return qGradientTuple;
            }
        }
        return null;
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void planFromState(State state) {
        if (this.forgetPreviousPlanResults) {
            this.rootLevelQValues.clear();
        }
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        if (this.rootLevelQValues.containsKey(hashState)) {
            return;
        }
        DPrint.cl(this.debugCode, "Beginning Planning.");
        int i = this.numUpdates;
        this.rootLevelQValues.put(hashState, getStateNode(state, this.h).estimateQs());
        DPrint.cl(this.debugCode, "Finished Planning with " + (this.numUpdates - i) + " value esitmates; for a cumulative total of: " + this.numUpdates);
        if (this.forgetPreviousPlanResults) {
            this.nodesByHeight.clear();
        }
        this.mapToStateIndex.put(hashState, hashState);
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void resetPlannerResults() {
        this.nodesByHeight.clear();
        this.rootLevelQValues.clear();
        this.numUpdates = 0;
    }

    protected int getCAtHeight(int i) {
        if (!this.useVariableC) {
            return this.c;
        }
        this.h = i;
        int pow = (int) (this.c * Math.pow(this.gamma, 2 * i));
        if (pow == 0) {
            pow = 1;
        }
        return pow;
    }

    protected DiffStateNode getStateNode(State state, int i) {
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        SparseSampling.HashedHeightState hashedHeightState = new SparseSampling.HashedHeightState(hashState, i);
        DiffStateNode diffStateNode = this.nodesByHeight.get(hashedHeightState);
        if (diffStateNode == null) {
            diffStateNode = new DiffStateNode(hashState, i);
            this.nodesByHeight.put(hashedHeightState, diffStateNode);
        }
        return diffStateNode;
    }
}
