package burlap.behavior.learningrate;

import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.State;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/learningrate/SoftTimeInverseDecayLR.class */
public class SoftTimeInverseDecayLR implements LearningRate {
    protected double initialLearningRate;
    protected double decayConstantShift;
    protected double minimumLR;
    protected int universalTime;
    protected Map<StateHashTuple, StateWiseTimeIndex> stateWiseMap;
    protected Map<Integer, StateWiseTimeIndex> featureWiseMap;
    protected boolean useStateWise;
    protected boolean useStateActionWise;
    protected StateHashFactory hashingFactory;
    protected int lastPollTime;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/learningrate/SoftTimeInverseDecayLR$MutableInt.class */
    public class MutableInt {
        int mi;
        int lastPollTime = -1;

        public MutableInt(int i) {
            this.mi = i;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/learningrate/SoftTimeInverseDecayLR$StateWiseTimeIndex.class */
    public class StateWiseTimeIndex {
        Map<String, MutableInt> actionLearningRates;
        int lastPollTime = -1;
        int timeIndex = 1;

        public StateWiseTimeIndex() {
            this.actionLearningRates = null;
            if (SoftTimeInverseDecayLR.this.useStateActionWise) {
                this.actionLearningRates = new HashMap();
            }
        }

        public MutableInt getActionTimeIndexEntry(AbstractGroundedAction abstractGroundedAction) {
            MutableInt mutableInt = this.actionLearningRates.get(abstractGroundedAction);
            if (mutableInt == null) {
                mutableInt = new MutableInt(1);
                this.actionLearningRates.put(abstractGroundedAction.actionName(), mutableInt);
            }
            return mutableInt;
        }
    }

    public SoftTimeInverseDecayLR(double d, double d2) {
        this.minimumLR = Double.MIN_NORMAL;
        this.universalTime = 1;
        this.useStateWise = false;
        this.useStateActionWise = false;
        this.lastPollTime = -1;
        this.initialLearningRate = d;
        this.decayConstantShift = d2;
    }

    public SoftTimeInverseDecayLR(double d, double d2, double d3) {
        this.minimumLR = Double.MIN_NORMAL;
        this.universalTime = 1;
        this.useStateWise = false;
        this.useStateActionWise = false;
        this.lastPollTime = -1;
        this.initialLearningRate = d;
        this.decayConstantShift = d2;
        this.minimumLR = d3;
    }

    public SoftTimeInverseDecayLR(double d, double d2, StateHashFactory stateHashFactory, boolean z) {
        this.minimumLR = Double.MIN_NORMAL;
        this.universalTime = 1;
        this.useStateWise = false;
        this.useStateActionWise = false;
        this.lastPollTime = -1;
        this.initialLearningRate = d;
        this.decayConstantShift = d2;
        this.useStateWise = true;
        this.useStateActionWise = z;
        this.hashingFactory = stateHashFactory;
        this.stateWiseMap = new HashMap();
        this.featureWiseMap = new HashMap();
    }

    public SoftTimeInverseDecayLR(double d, double d2, double d3, StateHashFactory stateHashFactory, boolean z) {
        this.minimumLR = Double.MIN_NORMAL;
        this.universalTime = 1;
        this.useStateWise = false;
        this.useStateActionWise = false;
        this.lastPollTime = -1;
        this.initialLearningRate = d;
        this.decayConstantShift = d2;
        this.minimumLR = d3;
        this.useStateWise = true;
        this.useStateActionWise = z;
        this.hashingFactory = stateHashFactory;
        this.stateWiseMap = new HashMap();
        this.featureWiseMap = new HashMap();
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public double peekAtLearningRate(State state, AbstractGroundedAction abstractGroundedAction) {
        if (!this.useStateWise) {
            return learningRate(this.universalTime);
        }
        StateWiseTimeIndex stateWiseTimeIndex = getStateWiseTimeIndex(state);
        return !this.useStateActionWise ? learningRate(stateWiseTimeIndex.timeIndex) : learningRate(stateWiseTimeIndex.getActionTimeIndexEntry(abstractGroundedAction).mi);
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public double pollLearningRate(int i, State state, AbstractGroundedAction abstractGroundedAction) {
        if (!this.useStateWise) {
            double learningRate = learningRate(this.universalTime);
            if (i > this.lastPollTime) {
                this.universalTime++;
                this.lastPollTime = i;
            }
            return learningRate;
        }
        StateWiseTimeIndex stateWiseTimeIndex = getStateWiseTimeIndex(state);
        if (!this.useStateActionWise) {
            double learningRate2 = learningRate(stateWiseTimeIndex.timeIndex);
            if (i > stateWiseTimeIndex.lastPollTime) {
                stateWiseTimeIndex.timeIndex++;
                stateWiseTimeIndex.lastPollTime = i;
            }
            return learningRate2;
        }
        MutableInt actionTimeIndexEntry = stateWiseTimeIndex.getActionTimeIndexEntry(abstractGroundedAction);
        double learningRate3 = learningRate(stateWiseTimeIndex.getActionTimeIndexEntry(abstractGroundedAction).mi);
        if (i > actionTimeIndexEntry.lastPollTime) {
            actionTimeIndexEntry.mi++;
            actionTimeIndexEntry.lastPollTime = i;
        }
        return learningRate3;
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public double peekAtLearningRate(int i) {
        return !this.useStateWise ? learningRate(this.universalTime) : learningRate(getFeatureWiseTimeIndex(i).timeIndex);
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public double pollLearningRate(int i, int i2) {
        if (!this.useStateWise) {
            double learningRate = learningRate(this.universalTime);
            if (i > this.lastPollTime) {
                this.universalTime++;
                this.lastPollTime = i;
            }
            return learningRate;
        }
        StateWiseTimeIndex featureWiseTimeIndex = getFeatureWiseTimeIndex(i2);
        double learningRate2 = learningRate(featureWiseTimeIndex.timeIndex);
        if (i > featureWiseTimeIndex.lastPollTime) {
            featureWiseTimeIndex.timeIndex++;
            featureWiseTimeIndex.lastPollTime = i;
        }
        return learningRate2;
    }

    @Override // burlap.behavior.learningrate.LearningRate
    public void resetDecay() {
        this.universalTime = 1;
        this.stateWiseMap.clear();
        this.featureWiseMap.clear();
    }

    protected double learningRate(int i) {
        return Math.max(i == 0 ? this.initialLearningRate : this.initialLearningRate * ((this.decayConstantShift + 1.0d) / (this.decayConstantShift + i)), this.minimumLR);
    }

    protected StateWiseTimeIndex getStateWiseTimeIndex(State state) {
        StateHashTuple hashState = this.hashingFactory.hashState(state);
        StateWiseTimeIndex stateWiseTimeIndex = this.stateWiseMap.get(hashState);
        if (stateWiseTimeIndex == null) {
            stateWiseTimeIndex = new StateWiseTimeIndex();
            this.stateWiseMap.put(hashState, stateWiseTimeIndex);
        }
        return stateWiseTimeIndex;
    }

    protected StateWiseTimeIndex getFeatureWiseTimeIndex(int i) {
        StateWiseTimeIndex stateWiseTimeIndex = this.featureWiseMap.get(Integer.valueOf(i));
        if (stateWiseTimeIndex == null) {
            stateWiseTimeIndex = new StateWiseTimeIndex();
            this.featureWiseMap.put(Integer.valueOf(i), stateWiseTimeIndex);
        }
        return stateWiseTimeIndex;
    }
}
