package burlap.behavior.singleagent.planning.stochastic.sparsesampling;

import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.options.EnvironmentOptionOutcome;
import burlap.behavior.singleagent.options.Option;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.singleagent.planning.stochastic.dpoperator.BellmanOperator;
import burlap.behavior.singleagent.planning.stochastic.dpoperator.DPOperator;
import burlap.behavior.valuefunction.ConstantValueFunction;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.behavior.valuefunction.ValueFunction;
import burlap.debugtools.DPrint;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import burlap.mdp.singleagent.model.FullModel;
import burlap.mdp.singleagent.model.TransitionProb;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/sparsesampling/SparseSampling.class */
public class SparseSampling extends MDPSolver implements QProvider, Planner {
    protected int h;
    protected int c;
    protected boolean computeExactValueFunction;
    protected Map<HashedHeightState, StateNode> nodesByHeight;
    protected Map<HashableState, List<QValue>> rootLevelQValues;
    protected boolean useVariableC = false;
    protected boolean forgetPreviousPlanResults = false;
    protected ValueFunction vinit = new ConstantValueFunction();
    protected int numUpdates = 0;
    protected DPOperator operator = new BellmanOperator();

    /* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/sparsesampling/SparseSampling$HashedHeightState.class */
    public static class HashedHeightState {
        public HashableState sh;
        public int height;

        public HashedHeightState(HashableState hashableState, int i) {
            this.sh = hashableState;
            this.height = i;
        }

        public boolean equals(Object obj) {
            if (obj == null || getClass() != obj.getClass()) {
                return false;
            }
            HashedHeightState hashedHeightState = (HashedHeightState) obj;
            return this.height == hashedHeightState.height && this.sh.equals(hashedHeightState.sh);
        }

        public int hashCode() {
            return (this.height * 31) + this.sh.hashCode();
        }
    }

    /* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/sparsesampling/SparseSampling$StateNode.class */
    public class StateNode {
        HashableState sh;
        int height;
        double v;
        boolean closed = false;

        public StateNode(HashableState hashableState, int i) {
            this.sh = hashableState;
            this.height = i;
        }

        public List<QValue> estimateQs() {
            List<Action> applicableActions = SparseSampling.this.applicableActions(this.sh.s());
            ArrayList arrayList = new ArrayList(applicableActions.size());
            for (Action action : applicableActions) {
                if (this.height <= 0) {
                    arrayList.add(new QValue(this.sh.s(), action, SparseSampling.this.vinit.value(this.sh.s())));
                } else {
                    arrayList.add(new QValue(this.sh.s(), action, !SparseSampling.this.computeExactValueFunction ? sampledQEstimate(action) : exactQValue(action)));
                }
            }
            return arrayList;
        }

        /* JADX WARN: Multi-variable type inference failed */
        protected double sampledQEstimate(Action action) {
            double d = 0.0d;
            int cAtHeight = SparseSampling.this.getCAtHeight(this.height);
            for (int i = 0; i < cAtHeight; i++) {
                EnvironmentOutcome sample = SparseSampling.this.model.sample(this.sh.s(), action);
                State state = sample.op;
                int i2 = 1;
                if (action instanceof Option) {
                    i2 = ((EnvironmentOptionOutcome) action).numSteps();
                }
                d += sample.r + (Math.pow(SparseSampling.this.gamma, i2) * SparseSampling.this.getStateNode(state, this.height - i2).estimateV());
            }
            return d / cAtHeight;
        }

        protected double exactQValue(Action action) {
            double d = 0.0d;
            List<TransitionProb> transitions = ((FullModel) SparseSampling.this.model).transitions(this.sh.s(), action);
            if (action instanceof Option) {
                throw new RuntimeException("Sparse Sampling Planner with Full Bellman updates turned on cannot work with options because it needs factored access to the depth for each option transition. Use the standard sampling mode instead.");
            }
            for (TransitionProb transitionProb : transitions) {
                d += transitionProb.p * (transitionProb.eo.r + (SparseSampling.this.gamma * SparseSampling.this.getStateNode(transitionProb.eo.op, this.height - 1).estimateV()));
            }
            return d;
        }

        public double estimateV() {
            if (this.closed) {
                return this.v;
            }
            if (SparseSampling.this.model.terminal(this.sh.s())) {
                this.v = 0.0d;
                this.closed = true;
                return this.v;
            }
            List<QValue> estimateQs = estimateQs();
            double[] dArr = new double[estimateQs.size()];
            for (int i = 0; i < estimateQs.size(); i++) {
                dArr[i] = estimateQs.get(i).q;
            }
            SparseSampling.this.numUpdates++;
            this.v = SparseSampling.this.operator.apply(dArr);
            this.closed = true;
            return this.v;
        }
    }

    public SparseSampling(SADomain sADomain, double d, HashableStateFactory hashableStateFactory, int i, int i2) {
        this.computeExactValueFunction = false;
        solverInit(sADomain, d, hashableStateFactory);
        this.h = i;
        this.c = i2;
        this.nodesByHeight = new HashMap();
        this.rootLevelQValues = new HashMap();
        if (this.c < 0) {
            this.computeExactValueFunction = true;
        }
        this.debugCode = 7369430;
    }

