package burlap.behavior.singleagent.planning.stochastic.rtdp;

import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.singleagent.planning.stochastic.DynamicProgramming;
import burlap.behavior.singleagent.planning.stochastic.dpoperator.DPOperator;
import burlap.behavior.valuefunction.QValue;
import burlap.behavior.valuefunction.ValueFunction;
import burlap.debugtools.DPrint;
import burlap.debugtools.RandomFactory;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.model.FullModel;
import burlap.mdp.singleagent.model.TransitionProb;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/rtdp/BoundedRTDP.class */
public class BoundedRTDP extends DynamicProgramming implements Planner {
    protected ValueFunction lowerVInit;
    protected ValueFunction upperVInit;
    protected int maxRollouts;
    protected double maxDiff;
    protected Map<HashableState, Double> lowerBoundV = new HashMap();
    protected Map<HashableState, Double> upperBoundV = new HashMap();
    protected int maxDepth = -1;
    protected boolean currentValueFunctionIsLower = false;
    protected boolean defaultToLowerValueAfterPlanning = true;
    protected StateSelectionMode selectionMode = StateSelectionMode.MODELBASED;
    protected int numBellmanUpdates = 0;
    protected int numSteps = 0;
    protected boolean runRolloutsInReverse = true;

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/rtdp/BoundedRTDP$StateSelectionAndExpectedGap.class */
    public static class StateSelectionAndExpectedGap {
        public HashableState sh;
        public double expectedGap;

        public StateSelectionAndExpectedGap(HashableState hashableState, double d) {
            this.sh = hashableState;
            this.expectedGap = d;
        }
    }

    /* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/rtdp/BoundedRTDP$StateSelectionMode.class */
    public enum StateSelectionMode {
        MODELBASED,
        WEIGHTEDMARGIN,
        MAXMARGIN
    }

    public BoundedRTDP(SADomain sADomain, double d, HashableStateFactory hashableStateFactory, ValueFunction valueFunction, ValueFunction valueFunction2, double d2, int i) {
        this.maxRollouts = -1;
        DPPInit(sADomain, d, hashableStateFactory);
        this.lowerVInit = valueFunction;
        this.upperVInit = valueFunction2;
        this.maxDiff = d2;
        this.maxRollouts = i;
    }

    @Override // burlap.behavior.singleagent.planning.stochastic.DynamicProgramming
    public void setOperator(DPOperator dPOperator) {
        throw new RuntimeException("Bounded RTDP does not currently support custom operators.");
    }

    public void setMaxNumberOfRollouts(int i) {
        this.maxRollouts = i;
    }

    public void setMaxRolloutDepth(int i) {
        this.maxDepth = i;
    }

    public void setMaxDifference(double d) {
        this.maxDiff = d;
    }

    public void setStateSelectionMode(StateSelectionMode stateSelectionMode) {
        this.selectionMode = stateSelectionMode;
    }

    public void setDefaultValueFunctionAfterARollout(boolean z) {
        this.defaultToLowerValueAfterPlanning = z;
    }

    public void setRunRolloutsInRevere(boolean z) {
        this.runRolloutsInReverse = z;
    }

    @Override // burlap.behavior.singleagent.planning.Planner
    public GreedyQPolicy planFromState(State state) {
        DPrint.cl(this.debugCode, "Beginning Planning.");
        for (int i = 0; runRollout(state) > this.maxDiff && (i < this.maxRollouts || this.maxRollouts == -1); i++) {
        }
        DPrint.cl(this.debugCode, "Finished planning with a total of " + this.numBellmanUpdates + " backups.");
        return new GreedyQPolicy(this);
    }

    public void setValueFunctionToUpperBound() {
        this.valueFunction = this.upperBoundV;
        this.valueInitializer = this.upperVInit;
        this.currentValueFunctionIsLower = false;
    }

    public void setValueFunctionToLowerBound() {
        this.valueFunction = this.lowerBoundV;
        this.valueInitializer = this.lowerVInit;
        this.currentValueFunctionIsLower = true;
    }

    public int getNumberOfBellmanUpdates() {
        return this.numBellmanUpdates;
    }

    public int getNumberOfSteps() {
        return this.numSteps;
    }

