package burlap.behavior.singleagent.learning.modellearning.modelplanners;

import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.learning.modellearning.ModelPlanner;
import burlap.behavior.singleagent.planning.QComputablePlanner;
import burlap.behavior.singleagent.planning.commonpolicies.GreedyQPolicy;
import burlap.behavior.singleagent.planning.stochastic.valueiteration.ValueIteration;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.debugtools.DPrint;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:burlap/behavior/singleagent/learning/modellearning/modelplanners/VIModelPlanner.class */
public class VIModelPlanner implements ModelPlanner, QComputablePlanner {
    protected ValueIteration vi;
    protected Policy modelPolicy;
    protected State initialState;
    protected Domain domain;
    protected RewardFunction rf;
    protected TerminalFunction tf;
    protected double gamma;
    protected StateHashFactory hashingFactory;
    protected double maxDelta;
    protected int maxIterations;
    protected Set<StateHashTuple> observedStates = new HashSet();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:burlap/behavior/singleagent/learning/modellearning/modelplanners/VIModelPlanner$ReplanIfUnseenPolicy.class */
    public class ReplanIfUnseenPolicy extends Policy {
        Policy p;

        public ReplanIfUnseenPolicy(Policy policy) {
            this.p = policy;
        }

        @Override // burlap.behavior.singleagent.Policy
        public AbstractGroundedAction getAction(State state) {
            if (!VIModelPlanner.this.vi.hasComputedValueFor(state)) {
                VIModelPlanner.this.observedStates.add(VIModelPlanner.this.hashingFactory.hashState(state));
                VIModelPlanner.this.rerunVI();
            }
            return this.p.getAction(state);
        }

        @Override // burlap.behavior.singleagent.Policy
        public List<Policy.ActionProb> getActionDistributionForState(State state) {
            if (!VIModelPlanner.this.vi.hasComputedValueFor(state)) {
                VIModelPlanner.this.observedStates.add(VIModelPlanner.this.hashingFactory.hashState(state));
                VIModelPlanner.this.rerunVI();
            }
            return this.p.getActionDistributionForState(state);
        }

        @Override // burlap.behavior.singleagent.Policy
        public boolean isStochastic() {
            return this.p.isStochastic();
        }

        @Override // burlap.behavior.singleagent.Policy
        public boolean isDefinedFor(State state) {
            return this.p.isDefinedFor(state);
        }
    }

    /* loaded from: input_file:burlap/behavior/singleagent/learning/modellearning/modelplanners/VIModelPlanner$VIModelPlannerGenerator.class */
    public static class VIModelPlannerGenerator implements ModelPlanner.ModelPlannerGenerator {
        StateHashFactory hashingFactory;
        double maxDelta;
        int maxIterations;

        public VIModelPlannerGenerator(StateHashFactory stateHashFactory, double d, int i) {
            this.hashingFactory = stateHashFactory;
            this.maxDelta = d;
            this.maxIterations = i;
        }

        @Override // burlap.behavior.singleagent.learning.modellearning.ModelPlanner.ModelPlannerGenerator
        public ModelPlanner getModelPlanner(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d) {
            return new VIModelPlanner(domain, rewardFunction, terminalFunction, d, this.hashingFactory, this.maxDelta, this.maxIterations);
        }
    }

    public VIModelPlanner(Domain domain, RewardFunction rewardFunction, TerminalFunction terminalFunction, double d, StateHashFactory stateHashFactory, double d2, int i) {
        this.domain = domain;
        this.rf = rewardFunction;
        this.tf = terminalFunction;
        this.gamma = d;
        this.hashingFactory = stateHashFactory;
        this.maxDelta = d2;
        this.maxIterations = i;
        this.vi = new ValueIteration(domain, rewardFunction, terminalFunction, d, stateHashFactory, d2, i);
        DPrint.toggleCode(this.vi.getDebugCode(), false);
        this.modelPolicy = new ReplanIfUnseenPolicy(new GreedyQPolicy(this.vi));
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.ModelPlanner
    public void initializePlannerIn(State state) {
        this.initialState = state;
        this.observedStates.add(this.hashingFactory.hashState(state));
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.ModelPlanner
    public void modelChanged(State state) {
        this.observedStates.add(this.hashingFactory.hashState(state));
        rerunVI();
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.ModelPlanner
    public Policy modelPlannedPolicy() {
        return this.modelPolicy;
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.ModelPlanner
    public void resetPlanner() {
        this.vi.resetPlannerResults();
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public List<QValue> getQs(State state) {
        return this.vi.getQs(state);
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public QValue getQ(State state, AbstractGroundedAction abstractGroundedAction) {
        return this.vi.getQ(state, abstractGroundedAction);
    }

    public ValueIteration getValueIterationPlanner() {
        return this.vi;
    }

    protected void rerunVI() {
        this.vi = new ValueIteration(this.domain, this.rf, this.tf, this.gamma, this.hashingFactory, this.maxDelta, this.maxIterations);
        this.modelPolicy = new ReplanIfUnseenPolicy(new GreedyQPolicy(this.vi));
        Iterator<StateHashTuple> it = this.observedStates.iterator();
        while (it.hasNext()) {
            this.vi.performReachabilityFrom(it.next().s);
        }
        this.vi.runVI();
    }
}
