package burlap.behavior.singleagent.options;

import burlap.behavior.singleagent.EpisodeAnalysis;
import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.planning.StateMapping;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.oomdp.auxiliary.common.NullTermination;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.core.TransitionProbability;
import burlap.oomdp.singleagent.Action;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import burlap.oomdp.singleagent.common.NullAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;

/* loaded from: input_file:burlap/behavior/singleagent/options/Option.class */
public abstract class Option extends Action {
    protected Random rand;
    protected EpisodeAnalysis lastOptionExecutionResults;
    protected boolean shouldRecordResults;
    protected boolean shouldAnnotateExecution;
    protected RewardFunction rf;
    protected boolean keepTrackOfReward;
    protected double discountFactor;
    protected double lastCumulativeReward;
    protected double cumulativeDiscount;
    protected int lastNumSteps;
    protected TerminalFunction externalTerminalFunction;
    protected StateHashFactory expectationStateHashingFactory;
    protected Map<StateHashTuple, List<TransitionProbability>> cachedExpectations;
    protected Map<StateHashTuple, Double> cachedExpectedRewards;
    protected double expectationSearchCutoffProb;
    protected StateMapping stateMapping;
    protected DirectOptionTerminateMapper terminateMapper;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:burlap/behavior/singleagent/options/Option$ExpectationSearchNode.class */
    public class ExpectationSearchNode {
        public State s;
        public String[] optionParams;
        public double probability;
        public double cumulativeDiscountedReward;
        public int nSteps;

        public ExpectationSearchNode(State state, String[] strArr) {
            this.s = state;
            this.optionParams = strArr;
            this.probability = 1.0d;
            this.cumulativeDiscountedReward = 0.0d;
            this.nSteps = 0;
        }

        public ExpectationSearchNode(ExpectationSearchNode expectationSearchNode, State state, double d, double d2) {
            this.s = state;
            this.optionParams = expectationSearchNode.optionParams;
            this.probability = expectationSearchNode.probability * d;
            this.cumulativeDiscountedReward = expectationSearchNode.cumulativeDiscountedReward + d2;
            this.nSteps = expectationSearchNode.nSteps + 1;
        }
    }

    public abstract boolean isMarkov();

    public abstract boolean usesDeterministicTermination();

    public abstract boolean usesDeterministicPolicy();

    public abstract double probabilityOfTermination(State state, String[] strArr);

    public abstract void initiateInStateHelper(State state, String[] strArr);

    public abstract GroundedAction oneStepActionSelection(State state, String[] strArr);

    public abstract List<Policy.ActionProb> getActionDistributionForState(State state, String[] strArr);

    public Option() {
        this.expectationSearchCutoffProb = 0.001d;
        init();
    }

    public Option(String str, Domain domain, String str2) {
        super(str, domain, str2);
        this.expectationSearchCutoffProb = 0.001d;
        init();
    }

    public Option(String str, Domain domain, String[] strArr) {
        super(str, domain, strArr);
        this.expectationSearchCutoffProb = 0.001d;
        init();
    }

    public Option(String str, Domain domain, String[] strArr, String[] strArr2) {
        super(str, domain, strArr, strArr2);
        this.expectationSearchCutoffProb = 0.001d;
        init();
    }

    private void init() {
        this.rand = new Random();
        this.rf = null;
        this.keepTrackOfReward = false;
        this.discountFactor = 1.0d;
        this.lastCumulativeReward = 0.0d;
        this.cumulativeDiscount = 1.0d;
        this.lastNumSteps = 0;
        this.stateMapping = null;
        this.terminateMapper = null;
        this.externalTerminalFunction = new NullTermination();
        this.shouldRecordResults = true;
        this.shouldAnnotateExecution = true;
    }

    public void setExpectationHashingFactory(StateHashFactory stateHashFactory) {
        this.expectationStateHashingFactory = stateHashFactory;
        this.cachedExpectations = new HashMap();
        this.cachedExpectedRewards = new HashMap();
    }

