package burlap.behavior.singleagent.learnbydemo.mlirl;

import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.learnbydemo.IRLRequest;
import burlap.behavior.singleagent.learnbydemo.mlirl.differentiableplanners.DifferentiableVI;
import burlap.behavior.singleagent.learnbydemo.mlirl.support.DifferentiableRF;
import burlap.behavior.singleagent.learnbydemo.mlirl.support.QGradientPlanner;
import burlap.behavior.singleagent.planning.OOMDPPlanner;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.oomdp.auxiliary.common.NullTermination;
import burlap.oomdp.core.Domain;
import java.util.List;

/* loaded from: input_file:burlap/behavior/singleagent/learnbydemo/mlirl/MLIRLRequest.class */
public class MLIRLRequest extends IRLRequest {
    protected double[] episodeWeights;
    protected double boltzmannBeta;
    protected DifferentiableRF rf;

    public MLIRLRequest(Domain domain, OOMDPPlanner oOMDPPlanner, List<EpisodeAnalysis> list, DifferentiableRF differentiableRF) {
        super(domain, oOMDPPlanner, list);
        this.episodeWeights = null;
        this.boltzmannBeta = 0.5d;
        if (oOMDPPlanner != null && !(oOMDPPlanner instanceof QGradientPlanner)) {
            throw new RuntimeException("Error: MLIRLRequest requires the planner to be an instance of QGradientPlanner");
        }
        this.rf = differentiableRF;
    }

    public MLIRLRequest(Domain domain, List<EpisodeAnalysis> list, DifferentiableRF differentiableRF, StateHashFactory stateHashFactory) {
        super(domain, null, list);
        this.episodeWeights = null;
        this.boltzmannBeta = 0.5d;
        this.rf = differentiableRF;
        this.planner = new DifferentiableVI(domain, differentiableRF, new NullTermination(), this.gamma, this.boltzmannBeta, stateHashFactory, 0.01d, 500);
    }

    @Override // burlap.behavior.singleagent.learnbydemo.IRLRequest
    public boolean isValid() {
        if (super.isValid()) {
            return (this.episodeWeights == null || this.episodeWeights.length == this.expertEpisodes.size()) && (this.planner instanceof QGradientPlanner) && this.rf != null;
        }
        return false;
    }

    @Override // burlap.behavior.singleagent.learnbydemo.IRLRequest
    public void setPlanner(OOMDPPlanner oOMDPPlanner) {
        if (this.planner != null && !(oOMDPPlanner instanceof QGradientPlanner)) {
            throw new RuntimeException("Error: MLIRLRequest requires the planner to be an instance of QGradientPlanner");
        }
        this.planner = oOMDPPlanner;
    }

    public double[] getEpisodeWeights() {
        if (this.episodeWeights != null) {
            return this.episodeWeights;
        }
        double[] dArr = new double[this.expertEpisodes.size()];
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = 1.0d;
        }
        return dArr;
    }

    public double getBoltzmannBeta() {
        return this.boltzmannBeta;
    }

    public DifferentiableRF getRf() {
        return this.rf;
    }

    public void setEpisodeWeights(double[] dArr) {
        this.episodeWeights = dArr;
    }

    public void setBoltzmannBeta(double d) {
        this.boltzmannBeta = d;
        if (this.planner != null) {
            ((QGradientPlanner) this.planner).setBoltzmannBetaParameter(d);
        }
    }

    public void setRf(DifferentiableRF differentiableRF) {
        this.rf = differentiableRF;
    }
}
