package burlap.behavior.singleagent.learnbydemo.mlirl;

import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.learnbydemo.mlirl.support.BoltzmannPolicyGradient;
import burlap.behavior.singleagent.learnbydemo.mlirl.support.QGradientPlanner;
import burlap.behavior.singleagent.planning.QComputablePlanner;
import burlap.behavior.singleagent.planning.commonpolicies.BoltzmannQPolicy;
import burlap.debugtools.DPrint;
import burlap.oomdp.core.State;
import burlap.oomdp.singleagent.GroundedAction;
import java.util.List;

/* loaded from: input_file:burlap/behavior/singleagent/learnbydemo/mlirl/MLIRL.class */
public class MLIRL {
    protected MLIRLRequest request;
    protected double learningRate;
    protected double maxLikelihoodChange;
    protected int maxSteps;
    protected int debugCode = 625420;

    public MLIRL(MLIRLRequest mLIRLRequest, double d, double d2, int i) {
        this.request = mLIRLRequest;
        this.learningRate = d;
        this.maxLikelihoodChange = d2;
        this.maxSteps = i;
        if (!mLIRLRequest.isValid()) {
            throw new RuntimeException("Provided MLIRLRequest object is not valid.");
        }
    }

    public void setRequest(MLIRLRequest mLIRLRequest) {
        this.request = mLIRLRequest;
    }

    public void toggleDebugPrinting(boolean z) {
        DPrint.toggleCode(this.debugCode, z);
        this.request.getPlanner().toggleDebugPrinting(z);
    }

    public int getDebugCode() {
        return this.debugCode;
    }

    public void setDebugCode(int i) {
        this.debugCode = i;
    }

    public void performIRL() {
        this.request.getPlanner().resetPlannerResults();
        double logLikelihood = logLikelihood();
        DPrint.cl(this.debugCode, "RF: " + this.request.getRf().toString());
        DPrint.cl(this.debugCode, "Log likelihood: " + logLikelihood);
        int i = 0;
        while (true) {
            if (i >= this.maxSteps && this.maxSteps != -1) {
                break;
            }
            double[] dArr = (double[]) this.request.getRf().getParameters().clone();
            double[] logLikelihoodGradient = logLikelihoodGradient();
            double d = 0.0d;
            double[] parameters = this.request.getRf().getParameters();
            for (int i2 = 0; i2 < parameters.length; i2++) {
                int i3 = i2;
                parameters[i3] = parameters[i3] + (this.learningRate * logLikelihoodGradient[i2]);
                d = Math.max(d, Math.abs(parameters[i2] - dArr[i2]));
            }
            this.request.getPlanner().resetPlannerResults();
            double logLikelihood2 = logLikelihood();
            double d2 = logLikelihood2 - logLikelihood;
            logLikelihood = logLikelihood2;
            DPrint.cl(this.debugCode, "RF: " + this.request.getRf().toString());
            DPrint.cl(this.debugCode, "Log likelihood: " + logLikelihood + " (change: " + d2 + ")");
            if (Math.abs(d2) < this.maxLikelihoodChange) {
                i++;
                break;
            }
            i++;
        }
        DPrint.cl(this.debugCode, "\nNum gradient ascent steps: " + i);
        DPrint.cl(this.debugCode, "RF: " + this.request.getRf().toString());
    }

    public double logLikelihood() {
        double[] episodeWeights = this.request.getEpisodeWeights();
        List<EpisodeAnalysis> expertEpisodes = this.request.getExpertEpisodes();
        double d = 0.0d;
        for (int i = 0; i < expertEpisodes.size(); i++) {
            d += logLikelihoodOfTrajectory(expertEpisodes.get(i), episodeWeights[i]);
        }
        return d;
    }

    public double logLikelihoodOfTrajectory(EpisodeAnalysis episodeAnalysis, double d) {
        double d2 = 0.0d;
        BoltzmannQPolicy boltzmannQPolicy = new BoltzmannQPolicy((QComputablePlanner) this.request.getPlanner(), 1.0d / this.request.getBoltzmannBeta());
        for (int i = 0; i < episodeAnalysis.numTimeSteps() - 1; i++) {
            this.request.getPlanner().planFromState(episodeAnalysis.getState(i));
            d2 += Math.log(boltzmannQPolicy.getProbOfAction(episodeAnalysis.getState(i), episodeAnalysis.getAction(i)));
        }
        return d2 * d;
    }

    public double[] logLikelihoodGradient() {
        double[] dArr = new double[this.request.getRf().getParameterDimension()];
        double[] episodeWeights = this.request.getEpisodeWeights();
        List<EpisodeAnalysis> expertEpisodes = this.request.getExpertEpisodes();
        for (int i = 0; i < expertEpisodes.size(); i++) {
            EpisodeAnalysis episodeAnalysis = expertEpisodes.get(i);
            double d = episodeWeights[i];
            for (int i2 = 0; i2 < episodeAnalysis.numTimeSteps() - 1; i2++) {
                this.request.getPlanner().planFromState(episodeAnalysis.getState(i2));
                double[] logPolicyGrad = logPolicyGrad(episodeAnalysis.getState(i2), episodeAnalysis.getAction(i2));
                for (int i3 = 0; i3 < logPolicyGrad.length; i3++) {
                    int i4 = i3;
                    logPolicyGrad[i4] = logPolicyGrad[i4] * d;
                }
                addToVector(dArr, logPolicyGrad);
            }
        }
        return dArr;
    }

    public double[] logPolicyGrad(State state, GroundedAction groundedAction) {
        double probOfAction = 1.0d / new BoltzmannQPolicy((QComputablePlanner) this.request.getPlanner(), 1.0d / this.request.getBoltzmannBeta()).getProbOfAction(state, groundedAction);
        double[] computeBoltzmannPolicyGradient = BoltzmannPolicyGradient.computeBoltzmannPolicyGradient(state, groundedAction, (QGradientPlanner) this.request.getPlanner(), this.request.getBoltzmannBeta());
        for (int i = 0; i < computeBoltzmannPolicyGradient.length; i++) {
            int i2 = i;
            computeBoltzmannPolicyGradient[i2] = computeBoltzmannPolicyGradient[i2] * probOfAction;
        }
        return computeBoltzmannPolicyGradient;
    }

    protected static void addToVector(double[] dArr, double[] dArr2) {
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + dArr2[i];
        }
    }
}
