package burlap.behavior.singleagent.planning;

import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.ValueFunctionInitialization;
import burlap.behavior.singleagent.options.Option;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.oomdp.auxiliary.common.NullTermination;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.core.TransitionProbability;
import burlap.oomdp.singleagent.Action;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/planning/ValueFunctionPlanner.class */
public abstract class ValueFunctionPlanner extends OOMDPPlanner implements ValueFunction, QComputablePlanner {
    protected Map<StateHashTuple, List<ActionTransitions>> transitionDynamics;
    protected Map<StateHashTuple, Double> valueFunction;
    protected boolean useCachedTransitions = true;
    protected ValueFunctionInitialization valueInitializer = new ValueFunctionInitialization.ConstantValueFunctionInitialization();

    /* loaded from: input_file:burlap/behavior/singleagent/planning/ValueFunctionPlanner$StaticVFPlanner.class */
    public static class StaticVFPlanner extends ValueFunctionPlanner {
        public StaticVFPlanner(Domain domain, RewardFunction rewardFunction, double d, StateHashFactory stateHashFactory, List<Action> list, Map<StateHashTuple, Double> map) {
            VFPInit(domain, rewardFunction, new NullTermination(), d, stateHashFactory);
            Iterator<Action> it = list.iterator();
            while (it.hasNext()) {
                addNonDomainReferencedAction(it.next());
            }
            for (Map.Entry<StateHashTuple, Double> entry : map.entrySet()) {
                this.valueFunction.put(entry.getKey(), entry.getValue());
            }
        }

        @Override // burlap.behavior.singleagent.planning.ValueFunctionPlanner, burlap.behavior.singleagent.planning.OOMDPPlanner
        public void planFromState(State state) {
            throw new RuntimeException("StaticVFPlaner has no planning method defined. It is used for manually querying and manipulatinga value function.");
        }
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public abstract void planFromState(State state);

    public void VFPInit(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory) {
        plannerInit(domain, rewardFunction, terminalFunction, d, stateHashFactory);
        this.transitionDynamics = new HashMap();
        this.valueFunction = new HashMap();
    }

    @Override // burlap.behavior.singleagent.planning.OOMDPPlanner
    public void resetPlannerResults() {
        this.mapToStateIndex.clear();
        this.valueFunction.clear();
        this.transitionDynamics.clear();
    }

    public void setValueFunctionInitialization(ValueFunctionInitialization valueFunctionInitialization) {
        this.valueInitializer = valueFunctionInitialization;
    }

    public ValueFunctionInitialization getValueFunctionInitialization() {
        return this.valueInitializer;
    }

    public boolean hasComputedValueFor(State state) {
        return this.valueFunction.containsKey(this.hashingFactory.hashState(state));
    }

    @Override // burlap.behavior.singleagent.planning.ValueFunction
    public double value(State state) {
        return value(this.hashingFactory.hashState(state));
    }

    public double value(StateHashTuple stateHashTuple) {
        if (this.tf.isTerminal(stateHashTuple.s)) {
            return 0.0d;
        }
        Double d = this.valueFunction.get(stateHashTuple);
        return d == null ? getDefaultValue(stateHashTuple.s) : d.doubleValue();
    }

