package burlap.behavior.singleagent.learnbydemo.mlirl.commonrfs;

import burlap.behavior.singleagent.learnbydemo.mlirl.support.DifferentiableRF;
import burlap.behavior.singleagent.vfa.StateToFeatureVectorGenerator;
import burlap.oomdp.core.State;
import burlap.oomdp.singleagent.GroundedAction;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/learnbydemo/mlirl/commonrfs/LinearStateActionDifferentiableRF.class */
public class LinearStateActionDifferentiableRF extends DifferentiableRF {
    protected Map<GroundedAction, Integer> actionMap;
    protected StateToFeatureVectorGenerator fvGen;
    protected int numStateFeatures;
    int numActions;

    public LinearStateActionDifferentiableRF(StateToFeatureVectorGenerator stateToFeatureVectorGenerator, int i, GroundedAction... groundedActionArr) {
        this.numActions = 0;
        this.fvGen = stateToFeatureVectorGenerator;
        this.numStateFeatures = i;
        this.actionMap = new HashMap(groundedActionArr.length);
        for (int i2 = 0; i2 < groundedActionArr.length; i2++) {
            this.actionMap.put(groundedActionArr[i2], Integer.valueOf(i2));
        }
        this.numActions = groundedActionArr.length;
        this.parameters = new double[this.numActions * this.numStateFeatures];
        this.dim = this.numActions * this.numStateFeatures;
    }

    public void addAction(GroundedAction groundedAction) {
        this.actionMap.put(groundedAction, Integer.valueOf(this.numActions));
        this.numActions++;
        this.parameters = new double[this.numActions * this.numStateFeatures];
        this.dim = this.numActions * this.numStateFeatures;
    }

    @Override // burlap.behavior.singleagent.learnbydemo.mlirl.support.DifferentiableRF
    protected DifferentiableRF copyHelper() {
        LinearStateActionDifferentiableRF linearStateActionDifferentiableRF = new LinearStateActionDifferentiableRF(this.fvGen, this.numStateFeatures, new GroundedAction[0]);
        for (Map.Entry<GroundedAction, Integer> entry : this.actionMap.entrySet()) {
            linearStateActionDifferentiableRF.actionMap.put(entry.getKey(), entry.getValue());
        }
        return linearStateActionDifferentiableRF;
    }

    @Override // burlap.oomdp.singleagent.RewardFunction
    public double reward(State state, GroundedAction groundedAction, State state2) {
        double[] generateFeatureVectorFrom = this.fvGen.generateFeatureVectorFrom(state);
        int intValue = this.actionMap.get(groundedAction).intValue() * this.numStateFeatures;
        double d = 0.0d;
        for (int i = intValue; i < intValue + this.numStateFeatures; i++) {
            d += this.parameters[i] * generateFeatureVectorFrom[i - intValue];
        }
        return d;
    }

    @Override // burlap.behavior.singleagent.learnbydemo.mlirl.support.DifferentiableRF
    public double[] getGradient(State state, GroundedAction groundedAction, State state2) {
        double[] generateFeatureVectorFrom = this.fvGen.generateFeatureVectorFrom(state);
        int intValue = this.actionMap.get(groundedAction).intValue() * this.numStateFeatures;
        double[] dArr = new double[this.numStateFeatures * this.numActions];
        copyInto(generateFeatureVectorFrom, dArr, intValue);
        return dArr;
    }

    protected void copyInto(double[] dArr, double[] dArr2, int i) {
        for (int i2 = i; i2 < i + dArr.length; i2++) {
            dArr2[i2] = dArr[i2 - i];
        }
    }
}
