package burlap.behavior.stochasticgame.agents.naiveq;

import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.singleagent.Policy;
import burlap.behavior.singleagent.QValue;
import burlap.behavior.singleagent.ValueFunctionInitialization;
import burlap.behavior.singleagent.planning.QComputablePlanner;
import burlap.behavior.singleagent.planning.commonpolicies.EpsilonGreedy;
import burlap.behavior.statehashing.StateHashFactory;
import burlap.behavior.statehashing.StateHashTuple;
import burlap.oomdp.auxiliary.StateAbstraction;
import burlap.oomdp.auxiliary.common.NullAbstractionNoCopy;
import burlap.oomdp.core.AbstractGroundedAction;
import burlap.oomdp.core.State;
import burlap.oomdp.stochasticgames.Agent;
import burlap.oomdp.stochasticgames.GroundedSingleAction;
import burlap.oomdp.stochasticgames.JointAction;
import burlap.oomdp.stochasticgames.SGDomain;
import burlap.oomdp.stochasticgames.SingleAction;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/stochasticgame/agents/naiveq/SGNaiveQLAgent.class */
public class SGNaiveQLAgent extends Agent implements QComputablePlanner {
    protected Map<StateHashTuple, List<QValue>> qMap;
    protected Map<StateHashTuple, State> stateRepresentations;
    protected StateAbstraction storedMapAbstraction;
    protected double discount;
    protected LearningRate learningRate;
    protected ValueFunctionInitialization qInit;
    protected Policy policy;
    protected StateHashFactory hashFactory;
    protected int totalNumberOfSteps = 0;

    public SGNaiveQLAgent(SGDomain sGDomain, double d, double d2, StateHashFactory stateHashFactory) {
        init(sGDomain);
        this.discount = d;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.hashFactory = stateHashFactory;
        this.qInit = new ValueFunctionInitialization.ConstantValueFunctionInitialization(0.0d);
        this.qMap = new HashMap();
        this.stateRepresentations = new HashMap();
        this.policy = new EpsilonGreedy(this, 0.1d);
        this.storedMapAbstraction = new NullAbstractionNoCopy();
    }

    public SGNaiveQLAgent(SGDomain sGDomain, double d, double d2, double d3, StateHashFactory stateHashFactory) {
        init(sGDomain);
        this.discount = d;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.hashFactory = stateHashFactory;
        this.qInit = new ValueFunctionInitialization.ConstantValueFunctionInitialization(d3);
        this.qMap = new HashMap();
        this.stateRepresentations = new HashMap();
        this.policy = new EpsilonGreedy(this, 0.1d);
        this.storedMapAbstraction = new NullAbstractionNoCopy();
    }

    public SGNaiveQLAgent(SGDomain sGDomain, double d, double d2, ValueFunctionInitialization valueFunctionInitialization, StateHashFactory stateHashFactory) {
        init(sGDomain);
        this.discount = d;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.hashFactory = stateHashFactory;
        this.qInit = valueFunctionInitialization;
        this.qMap = new HashMap();
        this.stateRepresentations = new HashMap();
        this.policy = new EpsilonGreedy(this, 0.1d);
        this.storedMapAbstraction = new NullAbstractionNoCopy();
    }

    public void setStoredMapAbstraction(StateAbstraction stateAbstraction) {
        this.storedMapAbstraction = stateAbstraction;
    }

    public void setStrategy(Policy policy) {
        this.policy = policy;
    }

    public void setQValueInitializer(ValueFunctionInitialization valueFunctionInitialization) {
        this.qInit = valueFunctionInitialization;
    }