    public void toggleUseCachedTransitionDynamics(boolean z) {
        this.useCachedTransitions = z;
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public List<QValue> getQs(State state) {
        StateHashTuple stateHash = stateHash(state);
        Map<String, String> map = null;
        StateHashTuple stateHashTuple = this.mapToStateIndex.get(stateHash);
        if (stateHashTuple == null) {
            stateHashTuple = stateHash;
            this.mapToStateIndex.put(stateHashTuple, stateHashTuple);
        }
        if (this.containsParameterizedActions && !this.domain.isObjectIdentifierDependent()) {
            map = stateHash.s.getObjectMatchingTo(stateHashTuple.s, false);
        }
        ArrayList arrayList = new ArrayList();
        Iterator<Action> it = this.actions.iterator();
        while (it.hasNext()) {
            Iterator<GroundedAction> it2 = it.next().getAllApplicableGroundedActions(state).iterator();
            while (it2.hasNext()) {
                arrayList.add(getQ(stateHash, it2.next(), map));
            }
        }
        return arrayList;
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public QValue getQ(State state, AbstractGroundedAction abstractGroundedAction) {
        if (!this.useCachedTransitions) {
            return new QValue(state, abstractGroundedAction, computeQ(stateHash(state), (GroundedAction) abstractGroundedAction));
        }
        StateHashTuple stateHash = stateHash(state);
        Map<String, String> map = null;
        StateHashTuple stateHashTuple = this.mapToStateIndex.get(stateHash);
        if (stateHashTuple == null) {
            stateHashTuple = stateHash;
            this.mapToStateIndex.put(stateHashTuple, stateHashTuple);
        }
        if (this.containsParameterizedActions && !this.domain.isObjectIdentifierDependent() && abstractGroundedAction.parametersAreObjects()) {
            map = stateHash.s.getObjectMatchingTo(stateHashTuple.s, false);
        }
        return getQ(stateHash, (GroundedAction) abstractGroundedAction, map);
    }

    public List<State> getAllStates() {
        ArrayList arrayList = new ArrayList(this.valueFunction.size());
        Iterator<StateHashTuple> it = this.valueFunction.keySet().iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().s);
        }
        return arrayList;
    }

    public StaticVFPlanner getCopyOfValueFunction() {
        return new StaticVFPlanner(this.domain, this.rf, this.gamma, this.hashingFactory, this.actions, this.valueFunction);
    }

    protected QValue getQ(StateHashTuple stateHashTuple, GroundedAction groundedAction, Map<String, String> map) {
        GroundedAction groundedAction2 = groundedAction;
        if (map != null && groundedAction.parametersAreObjects()) {
            groundedAction2 = translateAction(groundedAction2, map);
        }
        ActionTransitions actionTransitions = null;
        Iterator<ActionTransitions> it = getActionsTransitions(stateHashTuple).iterator();
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            ActionTransitions next = it.next();
            if (next.matchingTransitions(groundedAction2)) {
                actionTransitions = next;
                break;
            }
        }
        double d = 0.0d;
        if (!this.tf.isTerminal(stateHashTuple.s)) {
            d = computeQ(stateHashTuple.s, actionTransitions);
        }
        return new QValue(stateHashTuple.s, groundedAction, d);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<ActionTransitions> getActionsTransitions(StateHashTuple stateHashTuple) {
        List<ActionTransitions> list = this.transitionDynamics.get(stateHashTuple);
        if (list == null) {
            this.mapToStateIndex.put(stateHashTuple, stateHashTuple);
            List<GroundedAction> allApplicableGroundedActionsFromActionList = Action.getAllApplicableGroundedActionsFromActionList(this.actions, stateHashTuple.s);
            list = new ArrayList(allApplicableGroundedActionsFromActionList.size());
            Iterator<GroundedAction> it = allApplicableGroundedActionsFromActionList.iterator();
            while (it.hasNext()) {
                list.add(new ActionTransitions(stateHashTuple.s, it.next(), this.hashingFactory));
            }
            if (this.useCachedTransitions) {
                this.transitionDynamics.put(stateHashTuple, list);
            }
        }
        return list;
    }

    public double performBellmanUpdateOn(State state) {
        return performBellmanUpdateOn(stateHash(state));
    }