    public double runRollout(State state) {
        LinkedList linkedList = new LinkedList();
        HashableState hashState = this.hashingFactory.hashState(state);
        while (!this.model.terminal(hashState.s()) && (linkedList.size() < this.maxDepth + 1 || this.maxDepth == -1)) {
            if (this.runRolloutsInReverse) {
                linkedList.offerFirst(hashState);
            }
            setValueFunctionToLowerBound();
            this.lowerBoundV.put(hashState, Double.valueOf(maxQ(hashState.s()).q));
            setValueFunctionToUpperBound();
            QValue maxQ = maxQ(hashState.s());
            this.upperBoundV.put(hashState, Double.valueOf(maxQ.q));
            this.numBellmanUpdates += 2;
            this.numSteps++;
            StateSelectionAndExpectedGap nextState = getNextState(hashState.s(), maxQ.a);
            hashState = nextState.sh;
            if (nextState.expectedGap < this.maxDiff) {
                break;
            }
        }
        if (this.model.terminal(hashState.s())) {
            this.lowerBoundV.put(hashState, Double.valueOf(0.0d));
            this.upperBoundV.put(hashState, Double.valueOf(0.0d));
        }
        double d = 0.0d;
        if (this.runRolloutsInReverse) {
            while (!linkedList.isEmpty()) {
                HashableState hashableState = (HashableState) linkedList.pop();
                setValueFunctionToLowerBound();
                QValue maxQ2 = maxQ(hashableState.s());
                this.lowerBoundV.put(hashableState, Double.valueOf(maxQ2.q));
                setValueFunctionToUpperBound();
                QValue maxQ3 = maxQ(hashableState.s());
                this.upperBoundV.put(hashableState, Double.valueOf(maxQ3.q));
                this.numBellmanUpdates += 2;
                d = maxQ3.q - maxQ2.q;
            }
        } else {
            d = getGap(this.hashingFactory.hashState(state));
        }
        if (this.defaultToLowerValueAfterPlanning) {
            setValueFunctionToLowerBound();
        } else {
            setValueFunctionToUpperBound();
        }
        return d;
    }

    protected StateSelectionAndExpectedGap getNextState(State state, Action action) {
        if (this.selectionMode == StateSelectionMode.MODELBASED) {
            HashableState hashState = this.hashingFactory.hashState(this.model.sample(state, action).op);
            return new StateSelectionAndExpectedGap(hashState, getGap(hashState));
        }
        if (this.selectionMode == StateSelectionMode.WEIGHTEDMARGIN) {
            return getNextStateBySampling(state, action);
        }
        if (this.selectionMode == StateSelectionMode.MAXMARGIN) {
            return getNextStateByMaxMargin(state, action);
        }
        throw new RuntimeException("Unknown state selection mode.");
    }

    protected StateSelectionAndExpectedGap getNextStateByMaxMargin(State state, Action action) {
        List<TransitionProb> transitions = ((FullModel) this.model).transitions(state, action);
        double d = 0.0d;
        double d2 = Double.NEGATIVE_INFINITY;
        ArrayList arrayList = new ArrayList(transitions.size());
        for (TransitionProb transitionProb : transitions) {
            HashableState hashState = this.hashingFactory.hashState(transitionProb.eo.op);
            double gap = getGap(hashState);
            d += transitionProb.p * gap;
            if (gap == d2) {
                arrayList.add(hashState);
            } else if (gap > d2) {
                arrayList.clear();
                arrayList.add(hashState);
                d2 = gap;
            }
        }
        return new StateSelectionAndExpectedGap((HashableState) arrayList.get(RandomFactory.getMapped(0).nextInt(arrayList.size())), d);
    }

    protected StateSelectionAndExpectedGap getNextStateBySampling(State state, Action action) {
        List<TransitionProb> transitions = ((FullModel) this.model).transitions(state, action);
        double d = 0.0d;
        double[] dArr = new double[transitions.size()];
        HashableState[] hashableStateArr = new HashableState[transitions.size()];
        for (int i = 0; i < transitions.size(); i++) {
            TransitionProb transitionProb = transitions.get(i);
            HashableState hashState = this.hashingFactory.hashState(transitionProb.eo.op);
            hashableStateArr[i] = hashState;
            dArr[i] = transitionProb.p * getGap(hashState);
            d += dArr[i];
        }
        double nextDouble = RandomFactory.getMapped(0).nextDouble();
        double d2 = 0.0d;
        for (int i2 = 0; i2 < dArr.length; i2++) {
            d2 += dArr[i2] / d;
            if (nextDouble < d2) {
                return new StateSelectionAndExpectedGap(hashableStateArr[i2], d);
            }
        }
        throw new RuntimeException("Error: probabilities in state selection did not sum to 1.");
    }

    protected double getGap(HashableState hashableState) {
        setValueFunctionToLowerBound();
        double value = value(hashableState);
        setValueFunctionToUpperBound();
        return value(hashableState) - value;
    }

    protected QValue maxQ(State state) {
        List<QValue> qValues = qValues(state);
        double d = Double.NEGATIVE_INFINITY;
        ArrayList arrayList = new ArrayList(qValues.size());
        for (QValue qValue : qValues) {
            if (qValue.q == d) {
                arrayList.add(qValue);
            } else if (qValue.q > d) {
                d = qValue.q;
                arrayList.clear();
                arrayList.add(qValue);
            }
        }
        return (QValue) arrayList.get(RandomFactory.getMapped(0).nextInt(arrayList.size()));
    }
}
