package burlap.domain.singleagent.cartpole;

import burlap.oomdp.auxiliary.DomainGenerator;
import burlap.oomdp.core.Attribute;
import burlap.oomdp.core.Domain;
import burlap.oomdp.core.ObjectClass;
import burlap.oomdp.core.ObjectInstance;
import burlap.oomdp.core.State;
import burlap.oomdp.core.TerminalFunction;
import burlap.oomdp.core.TransitionProbability;
import burlap.oomdp.singleagent.Action;
import burlap.oomdp.singleagent.GroundedAction;
import burlap.oomdp.singleagent.RewardFunction;
import burlap.oomdp.singleagent.SADomain;
import burlap.oomdp.singleagent.explorer.VisualExplorer;
import java.util.List;

/* loaded from: input_file:burlap/domain/singleagent/cartpole/CartPoleDomain.class */
public class CartPoleDomain implements DomainGenerator {
    public static final String ATTX = "xAtt";
    public static final String ATTV = "xvAtt";
    public static final String ATTANGLE = "angleAtt";
    public static final String ATTANGLEV = "angleVAtt";
    public static final String ATTNORMSGN = "normalSign";
    public static final String CLASSCARTPOLE = "cartPole";
    public static final String ACTIONLEFT = "left";
    public static final String ACTIONRIGHT = "right";
    public CPPhysicsParams physParams = new CPPhysicsParams();

    /* loaded from: input_file:burlap/domain/singleagent/cartpole/CartPoleDomain$CPPhysicsParams.class */
    public static class CPPhysicsParams {
        public double halfTrackLength;
        public double angleRange;
        public double gravity;
        public double cartMass;
        public double poleMass;
        public double halfPoleLength;
        public double cartFriction;
        public double poleFriction;
        public double movementForceMag;
        public double timeDelta;
        public double maxCartSpeed;
        public double maxAngleSpeed;
        public boolean isFiniteTrack;
        public boolean useCorrectModel;

        public CPPhysicsParams() {
            this.halfTrackLength = 2.4d;
            this.angleRange = 1.5707963267948966d;
            this.gravity = 9.8d;
            this.cartMass = 1.0d;
            this.poleMass = 0.1d;
            this.halfPoleLength = 0.5d;
            this.cartFriction = 5.0E-4d;
            this.poleFriction = 2.0E-6d;
            this.movementForceMag = 10.0d;
            this.timeDelta = 0.02d;
            this.maxCartSpeed = 6.81d;
            this.maxAngleSpeed = 10.47d;
            this.isFiniteTrack = true;
            this.useCorrectModel = true;
        }

        public CPPhysicsParams(double d, double d2, double d3, double d4, double d5, double d6, double d7, double d8, double d9, double d10, double d11, double d12, boolean z, boolean z2) {
            this.halfTrackLength = 2.4d;
            this.angleRange = 1.5707963267948966d;
            this.gravity = 9.8d;
            this.cartMass = 1.0d;
            this.poleMass = 0.1d;
            this.halfPoleLength = 0.5d;
            this.cartFriction = 5.0E-4d;
            this.poleFriction = 2.0E-6d;
            this.movementForceMag = 10.0d;
            this.timeDelta = 0.02d;
            this.maxCartSpeed = 6.81d;
            this.maxAngleSpeed = 10.47d;
            this.isFiniteTrack = true;
            this.useCorrectModel = true;
            this.halfTrackLength = d;
            this.angleRange = d2;
            this.gravity = d3;
            this.cartMass = d4;
            this.poleMass = d5;
            this.halfPoleLength = d6;
            this.cartFriction = d7;
            this.poleFriction = d8;
            this.movementForceMag = d9;
            this.timeDelta = d10;
            this.maxCartSpeed = d11;
            this.maxAngleSpeed = d12;
            this.isFiniteTrack = z;
            this.useCorrectModel = z2;
        }

