package burlap.behavior.singleagent.learnbydemo.mlirl.support;

import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.planning.OOMDPPlanner;
import burlap.oomdp.core.State;
import burlap.oomdp.singleagent.GroundedAction;
import java.util.List;

/* loaded from: input_file:burlap/behavior/singleagent/learnbydemo/mlirl/support/BoltzmannPolicyGradient.class */
public class BoltzmannPolicyGradient {
    /* JADX WARN: Multi-variable type inference failed */
    public static double[] computeBoltzmannPolicyGradient(State state, GroundedAction groundedAction, QGradientPlanner qGradientPlanner, double d) {
        DifferentiableRF differentiableRF = (DifferentiableRF) ((OOMDPPlanner) qGradientPlanner).getRF();
        int parameterDimension = differentiableRF.getParameterDimension();
        double[] dArr = new double[parameterDimension];
        for (int i = 0; i < parameterDimension; i++) {
            dArr[i] = 0.0d;
        }
        List<QValue> qs = qGradientPlanner.getQs(state);
        double[] dArr2 = new double[qs.size()];
        for (int i2 = 0; i2 < qs.size(); i2++) {
            dArr2[i2] = qs.get(i2).q;
        }
        int i3 = -1;
        int i4 = 0;
        while (true) {
            if (i4 >= qs.size()) {
                break;
            }
            if (qs.get(i4).a.equals(groundedAction)) {
                i3 = i4;
                break;
            }
            i4++;
        }
        if (i3 == -1) {
            throw new RuntimeException("Error in computing BoltzmannPolicyGradient: Could not find query action in Q-value list.");
        }
        double[][] dArr3 = new double[dArr2.length][parameterDimension];
        for (int i5 = 0; i5 < dArr2.length; i5++) {
            double[] dArr4 = qGradientPlanner.getQGradient(state, (GroundedAction) qs.get(i5).a).gradient;
            for (int i6 = 0; i6 < parameterDimension; i6++) {
                dArr3[i5][i6] = dArr4[i6];
            }
        }
        double maxBetaScaled = maxBetaScaled(dArr2, d);
        return computePolicyGradient(differentiableRF, d, dArr2, maxBetaScaled, logSum(dArr2, maxBetaScaled, d), dArr3, i3);
    }

    public static double[] computePolicyGradient(DifferentiableRF differentiableRF, double d, double[] dArr, double d2, double d3, double[][] dArr2, int i) {
        int parameterDimension = differentiableRF.getParameterDimension();
        double[] dArr3 = new double[parameterDimension];
        double exp = d * Math.exp((((d * dArr[i]) + d2) - d3) - d3);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            for (int i3 = 0; i3 < parameterDimension; i3++) {
                int i4 = i3;
                dArr3[i4] = dArr3[i4] + ((dArr2[i][i3] - dArr2[i2][i3]) * Math.exp((d * dArr[i2]) - d2));
            }
        }
        for (int i5 = 0; i5 < parameterDimension; i5++) {
            int i6 = i5;
            dArr3[i6] = dArr3[i6] * exp;
        }
        return dArr3;
    }

    public static double maxBetaScaled(double[] dArr, double d) {
        double d2 = Double.NEGATIVE_INFINITY;
        for (double d3 : dArr) {
            if (d3 > d2) {
                d2 = d3;
            }
        }
        return d * d2;
    }

    public static double logSum(double[] dArr, double d, double d2) {
        double d3 = 0.0d;
        for (double d4 : dArr) {
            d3 += Math.exp((d2 * d4) - d);
        }
        return d + Math.log(d3);
    }
}
