package burlap.oomdp.singleagent.interfaces.rlglue;

import burlap.oomdp.auxiliary.StateGenerator;
import burlap.oomdp.core.Attribute;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.ObjectInstance;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.singleagent.Action;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.rlcommunity.rlglue.codec.EnvironmentInterface;
import org.rlcommunity.rlglue.codec.taskspec.TaskSpecVRLGLUE3;
import org.rlcommunity.rlglue.codec.taskspec.ranges.DoubleRange;
import org.rlcommunity.rlglue.codec.taskspec.ranges.IntRange;
import org.rlcommunity.rlglue.codec.types.Observation;
import org.rlcommunity.rlglue.codec.types.Reward_observation_terminal;
import org.rlcommunity.rlglue.codec.util.EnvironmentLoader;

/* loaded from: input_file:burlap/oomdp/singleagent/interfaces/rlglue/RLGlueEnvironment.class */
public class RLGlueEnvironment implements EnvironmentInterface {
    protected Domain domain;
    protected StateGenerator stateGenerator;
    protected RewardFunction rf;
    protected TerminalFunction tf;
    protected DoubleRange rewardRange;
    protected boolean isEpisodic;
    protected double discount;
    protected State curState;
    protected Map<Integer, ActionIndexParameterization> actionMap = new HashMap();
    protected int numDiscreteAtts = 0;
    protected int numContinuousAtts = 0;
    protected boolean usedConstructorState = false;
    protected Map<String, Integer> numObjectsOfEachClass = new HashMap();
    protected int numObjects = 0;

    /* loaded from: input_file:burlap/oomdp/singleagent/interfaces/rlglue/RLGlueEnvironment$ActionIndexParameterization.class */
    protected class ActionIndexParameterization {
        public Action action;
        public int[] params;

        public ActionIndexParameterization(GroundedAction groundedAction, State state) {
            this.action = groundedAction.action;
            this.params = new int[groundedAction.params.length];
            for (int i = 0; i < groundedAction.params.length; i++) {
                this.params[i] = RLGlueEnvironment.this.objectIndex(state, groundedAction.params[i]);
            }
        }

        public GroundedAction generateGroundedActionForState(State state) {
            List<ObjectInstance> allObjects = state.getAllObjects();
            String[] strArr = new String[this.params.length];
            for (int i = 0; i < strArr.length; i++) {
                strArr[i] = allObjects.get(this.params[i]).getName();
            }
            return new GroundedAction(this.action, strArr);
        }
    }

    public RLGlueEnvironment(Domain domain, StateGenerator stateGenerator, RewardFunction rewardFunction, TerminalFunction terminalFunction, DoubleRange doubleRange, boolean z, double d) {
        this.domain = domain;
        this.stateGenerator = stateGenerator;
        this.rf = rewardFunction;
        this.tf = terminalFunction;
        this.rewardRange = doubleRange;
        this.isEpisodic = z;
        this.discount = d;
        Iterator<Integer> it = this.numObjectsOfEachClass.values().iterator();
        while (it.hasNext()) {
            this.numObjects += it.next().intValue();
        }
        State generateState = this.stateGenerator.generateState();
        int i = 0;
        Iterator<Action> it2 = this.domain.getActions().iterator();
        while (it2.hasNext()) {
            Iterator<GroundedAction> it3 = it2.next().getAllApplicableGroundedActions(generateState).iterator();
            while (it3.hasNext()) {
                this.actionMap.put(Integer.valueOf(i), new ActionIndexParameterization(it3.next(), generateState));
                i++;
            }
        }
        for (List<ObjectInstance> list : generateState.getAllObjectsByTrueClass()) {
            this.numObjectsOfEachClass.put(list.get(0).getTrueClassName(), Integer.valueOf(list.size()));
        }
        this.curState = generateState;
    }

    public void load() {
        new EnvironmentLoader(this).run();
    }

