package burlap.behavior.singleagent;

import burlap.behavior.singleagent.options.Option;
import burlap.debugtools.RandomFactory;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.Action;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import burlap.oomdp.singleagent.common.NullAction;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;

/* loaded from: input_file:burlap/behavior/singleagent/Policy.class */
public abstract class Policy {
    protected boolean evaluateDecomposesOptions = true;
    protected boolean annotateOptionDecomposition = true;

    /* loaded from: input_file:burlap/behavior/singleagent/Policy$ActionProb.class */
    public static class ActionProb {
        public AbstractGroundedAction ga;
        public double pSelection;

        public ActionProb(AbstractGroundedAction abstractGroundedAction, double d) {
            this.ga = abstractGroundedAction;
            this.pSelection = d;
        }
    }

    /* loaded from: input_file:burlap/behavior/singleagent/Policy$PolicyUndefinedException.class */
    public static class PolicyUndefinedException extends RuntimeException {
        private static final long serialVersionUID = 1;

        public PolicyUndefinedException() {
            super("Policy is undefined for provided state");
        }
    }

    /* loaded from: input_file:burlap/behavior/singleagent/Policy$RandomPolicy.class */
    public static class RandomPolicy extends Policy {
        protected List<Action> actions;
        protected Random rand;

        public RandomPolicy(Domain domain) {
            this.rand = RandomFactory.getMapped(0);
            this.actions = new ArrayList(domain.getActions());
        }

        public RandomPolicy(List<Action> list) {
            this.rand = RandomFactory.getMapped(0);
            this.actions = new ArrayList(this.actions);
        }

        public void addAction(Action action) {
            this.actions.add(action);
        }

        public void clearActions() {
            this.actions.clear();
        }

        public void removeAction(String str) {
            Action action = null;
            Iterator<Action> it = this.actions.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                Action next = it.next();
                if (next.getName().equals(str)) {
                    action = next;
                    break;
                }
            }
            if (action != null) {
                this.actions.remove(action);
            }
        }

        public List<Action> getSelectionActions() {
            return this.actions;
        }

        public Random getRandomGenerator() {
            return this.rand;
        }

        public void setRandomGenerator(Random random) {
            this.rand = random;
        }

        @Override // burlap.behavior.singleagent.Policy
        public AbstractGroundedAction getAction(State state) {
            List<GroundedAction> allApplicableGroundedActionsFromActionList = Action.getAllApplicableGroundedActionsFromActionList(this.actions, state);
            if (allApplicableGroundedActionsFromActionList.size() == 0) {
                throw new PolicyUndefinedException();
            }
            return allApplicableGroundedActionsFromActionList.get(this.rand.nextInt(this.actions.size()));
        }

        @Override // burlap.behavior.singleagent.Policy
        public List<ActionProb> getActionDistributionForState(State state) {
            List<GroundedAction> allApplicableGroundedActionsFromActionList = Action.getAllApplicableGroundedActionsFromActionList(this.actions, state);
            if (allApplicableGroundedActionsFromActionList.size() == 0) {
                throw new PolicyUndefinedException();
            }
            double size = 1.0d / allApplicableGroundedActionsFromActionList.size();
            ArrayList arrayList = new ArrayList(allApplicableGroundedActionsFromActionList.size());
            Iterator<GroundedAction> it = allApplicableGroundedActionsFromActionList.iterator();
            while (it.hasNext()) {
                arrayList.add(new ActionProb(it.next(), size));
            }
            return arrayList;
        }

        @Override // burlap.behavior.singleagent.Policy
        public boolean isStochastic() {
            return true;
        }

