package burlap.behavior.singleagent.learnbydemo.mlirl.commonrfs;

import burlap.behavior.singleagent.learnbydemo.mlirl.support.DifferentiableRF;
import burlap.behavior.singleagent.vfa.StateToFeatureVectorGenerator;
import burlap.oomdp.core.State;
import burlap.oomdp.singleagent.GroundedAction;

/* loaded from: input_file:burlap/behavior/singleagent/learnbydemo/mlirl/commonrfs/LinearStateDifferentiableRF.class */
public class LinearStateDifferentiableRF extends DifferentiableRF {
    protected boolean featuresAreForNextState;
    protected StateToFeatureVectorGenerator fvGen;

    public LinearStateDifferentiableRF(StateToFeatureVectorGenerator stateToFeatureVectorGenerator, int i) {
        this.featuresAreForNextState = true;
        this.dim = i;
        this.parameters = new double[i];
        this.fvGen = stateToFeatureVectorGenerator;
    }

    public LinearStateDifferentiableRF(StateToFeatureVectorGenerator stateToFeatureVectorGenerator, int i, boolean z) {
        this.featuresAreForNextState = true;
        this.featuresAreForNextState = z;
        this.dim = i;
        this.parameters = new double[i];
        this.fvGen = stateToFeatureVectorGenerator;
    }

    public void setFeaturesAreForNextState(boolean z) {
        this.featuresAreForNextState = z;
    }

    @Override // burlap.behavior.singleagent.learnbydemo.mlirl.support.DifferentiableRF
    protected DifferentiableRF copyHelper() {
        return new LinearStateDifferentiableRF(this.fvGen, this.dim, this.featuresAreForNextState);
    }

    @Override // burlap.behavior.singleagent.learnbydemo.mlirl.support.DifferentiableRF
    public double[] getGradient(State state, GroundedAction groundedAction, State state2) {
        return this.featuresAreForNextState ? this.fvGen.generateFeatureVectorFrom(state2) : this.fvGen.generateFeatureVectorFrom(state);
    }

    @Override // burlap.oomdp.singleagent.RewardFunction
    public double reward(State state, GroundedAction groundedAction, State state2) {
        double[] generateFeatureVectorFrom = this.featuresAreForNextState ? this.fvGen.generateFeatureVectorFrom(state2) : this.fvGen.generateFeatureVectorFrom(state);
        double d = 0.0d;
        for (int i = 0; i < generateFeatureVectorFrom.length; i++) {
            d += generateFeatureVectorFrom[i] * this.parameters[i];
        }
        return d;
    }
}