    public void load(String str, String str2) {
        new EnvironmentLoader(str, str2, this).run();
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public void env_cleanup() {
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public String env_init() {
        TaskSpecVRLGLUE3 taskSpecVRLGLUE3 = new TaskSpecVRLGLUE3();
        if (this.isEpisodic) {
            taskSpecVRLGLUE3.setEpisodic();
        } else {
            taskSpecVRLGLUE3.setContinuing();
        }
        taskSpecVRLGLUE3.setDiscountFactor(this.discount);
        taskSpecVRLGLUE3.setRewardRange(this.rewardRange);
        taskSpecVRLGLUE3.addDiscreteAction(new IntRange(0, this.actionMap.size() - 1));
        for (Map.Entry<String, Integer> entry : this.numObjectsOfEachClass.entrySet()) {
            int intValue = entry.getValue().intValue();
            List<Attribute> list = this.domain.getObjectClass(entry.getKey()).attributeList;
            for (int i = 0; i < intValue; i++) {
                Iterator<Attribute> it = list.iterator();
                while (it.hasNext()) {
                    addAttribute(taskSpecVRLGLUE3, it.next());
                }
            }
        }
        return taskSpecVRLGLUE3.toTaskSpec();
    }

    protected void addAttribute(TaskSpecVRLGLUE3 taskSpecVRLGLUE3, Attribute attribute) {
        Attribute.AttributeType attributeType = attribute.type;
        if (attributeType == Attribute.AttributeType.DISC || attributeType == Attribute.AttributeType.BOOLEAN || attributeType == Attribute.AttributeType.INT) {
            taskSpecVRLGLUE3.addDiscreteObservation(new IntRange((int) attribute.lowerLim, (int) attribute.upperLim));
            this.numDiscreteAtts++;
        } else if (attributeType == Attribute.AttributeType.RELATIONAL) {
            taskSpecVRLGLUE3.addDiscreteObservation(new IntRange(0, this.numObjects - 1));
            this.numDiscreteAtts++;
        } else {
            if (attributeType != Attribute.AttributeType.REAL && attributeType != Attribute.AttributeType.REALUNBOUND) {
                throw new RuntimeException("Cannot create RLGlue Attribute for BURLAP att type: " + attributeType);
            }
            taskSpecVRLGLUE3.addContinuousObservation(new DoubleRange(attribute.lowerLim, attribute.upperLim));
            this.numContinuousAtts++;
        }
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public String env_message(String str) {
        return "Messages not supportd by default BURLAP RLGlueEnvironment";
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public Observation env_start() {
        if (this.usedConstructorState) {
            this.curState = this.stateGenerator.generateState();
        } else {
            this.usedConstructorState = true;
        }
        return convertIntoObservation(this.curState);
    }

    @Override // org.rlcommunity.rlglue.codec.EnvironmentInterface
    public Reward_observation_terminal env_step(org.rlcommunity.rlglue.codec.types.Action action) {
        GroundedAction generateGroundedActionForState = this.actionMap.get(Integer.valueOf(action.getInt(0))).generateGroundedActionForState(this.curState);
        State executeIn = generateGroundedActionForState.executeIn(this.curState);
        Observation convertIntoObservation = convertIntoObservation(executeIn);
        double reward = this.rf.reward(this.curState, generateGroundedActionForState, executeIn);
        boolean isTerminal = this.tf.isTerminal(executeIn);
        this.curState = executeIn;
        return new Reward_observation_terminal(reward, convertIntoObservation, isTerminal);
    }

    protected Observation convertIntoObservation(State state) {
        Observation observation = new Observation(this.numDiscreteAtts, this.numContinuousAtts);
        int i = 0;
        int i2 = 0;
        for (Map.Entry<String, Integer> entry : this.numObjectsOfEachClass.entrySet()) {
            List<ObjectInstance> objectsOfClass = state.getObjectsOfClass(entry.getKey());
            List<Attribute> list = this.domain.getObjectClass(entry.getKey()).attributeList;
            for (int i3 = 0; i3 < objectsOfClass.size(); i3++) {
                ObjectInstance objectInstance = objectsOfClass.get(i3);
                for (Attribute attribute : list) {
                    if (attribute.type == Attribute.AttributeType.DISC || attribute.type == Attribute.AttributeType.INT || attribute.type == Attribute.AttributeType.BOOLEAN) {
                        observation.setInt(i, objectInstance.getIntValForAttribute(attribute.name));
                        i++;
                    } else if (attribute.type == Attribute.AttributeType.REAL || attribute.type == Attribute.AttributeType.REALUNBOUND) {
                        observation.setDouble(i2, objectInstance.getRealValForAttribute(attribute.name));
                        i2++;
                    } else if (attribute.type == Attribute.AttributeType.RELATIONAL) {
                        observation.setDouble(i, objectIndex(state, objectInstance.getName()));
                        i++;
                    }
                }
            }
        }
        return observation;
    }

    protected int objectIndex(State state, String str) {
        int i = 0;
        Iterator<ObjectInstance> it = state.getAllObjects().iterator();
        while (it.hasNext()) {
            if (it.next().getName().equals(str)) {
                return i;
            }
            i++;
        }
        throw new RuntimeException("Could not find object " + str);
    }
}