        @Override // burlap.behavior.singleagent.Policy
        public boolean isDefinedFor(State state) {
            return Action.getAllApplicableGroundedActionsFromActionList(this.actions, state).size() > 0;
        }
    }

    public abstract AbstractGroundedAction getAction(State state);

    public abstract List<ActionProb> getActionDistributionForState(State state);

    public abstract boolean isStochastic();

    public abstract boolean isDefinedFor(State state);

    public double getProbOfAction(State state, AbstractGroundedAction abstractGroundedAction) {
        List<ActionProb> actionDistributionForState = getActionDistributionForState(state);
        if (actionDistributionForState == null || actionDistributionForState.size() == 0) {
            throw new PolicyUndefinedException();
        }
        for (ActionProb actionProb : actionDistributionForState) {
            if (actionProb.ga.equals(abstractGroundedAction)) {
                return actionProb.pSelection;
            }
        }
        return 0.0d;
    }

    @Deprecated
    public static double getProbOfActionGivenDistribution(State state, AbstractGroundedAction abstractGroundedAction, List<ActionProb> list) {
        if (list == null || list.size() == 0) {
            throw new RuntimeException("Distribution is null or empty, cannot return probability for given action.");
        }
        for (ActionProb actionProb : list) {
            if (actionProb.ga.equals(abstractGroundedAction)) {
                return actionProb.pSelection;
            }
        }
        return 0.0d;
    }

    public static double getProbOfActionGivenDistribution(AbstractGroundedAction abstractGroundedAction, List<ActionProb> list) {
        if (list == null || list.size() == 0) {
            throw new RuntimeException("Distribution is null or empty, cannot return probability for given action.");
        }
        for (ActionProb actionProb : list) {
            if (actionProb.ga.equals(abstractGroundedAction)) {
                return actionProb.pSelection;
            }
        }
        return 0.0d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<ActionProb> getDeterministicPolicy(State state) {
        AbstractGroundedAction action = getAction(state);
        if (action == null) {
            throw new PolicyUndefinedException();
        }
        ActionProb actionProb = new ActionProb(action, 1.0d);
        ArrayList arrayList = new ArrayList();
        arrayList.add(actionProb);
        return arrayList;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public AbstractGroundedAction sampleFromActionDistribution(State state) {
        double nextDouble = RandomFactory.getMapped(0).nextDouble();
        List<ActionProb> actionDistributionForState = getActionDistributionForState(state);
        if (actionDistributionForState == null || actionDistributionForState.size() == 0) {
            throw new PolicyUndefinedException();
        }
        double d = 0.0d;
        for (ActionProb actionProb : actionDistributionForState) {
            d += actionProb.pSelection;
            if (nextDouble < d) {
                return actionProb.ga;
            }
        }
        throw new RuntimeException("Tried to sample policy action distribution, but it did not sum to 1.");
    }

    public void evaluateMethodsShouldDecomposeOption(boolean z) {
        this.evaluateDecomposesOptions = z;
    }

    public void evaluateMethodsShouldAnnotateOptionDecomposition(boolean z) {
        this.annotateOptionDecomposition = z;
    }

    public EpisodeAnalysis evaluateBehavior(State state, RewardFunction rewardFunction, TerminalFunction terminalFunction) {
        EpisodeAnalysis episodeAnalysis = new EpisodeAnalysis();
        episodeAnalysis.addState(state);
        State state2 = state;
        while (true) {
            State state3 = state2;
            if (terminalFunction.isTerminal(state3)) {
                return episodeAnalysis;
            }
            state2 = followAndRecordPolicy(episodeAnalysis, state3, rewardFunction);
        }
    }

    public EpisodeAnalysis evaluateBehavior(State state, RewardFunction rewardFunction, TerminalFunction terminalFunction, int i) {
        EpisodeAnalysis episodeAnalysis = new EpisodeAnalysis();
        episodeAnalysis.addState(state);
        State state2 = state;
        int i2 = 0;
        while (true) {
            int i3 = i2;
            if (terminalFunction.isTerminal(state2) || i3 >= i) {
                break;
            }
            state2 = followAndRecordPolicy(episodeAnalysis, state2, rewardFunction);
            i2 = episodeAnalysis.numTimeSteps();
        }
        return episodeAnalysis;
    }

    public EpisodeAnalysis evaluateBehavior(State state, RewardFunction rewardFunction, int i) {
        EpisodeAnalysis episodeAnalysis = new EpisodeAnalysis();
        episodeAnalysis.addState(state);
        State state2 = state;
        for (int i2 = 0; i2 < i; i2 = episodeAnalysis.numTimeSteps()) {
            state2 = followAndRecordPolicy(episodeAnalysis, state2, rewardFunction);
        }
        return episodeAnalysis;
    }

    private State followAndRecordPolicy(EpisodeAnalysis episodeAnalysis, State state, RewardFunction rewardFunction) {
        State executeIn;
        AbstractGroundedAction action = getAction(state);
        if (action == null) {
            throw new PolicyUndefinedException();
        }
        if (!(action instanceof GroundedAction)) {
            throw new RuntimeException("cannot folow policy for non-single agent actions");
        }
        GroundedAction groundedAction = (GroundedAction) action;
        if (groundedAction.action.isPrimitive() || !this.evaluateDecomposesOptions) {
            executeIn = groundedAction.executeIn(state);
            episodeAnalysis.recordTransitionTo(groundedAction, executeIn, rewardFunction.reward(state, groundedAction, executeIn));
        } else {
            Option option = (Option) groundedAction.action;
            option.initiateInState(state, groundedAction.params);
            int i = 0;
            do {
                GroundedAction oneStepActionSelection = option.oneStepActionSelection(state, groundedAction.params);
                executeIn = oneStepActionSelection.executeIn(state);
                double reward = rewardFunction.reward(state, oneStepActionSelection, executeIn);
                if (this.annotateOptionDecomposition) {
                    episodeAnalysis.recordTransitionTo(new GroundedAction(new NullAction(option.getName() + "(" + i + ")-" + oneStepActionSelection.action.getName()), oneStepActionSelection.params), executeIn, reward);
                } else {
                    episodeAnalysis.recordTransitionTo(oneStepActionSelection, executeIn, reward);
                }
                state = executeIn;
                i++;
            } while (option.continueFromState(state, groundedAction.params));
        }
        return executeIn;
    }
}