    public void setExpectationCalculationProbabilityCutoff(double d) {
        this.expectationSearchCutoffProb = d;
    }

    public void toggleShouldRecordResults(boolean z) {
        this.shouldRecordResults = z;
    }

    public void toggleShouldAnnotateResults(boolean z) {
        this.shouldAnnotateExecution = z;
    }

    public boolean isRecordingExecutionResults() {
        return this.shouldRecordResults;
    }

    public boolean isAnnotatingExecutionResults() {
        return this.shouldAnnotateExecution;
    }

    public EpisodeAnalysis getLastExecutionResults() {
        return this.lastOptionExecutionResults;
    }

    public void setStateMapping(StateMapping stateMapping) {
        this.stateMapping = stateMapping;
    }

    public void setTerminateMapper(DirectOptionTerminateMapper directOptionTerminateMapper) {
        this.terminateMapper = directOptionTerminateMapper;
    }

    public void setExernalTermination(TerminalFunction terminalFunction) {
        if (terminalFunction == null) {
            this.externalTerminalFunction = new NullTermination();
        } else {
            this.externalTerminalFunction = terminalFunction;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public State map(State state) {
        return this.stateMapping == null ? state : this.stateMapping.mapState(state);
    }

    public void keepTrackOfRewardWith(RewardFunction rewardFunction, double d) {
        this.keepTrackOfReward = true;
        this.rf = rewardFunction;
        this.discountFactor = d;
    }

    @Override // burlap.oomdp.singleagent.Action
    public void init(String str, Domain domain, String[] strArr, String[] strArr2) {
        this.name = str;
        this.domain = domain;
        this.parameterClasses = strArr;
        this.parameterOrderGroup = strArr2;
    }

    public double getLastCumulativeReward() {
        return this.lastCumulativeReward;
    }

    public int getLastNumSteps() {
        return this.lastNumSteps;
    }

    @Override // burlap.oomdp.singleagent.Action
    public boolean isPrimitive() {
        return false;
    }

    public void initiateInState(State state, String[] strArr) {
        this.lastCumulativeReward = 0.0d;
        this.cumulativeDiscount = 1.0d;
        this.lastNumSteps = 0;
        this.lastOptionExecutionResults = new EpisodeAnalysis(state);
        initiateInStateHelper(state, strArr);
    }

    @Override // burlap.oomdp.singleagent.Action
    protected State performActionHelper(State state, String[] strArr) {
        if (this.terminateMapper != null) {
            State generateOptionTerminalState = this.terminateMapper.generateOptionTerminalState(state);
            this.lastNumSteps = this.terminateMapper.getNumSteps(state, generateOptionTerminalState);
            this.lastCumulativeReward = this.terminateMapper.getCumulativeReward(state, generateOptionTerminalState, this.rf, this.discountFactor);
            return generateOptionTerminalState;
        }
        State state2 = state;
        initiateInState(state2, strArr);
        do {
            state2 = oneStep(state2, strArr);
            if (!continueFromState(state2, strArr)) {
                break;
            }
        } while (!this.externalTerminalFunction.isTerminal(state2));
        return state2;
    }

    public State oneStep(State state, String[] strArr) {
        GroundedAction oneStepActionSelection = oneStepActionSelection(state, strArr);
        State executeIn = oneStepActionSelection.executeIn(state);
        this.lastNumSteps++;
        double d = 0.0d;
        if (this.keepTrackOfReward) {
            d = this.rf.reward(state, oneStepActionSelection, executeIn);
            this.lastCumulativeReward += this.cumulativeDiscount * d;
            this.cumulativeDiscount *= this.discountFactor;
        }
        if (this.shouldRecordResults) {
            GroundedAction groundedAction = oneStepActionSelection;
            if (this.shouldAnnotateExecution) {
                groundedAction = new GroundedAction(new NullAction(this.name + "(" + (this.lastNumSteps - 1) + ")-" + oneStepActionSelection.action.getName()), oneStepActionSelection.params);
            }
            this.lastOptionExecutionResults.recordTransitionTo(groundedAction, executeIn, d);
        }
        return executeIn;
    }

    public boolean continueFromState(State state, String[] strArr) {
        double probabilityOfTermination = probabilityOfTermination(state, strArr);
        if (probabilityOfTermination == 1.0d) {
            return false;
        }
        return probabilityOfTermination == 0.0d || this.rand.nextDouble() >= probabilityOfTermination;
    }

    public double getExpectedRewards(State state, String[] strArr) {
        StateHashTuple hashState = this.expectationStateHashingFactory.hashState(state);
        Double d = this.cachedExpectedRewards.get(hashState);
        if (d != null) {
            return d.doubleValue();
        }
        getTransitions(state, strArr);
        return this.cachedExpectedRewards.get(hashState).doubleValue();
    }

    @Override // burlap.oomdp.singleagent.Action
    public List<TransitionProbability> getTransitions(State state, String[] strArr) {
        StateHashTuple hashState = this.expectationStateHashingFactory.hashState(state);
        List<TransitionProbability> list = this.cachedExpectations.get(hashState);
        if (list != null) {
            return list;
        }
        initiateInState(state, strArr);
        ExpectationSearchNode expectationSearchNode = new ExpectationSearchNode(state, strArr);
        HashMap hashMap = new HashMap();
        double[] dArr = {0.0d};
        iterateExpectationScan(expectationSearchNode, 1.0d, hashMap, dArr);
        this.cachedExpectedRewards.put(hashState, Double.valueOf(dArr[0]));
        ArrayList arrayList = new ArrayList();
        for (Map.Entry<StateHashTuple, Double> entry : hashMap.entrySet()) {
            arrayList.add(new TransitionProbability(entry.getKey().s, entry.getValue().doubleValue()));
        }
        this.cachedExpectations.put(hashState, arrayList);
        return arrayList;
    }

    protected void iterateExpectationScan(ExpectationSearchNode expectationSearchNode, double d, Map<StateHashTuple, Double> map, double[] dArr) {
        double probabilityOfTermination = expectationSearchNode.nSteps > 0 ? probabilityOfTermination(expectationSearchNode.s, expectationSearchNode.optionParams) : 0.0d;
        double d2 = 1.0d - probabilityOfTermination;
        if (probabilityOfTermination > 0.0d) {
            accumulateDiscountedProb(map, expectationSearchNode.s, expectationSearchNode.probability * d);
            dArr[0] = dArr[0] + expectationSearchNode.cumulativeDiscountedReward;
        }
        if (d2 > 0.0d) {
            for (Policy.ActionProb actionProb : getActionDistributionForState(expectationSearchNode.s, expectationSearchNode.optionParams)) {
                for (TransitionProbability transitionProbability : ((GroundedAction) actionProb.ga).action.getTransitions(expectationSearchNode.s, expectationSearchNode.optionParams)) {
                    ExpectationSearchNode expectationSearchNode2 = new ExpectationSearchNode(expectationSearchNode, transitionProbability.s, actionProb.pSelection * transitionProbability.p, d * this.rf.reward(expectationSearchNode.s, (GroundedAction) actionProb.ga, transitionProbability.s));
                    if (expectationSearchNode2.probability > this.expectationSearchCutoffProb) {
                        iterateExpectationScan(expectationSearchNode2, d * this.discountFactor, map, dArr);
                    }
                }
            }
        }
    }

    protected void accumulateDiscountedProb(Map<StateHashTuple, Double> map, State state, double d) {
        StateHashTuple hashState = this.expectationStateHashingFactory.hashState(state);
        Double d2 = map.get(hashState);
        double d3 = d;
        if (d2 != null) {
            d3 = d2.doubleValue() + d;
        }
        map.put(hashState, Double.valueOf(d3));
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<Policy.ActionProb> getDeterministicPolicy(State state, String[] strArr) {
        Policy.ActionProb actionProb = new Policy.ActionProb(oneStepActionSelection(state, strArr), 1.0d);
        ArrayList arrayList = new ArrayList();
        arrayList.add(actionProb);
        return arrayList;
    }
}