    public void setLearningRate(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    @Override // burlap.oomdp.stochasticgames.Agent
    public void gameStarting() {
    }

    @Override // burlap.oomdp.stochasticgames.Agent
    public GroundedSingleAction getAction(State state) {
        return (GroundedSingleAction) this.policy.getAction(state);
    }

    @Override // burlap.oomdp.stochasticgames.Agent
    public void observeOutcome(State state, JointAction jointAction, Map<String, Double> map, State state2, boolean z) {
        if (this.internalRewardFunction != null) {
            map = this.internalRewardFunction.reward(state, jointAction, state2);
        }
        GroundedSingleAction action = jointAction.action(this.worldAgentName);
        double doubleValue = map.get(this.worldAgentName).doubleValue();
        QValue q = getQ(state, action);
        double d = 0.0d;
        if (!z) {
            d = getMaxQValue(state2);
        }
        q.q += this.learningRate.pollLearningRate(this.totalNumberOfSteps, state, action) * ((doubleValue + (this.discount * d)) - q.q);
        this.totalNumberOfSteps++;
    }

    @Override // burlap.oomdp.stochasticgames.Agent
    public void gameTerminated() {
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getMaxQValue(State state) {
        double d = Double.NEGATIVE_INFINITY;
        Iterator<QValue> it = getQs(state).iterator();
        while (it.hasNext()) {
            d = Math.max(d, it.next().q);
        }
        return d;
    }

    protected StateHashTuple stateHash(State state) {
        return this.hashFactory.hashState(this.storedMapAbstraction.abstraction(state));
    }

    protected GroundedSingleAction translateAction(GroundedSingleAction groundedSingleAction, Map<String, String> map) {
        String[] strArr = new String[groundedSingleAction.params.length];
        for (int i = 0; i < groundedSingleAction.params.length; i++) {
            strArr[i] = map.get(groundedSingleAction.params[i]);
        }
        return new GroundedSingleAction(this.worldAgentName, groundedSingleAction.action, strArr);
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public List<QValue> getQs(State state) {
        List<GroundedSingleAction> allPossibleGroundedSingleActions = SingleAction.getAllPossibleGroundedSingleActions(state, this.worldAgentName, this.agentType.actions);
        StateHashTuple stateHash = stateHash(state);
        State state2 = this.stateRepresentations.get(stateHash);
        if (state2 == null) {
            this.stateRepresentations.put(stateHash, stateHash.s);
            ArrayList arrayList = new ArrayList();
            for (GroundedSingleAction groundedSingleAction : allPossibleGroundedSingleActions) {
                arrayList.add(new QValue(stateHash.s, groundedSingleAction, this.qInit.qValue(stateHash.s, groundedSingleAction)));
            }
            this.qMap.put(stateHash, arrayList);
            return arrayList;
        }
        List<QValue> list = this.qMap.get(stateHash);
        ArrayList arrayList2 = new ArrayList(allPossibleGroundedSingleActions.size());
        Map<String, String> map = null;
        for (GroundedSingleAction groundedSingleAction2 : allPossibleGroundedSingleActions) {
            GroundedSingleAction groundedSingleAction3 = groundedSingleAction2;
            if (groundedSingleAction2.isParameterized() && !this.domain.isObjectIdentifierDependent() && groundedSingleAction2.parametersAreObjects()) {
                if (map == null) {
                    map = stateHash.s.getObjectMatchingTo(state2, false);
                }
                groundedSingleAction3 = translateAction(groundedSingleAction2, map);
            }
            boolean z = false;
            Iterator<QValue> it = list.iterator();
            while (true) {
                if (!it.hasNext()) {
                    break;
                }
                QValue next = it.next();
                if (next.a.equals(groundedSingleAction3)) {
                    arrayList2.add(next);
                    z = true;
                    break;
                }
            }
            if (!z) {
                QValue qValue = new QValue(stateHash.s, groundedSingleAction3, this.qInit.qValue(stateHash.s, groundedSingleAction3));
                list.add(qValue);
                arrayList2.add(qValue);
            }
        }
        if (arrayList2.size() == 0) {
            throw new RuntimeException();
        }
        return arrayList2;
    }

    @Override // burlap.behavior.singleagent.planning.QComputablePlanner
    public QValue getQ(State state, AbstractGroundedAction abstractGroundedAction) {
        GroundedSingleAction groundedSingleAction = (GroundedSingleAction) abstractGroundedAction;
        StateHashTuple stateHash = stateHash(state);
        State state2 = this.stateRepresentations.get(stateHash);
        if (state2 == null) {
            this.stateRepresentations.put(stateHash, stateHash.s);
            QValue qValue = new QValue(state2, groundedSingleAction, this.qInit.qValue(stateHash.s, groundedSingleAction));
            ArrayList arrayList = new ArrayList();
            arrayList.add(qValue);
            this.qMap.put(stateHash, arrayList);
            return qValue;
        }
        if (groundedSingleAction.isParameterized() && !this.domain.isObjectIdentifierDependent() && abstractGroundedAction.parametersAreObjects()) {
            groundedSingleAction = translateAction(groundedSingleAction, stateHash.s.getObjectMatchingTo(state2, false));
        }
        List<QValue> list = this.qMap.get(stateHash);
        for (QValue qValue2 : list) {
            if (qValue2.a.equals(groundedSingleAction)) {
                return qValue2;
            }
        }
        QValue qValue3 = new QValue(stateHash.s, groundedSingleAction, this.qInit.qValue(stateHash.s, groundedSingleAction));
        list.add(qValue3);
        return qValue3;
    }
}