        public CPPhysicsParams copy() {
            return new CPPhysicsParams(this.halfTrackLength, this.angleRange, this.gravity, this.cartMass, this.poleMass, this.halfPoleLength, this.cartFriction, this.poleFriction, this.movementForceMag, this.timeDelta, this.maxCartSpeed, this.maxAngleSpeed, this.isFiniteTrack, this.useCorrectModel);
        }
    }

    /* loaded from: input_file:burlap/domain/singleagent/cartpole/CartPoleDomain$CartPoleRewardFunction.class */
    public static class CartPoleRewardFunction implements RewardFunction {
        double maxAbsoluteAngle;

        public CartPoleRewardFunction() {
            this.maxAbsoluteAngle = 0.20943951023931956d;
        }

        public CartPoleRewardFunction(double d) {
            this.maxAbsoluteAngle = 0.20943951023931956d;
            this.maxAbsoluteAngle = d;
        }

        @Override // burlap.oomdp.singleagent.RewardFunction
        public double reward(State state, GroundedAction groundedAction, State state2) {
            ObjectInstance firstObjectOfClass = state2.getFirstObjectOfClass(CartPoleDomain.CLASSCARTPOLE);
            double realValForAttribute = firstObjectOfClass.getRealValForAttribute("xAtt");
            Attribute attribute = firstObjectOfClass.getObjectClass().getAttribute("xAtt");
            double d = attribute.lowerLim;
            double d2 = attribute.upperLim;
            if (realValForAttribute <= d || realValForAttribute >= d2) {
                return -1.0d;
            }
            return Math.abs(firstObjectOfClass.getRealValForAttribute("angleAtt")) >= this.maxAbsoluteAngle ? -1.0d : 0.0d;
        }
    }

    /* loaded from: input_file:burlap/domain/singleagent/cartpole/CartPoleDomain$CartPoleTerminalFunction.class */
    public static class CartPoleTerminalFunction implements TerminalFunction {
        double maxAbsoluteAngle;

        public CartPoleTerminalFunction() {
            this.maxAbsoluteAngle = 0.20943951023931956d;
        }

        public CartPoleTerminalFunction(double d) {
            this.maxAbsoluteAngle = 0.20943951023931956d;
            this.maxAbsoluteAngle = d;
        }

