package burlap.domain.singleagent.cartpole.model;

import burlap.domain.singleagent.cartpole.CartPoleDomain;
import burlap.domain.singleagent.cartpole.states.CartPoleFullState;
import burlap.mdp.core.StateTransitionProb;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.model.statemodel.FullStateModel;
import java.util.List;

/* loaded from: input_file:burlap/domain/singleagent/cartpole/model/CPCorrectModel.class */
public class CPCorrectModel implements FullStateModel {
    protected CartPoleDomain.CPPhysicsParams physParams;

    public CPCorrectModel(CartPoleDomain.CPPhysicsParams cPPhysicsParams) {
        this.physParams = cPPhysicsParams;
    }

    @Override // burlap.mdp.singleagent.model.statemodel.FullStateModel
    public List<StateTransitionProb> stateTransitions(State state, Action action) {
        return FullStateModel.Helper.deterministicTransition(this, state, action);
    }

    @Override // burlap.mdp.singleagent.model.statemodel.SampleStateModel
    public State sample(State state, Action action) {
        State copy = state.copy();
        if (action.actionName().equals("right")) {
            return moveCorrectModel(copy, 1.0d);
        }
        if (action.actionName().equals("left")) {
            return moveCorrectModel(copy, -1.0d);
        }
        throw new RuntimeException("Unknown action " + action.actionName());
    }

    public State moveCorrectModel(State state, double d) {
        CartPoleFullState cartPoleFullState = (CartPoleFullState) state;
        double d2 = cartPoleFullState.x;
        double d3 = cartPoleFullState.v;
        double d4 = cartPoleFullState.angle;
        double d5 = cartPoleFullState.angleV;
        double d6 = cartPoleFullState.normSign;
        double d7 = d * this.physParams.movementForceMag;
        double angle2ndDeriv = getAngle2ndDeriv(d3, d4, d5, d6, d7);
        double normForce = getNormForce(d4, d5, angle2ndDeriv);
        double signum = Math.signum(normForce);
        if (signum != d6) {
            angle2ndDeriv = getAngle2ndDeriv(d3, d4, d5, signum, d7);
        }
        double x2ndDeriv = getX2ndDeriv(d3, d4, d5, normForce, d7, angle2ndDeriv);
        double d8 = d2 + (this.physParams.timeDelta * d3);
        double d9 = d3 + (this.physParams.timeDelta * x2ndDeriv);
        double d10 = d4 + (this.physParams.timeDelta * d5);
        double d11 = d5 + (this.physParams.timeDelta * angle2ndDeriv);
        if (Math.abs(d8) > this.physParams.halfTrackLength) {
            d8 = Math.signum(d8) * this.physParams.halfTrackLength;
            d9 = 0.0d;
        }
        if (Math.abs(d9) > this.physParams.maxCartSpeed) {
            d9 = Math.signum(d9) * this.physParams.maxCartSpeed;
        }
        if (Math.abs(d10) >= this.physParams.angleRange) {
            d10 = Math.signum(d10) * this.physParams.angleRange;
            d11 = 0.0d;
        }
        if (Math.abs(d11) > this.physParams.maxAngleSpeed) {
            d11 = Math.signum(d11) * this.physParams.maxAngleSpeed;
        }
        if (this.physParams.isFiniteTrack) {
            cartPoleFullState.x = d8;
        }
        cartPoleFullState.v = d9;
        cartPoleFullState.angle = d10;
        cartPoleFullState.angleV = d11;
        cartPoleFullState.normSign = normForce;
        return state;
    }

    protected double getAngle2ndDeriv(double d, double d2, double d3, double d4, double d5) {
        double d6 = this.physParams.cartMass + this.physParams.poleMass;
        double sin = Math.sin(d2);
        double cos = Math.cos(d2);
        double signum = ((-d5) - ((((this.physParams.poleMass * this.physParams.halfPoleLength) * d3) * d3) * (sin + ((this.physParams.cartFriction * Math.signum(d4 * d)) * cos)))) / d6;
        return (((this.physParams.gravity * Math.sin(d2)) + (Math.cos(d2) * signum)) + ((this.physParams.cartFriction * this.physParams.gravity) * Math.signum(d4 * d))) / (this.physParams.halfPoleLength * (1.3333333333333333d - (((this.physParams.poleMass * cos) / d6) * (cos - (this.physParams.cartMass * Math.signum(d4 * d))))));
    }

    protected double getNormForce(double d, double d2, double d3) {
        return ((this.physParams.cartMass + this.physParams.poleMass) * this.physParams.gravity) - ((this.physParams.poleMass * this.physParams.halfPoleLength) * ((d3 * Math.sin(d)) + ((d2 * d2) * Math.cos(d))));
    }

    protected double getX2ndDeriv(double d, double d2, double d3, double d4, double d5, double d6) {
        return ((d5 + ((this.physParams.poleMass * this.physParams.halfPoleLength) * (((d3 * d3) * Math.sin(d2)) - (d6 * Math.cos(d2))))) - ((this.physParams.cartFriction * d4) * Math.signum(d4 * d))) / (this.physParams.cartMass + this.physParams.poleMass);
    }
}