    public double performFixedPolicyBellmanUpdateOn(State state, Policy policy) {
        return performFixedPolicyBellmanUpdateOn(stateHash(state), policy);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double performBellmanUpdateOn(StateHashTuple stateHashTuple) {
        if (this.tf.isTerminal(stateHashTuple.s)) {
            this.valueFunction.put(stateHashTuple, Double.valueOf(0.0d));
            return 0.0d;
        }
        double d = Double.NEGATIVE_INFINITY;
        if (this.useCachedTransitions) {
            Iterator<ActionTransitions> it = getActionsTransitions(stateHashTuple).iterator();
            while (it.hasNext()) {
                double computeQ = computeQ(stateHashTuple.s, it.next());
                if (computeQ > d) {
                    d = computeQ;
                }
            }
        } else {
            Iterator<GroundedAction> it2 = Action.getAllApplicableGroundedActionsFromActionList(this.actions, stateHashTuple.s).iterator();
            while (it2.hasNext()) {
                double computeQ2 = computeQ(stateHashTuple, it2.next());
                if (computeQ2 > d) {
                    d = computeQ2;
                }
            }
        }
        this.valueFunction.put(stateHashTuple, Double.valueOf(d));
        return d;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double performFixedPolicyBellmanUpdateOn(StateHashTuple stateHashTuple, Policy policy) {
        if (this.tf.isTerminal(stateHashTuple.s)) {
            this.valueFunction.put(stateHashTuple, Double.valueOf(0.0d));
            return 0.0d;
        }
        double d = 0.0d;
        List<Policy.ActionProb> actionDistributionForState = policy.getActionDistributionForState(stateHashTuple.s);
        if (this.useCachedTransitions) {
            for (ActionTransitions actionTransitions : getActionsTransitions(stateHashTuple)) {
                double probOfActionGivenDistribution = Policy.getProbOfActionGivenDistribution(actionTransitions.ga, actionDistributionForState);
                if (probOfActionGivenDistribution != 0.0d) {
                    d += probOfActionGivenDistribution * computeQ(stateHashTuple.s, actionTransitions);
                }
            }
        } else {
            for (GroundedAction groundedAction : Action.getAllApplicableGroundedActionsFromActionList(this.actions, stateHashTuple.s)) {
                double probOfActionGivenDistribution2 = Policy.getProbOfActionGivenDistribution(groundedAction, actionDistributionForState);
                if (probOfActionGivenDistribution2 != 0.0d) {
                    d += probOfActionGivenDistribution2 * computeQ(stateHashTuple, groundedAction);
                }
            }
        }
        this.valueFunction.put(stateHashTuple, Double.valueOf(d));
        return d;
    }

    protected double computeQ(State state, ActionTransitions actionTransitions) {
        double d = 0.0d;
        if (actionTransitions.ga.action instanceof Option) {
            d = 0.0d + ((Option) actionTransitions.ga.action).getExpectedRewards(state, actionTransitions.ga.params);
            for (HashedTransitionProbability hashedTransitionProbability : actionTransitions.transitions) {
                d += hashedTransitionProbability.p * value(hashedTransitionProbability.sh);
            }
        } else {
            for (HashedTransitionProbability hashedTransitionProbability2 : actionTransitions.transitions) {
                double value = value(hashedTransitionProbability2.sh);
                d += hashedTransitionProbability2.p * (this.rf.reward(state, actionTransitions.ga, hashedTransitionProbability2.sh.s) + (this.gamma * value));
            }
        }
        return d;
    }

    protected double computeQ(StateHashTuple stateHashTuple, GroundedAction groundedAction) {
        double d = 0.0d;
        if (groundedAction.action instanceof Option) {
            Option option = (Option) groundedAction.action;
            d = 0.0d + option.getExpectedRewards(stateHashTuple.s, groundedAction.params);
            for (TransitionProbability transitionProbability : option.getTransitions(stateHashTuple.s, groundedAction.params)) {
                d += transitionProbability.p * value(transitionProbability.s);
            }
        } else {
            for (TransitionProbability transitionProbability2 : groundedAction.action.getTransitions(stateHashTuple.s, groundedAction.params)) {
                double value = value(transitionProbability2.s);
                d += transitionProbability2.p * (this.rf.reward(stateHashTuple.s, groundedAction, transitionProbability2.s) + (this.gamma * value));
            }
        }
        return d;
    }

    protected double getDefaultValue(State state) {
        return this.valueInitializer.value(state);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void initializeOptionsForExpectationComputations() {
        for (Action action : this.actions) {
            if (action instanceof Option) {
                ((Option) action).setExpectationHashingFactory(this.hashingFactory);
            }
        }
    }
}