        @Override // burlap.oomdp.core.TerminalFunction
        public boolean isTerminal(State state) {
            ObjectInstance firstObjectOfClass = state.getFirstObjectOfClass(CartPoleDomain.CLASSCARTPOLE);
            double realValForAttribute = firstObjectOfClass.getRealValForAttribute("xAtt");
            Attribute attribute = firstObjectOfClass.getObjectClass().getAttribute("xAtt");
            return realValForAttribute <= attribute.lowerLim || realValForAttribute >= attribute.upperLim || Math.abs(firstObjectOfClass.getRealValForAttribute("angleAtt")) >= this.maxAbsoluteAngle;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/domain/singleagent/cartpole/CartPoleDomain$MovementAction.class */
    public static class MovementAction extends Action {
        CPPhysicsParams physParams;
        double dir;

        public MovementAction(String str, Domain domain, double d, CPPhysicsParams cPPhysicsParams) {
            super(str, domain, "");
            this.dir = d;
            this.physParams = cPPhysicsParams;
        }

        @Override // burlap.oomdp.singleagent.Action
        protected State performActionHelper(State state, String[] strArr) {
            return this.physParams.useCorrectModel ? CartPoleDomain.moveCorrectModel(state, this.dir, this.physParams) : CartPoleDomain.moveClassicModel(state, this.dir, this.physParams);
        }

        @Override // burlap.oomdp.singleagent.Action
        public List<TransitionProbability> getTransitions(State state, String[] strArr) {
            return deterministicTransition(state, strArr);
        }
    }

    @Override // burlap.oomdp.auxiliary.DomainGenerator
    public Domain generateDomain() {
        SADomain sADomain = new SADomain();
        Attribute attribute = new Attribute(sADomain, "xAtt", Attribute.AttributeType.REAL);
        attribute.setLims(-this.physParams.halfTrackLength, this.physParams.halfTrackLength);
        Attribute attribute2 = new Attribute(sADomain, ATTV, Attribute.AttributeType.REAL);
        attribute2.setLims(-this.physParams.maxCartSpeed, this.physParams.maxCartSpeed);
        Attribute attribute3 = new Attribute(sADomain, "angleAtt", Attribute.AttributeType.REAL);
        attribute3.setLims(-this.physParams.angleRange, this.physParams.angleRange);
        Attribute attribute4 = new Attribute(sADomain, "angleVAtt", Attribute.AttributeType.REAL);
        attribute4.setLims(-this.physParams.maxAngleSpeed, this.physParams.maxAngleSpeed);
        Attribute attribute5 = null;
        if (this.physParams.useCorrectModel) {
            attribute5 = new Attribute(sADomain, ATTNORMSGN, Attribute.AttributeType.REAL);
            attribute5.setLims(-1.0d, 1.0d);
            attribute5.hidden = true;
        }
        ObjectClass objectClass = new ObjectClass(sADomain, CLASSCARTPOLE);
        objectClass.addAttribute(attribute);
        objectClass.addAttribute(attribute2);
        objectClass.addAttribute(attribute3);
        objectClass.addAttribute(attribute4);
        if (this.physParams.useCorrectModel) {
            objectClass.addAttribute(attribute5);
        }
        CPPhysicsParams copy = this.physParams.copy();
        new MovementAction("left", sADomain, -1.0d, copy);
        new MovementAction("right", sADomain, 1.0d, copy);
        return sADomain;
    }

    public void setToIncorrectClassicModelWithCorrectGravity() {
        this.physParams.gravity = Math.abs(this.physParams.gravity);
        this.physParams.useCorrectModel = false;
    }

    public void setToIncorrectClassicModel() {
        this.physParams.gravity = Math.abs(this.physParams.gravity) * (-1.0d);
        this.physParams.useCorrectModel = false;
    }

    public void setToCorrectModel() {
        this.physParams.gravity = Math.abs(this.physParams.gravity);
        this.physParams.useCorrectModel = true;
    }

    public double setMaxCartSpeedToMaxWithMovementFromOneSideToOther() {
        double d = this.physParams.movementForceMag / (this.physParams.cartMass + this.physParams.poleMass);
        return d * Math.sqrt((2.0d * (2.0d * this.physParams.halfTrackLength)) / d);
    }

    public static State getInitialState(Domain domain) {
        return getInitialState(domain, 0.0d, 0.0d, 0.0d, 0.0d);
    }

    public static State getInitialState(Domain domain, double d, double d2, double d3, double d4) {
        ObjectInstance objectInstance = new ObjectInstance(domain.getObjectClass(CLASSCARTPOLE), CLASSCARTPOLE);
        objectInstance.setValue("xAtt", d);
        objectInstance.setValue(ATTV, d2);
        objectInstance.setValue("angleAtt", d3);
        objectInstance.setValue("angleVAtt", d4);
        if (domain.getAttribute(ATTNORMSGN) != null) {
            objectInstance.setValue(ATTNORMSGN, 1.0d);
        }
        State state = new State();
        state.addObject(objectInstance);
        return state;
    }

    public static State moveClassicModel(State state, double d, CPPhysicsParams cPPhysicsParams) {
        ObjectInstance firstObjectOfClass = state.getFirstObjectOfClass(CLASSCARTPOLE);
        double realValForAttribute = firstObjectOfClass.getRealValForAttribute("xAtt");
        double realValForAttribute2 = firstObjectOfClass.getRealValForAttribute(ATTV);
        double realValForAttribute3 = firstObjectOfClass.getRealValForAttribute("angleAtt");
        double realValForAttribute4 = firstObjectOfClass.getRealValForAttribute("angleVAtt");
        double d2 = d * cPPhysicsParams.movementForceMag;
        double d3 = cPPhysicsParams.cartMass + cPPhysicsParams.poleMass;
        double sin = (((cPPhysicsParams.gravity * Math.sin(realValForAttribute3)) + (Math.cos(realValForAttribute3) * ((((-d2) - ((((cPPhysicsParams.poleMass * cPPhysicsParams.halfPoleLength) * realValForAttribute4) * realValForAttribute4) * Math.sin(realValForAttribute3))) + (cPPhysicsParams.cartFriction * Math.signum(realValForAttribute2))) / d3))) - ((cPPhysicsParams.poleFriction * realValForAttribute4) / (cPPhysicsParams.poleMass * cPPhysicsParams.halfPoleLength))) / (cPPhysicsParams.halfPoleLength * (1.3333333333333333d - ((cPPhysicsParams.poleMass * Math.pow(Math.cos(realValForAttribute3), 2.0d)) / d3)));
        double sin2 = ((d2 + ((cPPhysicsParams.poleMass * cPPhysicsParams.halfPoleLength) * (((realValForAttribute4 * realValForAttribute4) * Math.sin(realValForAttribute3)) - (sin * Math.cos(realValForAttribute3))))) - (cPPhysicsParams.cartFriction * Math.signum(realValForAttribute2))) / d3;
        double d4 = realValForAttribute + (cPPhysicsParams.timeDelta * realValForAttribute2);
        double d5 = realValForAttribute2 + (cPPhysicsParams.timeDelta * sin2);
        double d6 = realValForAttribute3 + (cPPhysicsParams.timeDelta * realValForAttribute4);
        double d7 = realValForAttribute4 + (cPPhysicsParams.timeDelta * sin);
        if (Math.abs(d4) > cPPhysicsParams.halfTrackLength) {
            d4 = Math.signum(d4) * cPPhysicsParams.halfTrackLength;
            d5 = 0.0d;
        }
        if (Math.abs(d5) > cPPhysicsParams.maxCartSpeed) {
            d5 = Math.signum(d5) * cPPhysicsParams.maxCartSpeed;
        }
        if (Math.abs(d6) >= cPPhysicsParams.angleRange) {
            d6 = Math.signum(d6) * cPPhysicsParams.angleRange;
            d7 = 0.0d;
        }
        if (Math.abs(d7) > cPPhysicsParams.maxAngleSpeed) {
            d7 = Math.signum(d7) * cPPhysicsParams.maxAngleSpeed;
        }
        if (cPPhysicsParams.isFiniteTrack) {
            firstObjectOfClass.setValue("xAtt", d4);
        }
        firstObjectOfClass.setValue(ATTV, d5);
        firstObjectOfClass.setValue("angleAtt", d6);
        firstObjectOfClass.setValue("angleVAtt", d7);
        return state;
    }

    public static State moveCorrectModel(State state, double d, CPPhysicsParams cPPhysicsParams) {
        ObjectInstance firstObjectOfClass = state.getFirstObjectOfClass(CLASSCARTPOLE);
        double realValForAttribute = firstObjectOfClass.getRealValForAttribute("xAtt");
        double realValForAttribute2 = firstObjectOfClass.getRealValForAttribute(ATTV);
        double realValForAttribute3 = firstObjectOfClass.getRealValForAttribute("angleAtt");
        double realValForAttribute4 = firstObjectOfClass.getRealValForAttribute("angleVAtt");
        double realValForAttribute5 = firstObjectOfClass.getRealValForAttribute(ATTNORMSGN);
        double d2 = d * cPPhysicsParams.movementForceMag;
        double angle2ndDeriv = getAngle2ndDeriv(realValForAttribute2, realValForAttribute3, realValForAttribute4, realValForAttribute5, d2, cPPhysicsParams);
        double normForce = getNormForce(realValForAttribute3, realValForAttribute4, angle2ndDeriv, cPPhysicsParams);
        double signum = Math.signum(normForce);
        if (signum != realValForAttribute5) {
            angle2ndDeriv = getAngle2ndDeriv(realValForAttribute2, realValForAttribute3, realValForAttribute4, signum, d2, cPPhysicsParams);
        }
        double x2ndDeriv = getX2ndDeriv(realValForAttribute2, realValForAttribute3, realValForAttribute4, normForce, d2, angle2ndDeriv, cPPhysicsParams);
        double d3 = realValForAttribute + (cPPhysicsParams.timeDelta * realValForAttribute2);
        double d4 = realValForAttribute2 + (cPPhysicsParams.timeDelta * x2ndDeriv);
        double d5 = realValForAttribute3 + (cPPhysicsParams.timeDelta * realValForAttribute4);
        double d6 = realValForAttribute4 + (cPPhysicsParams.timeDelta * angle2ndDeriv);
        if (Math.abs(d3) > cPPhysicsParams.halfTrackLength) {
            d3 = Math.signum(d3) * cPPhysicsParams.halfTrackLength;
            d4 = 0.0d;
        }
        if (Math.abs(d4) > cPPhysicsParams.maxCartSpeed) {
            d4 = Math.signum(d4) * cPPhysicsParams.maxCartSpeed;
        }
        if (Math.abs(d5) >= cPPhysicsParams.angleRange) {
            d5 = Math.signum(d5) * cPPhysicsParams.angleRange;
            d6 = 0.0d;
        }
        if (Math.abs(d6) > cPPhysicsParams.maxAngleSpeed) {
            d6 = Math.signum(d6) * cPPhysicsParams.maxAngleSpeed;
        }
        if (cPPhysicsParams.isFiniteTrack) {
            firstObjectOfClass.setValue("xAtt", d3);
        }
        firstObjectOfClass.setValue(ATTV, d4);
        firstObjectOfClass.setValue("angleAtt", d5);
        firstObjectOfClass.setValue("angleVAtt", d6);
        firstObjectOfClass.setValue(ATTNORMSGN, normForce);
        return state;
    }

    protected static double getAngle2ndDeriv(double d, double d2, double d3, double d4, double d5, CPPhysicsParams cPPhysicsParams) {
        double d6 = cPPhysicsParams.cartMass + cPPhysicsParams.poleMass;
        double sin = Math.sin(d2);
        double cos = Math.cos(d2);
        double signum = ((-d5) - ((((cPPhysicsParams.poleMass * cPPhysicsParams.halfPoleLength) * d3) * d3) * (sin + ((cPPhysicsParams.cartFriction * Math.signum(d4 * d)) * cos)))) / d6;
        return (((cPPhysicsParams.gravity * Math.sin(d2)) + (Math.cos(d2) * signum)) + ((cPPhysicsParams.cartFriction * cPPhysicsParams.gravity) * Math.signum(d4 * d))) / (cPPhysicsParams.halfPoleLength * (1.3333333333333333d - (((cPPhysicsParams.poleMass * cos) / d6) * (cos - (cPPhysicsParams.cartMass * Math.signum(d4 * d))))));
    }

    protected static double getNormForce(double d, double d2, double d3, CPPhysicsParams cPPhysicsParams) {
        return ((cPPhysicsParams.cartMass + cPPhysicsParams.poleMass) * cPPhysicsParams.gravity) - ((cPPhysicsParams.poleMass * cPPhysicsParams.halfPoleLength) * ((d3 * Math.sin(d)) + ((d2 * d2) * Math.cos(d))));
    }

    protected static double getX2ndDeriv(double d, double d2, double d3, double d4, double d5, double d6, CPPhysicsParams cPPhysicsParams) {
        return ((d5 + ((cPPhysicsParams.poleMass * cPPhysicsParams.halfPoleLength) * (((d3 * d3) * Math.sin(d2)) - (d6 * Math.cos(d2))))) - ((cPPhysicsParams.cartFriction * d4) * Math.signum(d4 * d))) / (cPPhysicsParams.cartMass + cPPhysicsParams.poleMass);
    }

    public static void main(String[] strArr) {
        Domain generateDomain = new CartPoleDomain().generateDomain();
        VisualExplorer visualExplorer = new VisualExplorer(generateDomain, CartPoleVisualizer.getCartPoleVisualizer(), getInitialState(generateDomain));
        visualExplorer.addKeyAction("a", "left");
        visualExplorer.addKeyAction("d", "right");
        visualExplorer.initGUI();
    }
}