    public void setHAndCByMDPError(double d, double d2, int i) {
        double d3 = ((d2 * (1.0d - this.gamma)) * (1.0d - this.gamma)) / 4.0d;
        double d4 = d / (1.0d - this.gamma);
        this.h = ((int) logbase(this.gamma, d3 / d4)) + 1;
        this.c = (int) (((d4 * d4) / (d3 * d3)) * 2 * this.h * Math.log(((((i * this.h) * d4) * d4) / (d3 * d3)) + Math.log(d / d3)));
        DPrint.cl(this.debugCode, "H = " + this.h);
        DPrint.cl(this.debugCode, "C = " + this.c);
    }

    public void setUseVariableCSize(boolean z) {
        this.useVariableC = z;
    }

    public void setC(int i) {
        this.c = i;
        if (this.c < 0) {
            this.computeExactValueFunction = true;
        } else {
            this.computeExactValueFunction = false;
        }
    }

    public void setH(int i) {
        this.h = i;
    }

    public int getC() {
        return this.c;
    }

    public int getH() {
        return this.h;
    }

    public void setComputeExactValueFunction(boolean z) {
        this.computeExactValueFunction = z;
    }

    public boolean computesExactValueFunction() {
        return this.computeExactValueFunction;
    }

    public void setForgetPreviousPlanResults(boolean z) {
        this.forgetPreviousPlanResults = z;
        if (this.forgetPreviousPlanResults) {
            this.nodesByHeight.clear();
        }
    }

    public void setValueForLeafNodes(ValueFunction valueFunction) {
        this.vinit = valueFunction;
    }

    public DPOperator getOperator() {
        return this.operator;
    }

    public void setOperator(DPOperator dPOperator) {
        this.operator = dPOperator;
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public int getDebugCode() {
        return this.debugCode;
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void setDebugCode(int i) {
        this.debugCode = i;
    }

    public int getNumberOfValueEsitmates() {
        return this.numUpdates;
    }

    public int getNumberOfStateNodesCreated() {
        return this.nodesByHeight.size() + this.rootLevelQValues.size();
    }

    @Override // burlap.behavior.singleagent.planning.Planner
    public GreedyQPolicy planFromState(State state) {
        if (this.forgetPreviousPlanResults) {
            this.rootLevelQValues.clear();
        }
        HashableState hashState = this.hashingFactory.hashState(state);
        if (this.rootLevelQValues.containsKey(hashState)) {
            return new GreedyQPolicy(this);
        }
        DPrint.cl(this.debugCode, "Beginning Planning.");
        int i = this.numUpdates;
        this.rootLevelQValues.put(hashState, getStateNode(state, this.h).estimateQs());
        DPrint.cl(this.debugCode, "Finished Planning with " + (this.numUpdates - i) + " value esitmates; for a cumulative total of: " + this.numUpdates);
        if (this.forgetPreviousPlanResults) {
            this.nodesByHeight.clear();
        }
        return new GreedyQPolicy(this);
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void resetSolver() {
        this.nodesByHeight.clear();
        this.rootLevelQValues.clear();
        this.numUpdates = 0;
    }

    @Override // burlap.behavior.valuefunction.QProvider
    public List<QValue> qValues(State state) {
        HashableState hashState = this.hashingFactory.hashState(state);
        List<QValue> list = this.rootLevelQValues.get(hashState);
        if (list == null) {
            planFromState(state);
            list = this.rootLevelQValues.get(hashState);
        }
        return list;
    }

    @Override // burlap.behavior.valuefunction.QFunction
    public double qValue(State state, Action action) {
        HashableState hashState = this.hashingFactory.hashState(state);
        List<QValue> list = this.rootLevelQValues.get(hashState);
        if (list == null) {
            planFromState(state);
            list = this.rootLevelQValues.get(hashState);
        }
        for (QValue qValue : list) {
            if (qValue.a.equals(action)) {
                return qValue.q;
            }
        }
        throw new RuntimeException("Q-value for action " + action.toString() + " does not exist.");
    }

    @Override // burlap.behavior.valuefunction.ValueFunction
    public double value(State state) {
        if (this.model.terminal(state)) {
            return 0.0d;
        }
        List<QValue> qValues = qValues(state);
        double[] dArr = new double[qValues.size()];
        for (int i = 0; i < qValues.size(); i++) {
            dArr[i] = qValues.get(i).q;
        }
        return this.operator.apply(dArr);
    }

    protected int getCAtHeight(int i) {
        if (!this.useVariableC) {
            return this.c;
        }
        this.h = i;
        int pow = (int) (this.c * Math.pow(this.gamma, 2 * i));
        if (pow == 0) {
            pow = 1;
        }
        return pow;
    }

    protected StateNode getStateNode(State state, int i) {
        HashableState hashState = this.hashingFactory.hashState(state);
        HashedHeightState hashedHeightState = new HashedHeightState(hashState, i);
        StateNode stateNode = this.nodesByHeight.get(hashedHeightState);
        if (stateNode == null) {
            stateNode = new StateNode(hashState, i);
            this.nodesByHeight.put(hashedHeightState, stateNode);
        }
        return stateNode;
    }

    protected static double logbase(double d, double d2) {
        return Math.log(d2) / Math.log(d);
    }
}
