package burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners;

import burlap.behavior.functionapproximation.FunctionGradient;
import burlap.behavior.policy.BoltzmannQPolicy;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.learnfromdemo.CustomRewardModel;
import burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners.diffvinit.DifferentiableVInit;
import burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners.diffvinit.VanillaDiffVinit;
import burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners.dpoperator.DifferentiableDPOperator;
import burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners.dpoperator.DifferentiableSoftmaxOperator;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableQFunction;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableRF;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableValueFunction;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.QGradientTuple;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.singleagent.planning.stochastic.sparsesampling.SparseSampling;
import burlap.behavior.valuefunction.ConstantValueFunction;
import burlap.behavior.valuefunction.QFunction;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.behavior.valuefunction.ValueFunction;
import burlap.debugtools.DPrint;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import burlap.mdp.singleagent.model.FullModel;
import burlap.mdp.singleagent.model.SampleModel;
import burlap.mdp.singleagent.model.TransitionProb;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/mlirl/differentiableplanners/DifferentiableSparseSampling.class */
public class DifferentiableSparseSampling extends MDPSolver implements DifferentiableQFunction, QProvider, Planner {
    protected int h;
    protected int c;
    protected DifferentiableValueFunction vinit;
    protected DifferentiableRF rf;
    protected Map<SparseSampling.HashedHeightState, DiffStateNode> nodesByHeight;
    protected Map<HashableState, QAndQGradient> rootLevelQValues;
    protected double boltzBeta;
    protected int rfDim;
    protected DifferentiableDPOperator operator;
    protected boolean useVariableC = false;
    protected boolean forgetPreviousPlanResults = false;
    protected int numUpdates = 0;

    /* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/mlirl/differentiableplanners/DifferentiableSparseSampling$DiffStateNode.class */
    public class DiffStateNode {
        HashableState sh;
        int height;
        double v;
        FunctionGradient vgrad;
        boolean closed = false;

        public DiffStateNode(HashableState hashableState, int i) {
            this.sh = hashableState;
            this.height = i;
        }

        public QAndQGradient estimateQs() {
            int i = DifferentiableSparseSampling.this.rfDim;
            List<Action> applicableActions = DifferentiableSparseSampling.this.applicableActions(this.sh.s());
            QAndQGradient qAndQGradient = new QAndQGradient(applicableActions.size());
            int cAtHeight = DifferentiableSparseSampling.this.getCAtHeight(this.height);
            for (Action action : applicableActions) {
                if (this.height == 0 || cAtHeight == 0) {
                    qAndQGradient.add(new QValue(this.sh.s(), action, DifferentiableSparseSampling.this.vinit.value(this.sh.s())), null);
                } else if (cAtHeight > 0) {
                    sampledQEstimate(action, qAndQGradient);
                } else {
                    exactQEstimate(action, qAndQGradient);
                }
            }
            return qAndQGradient;
        }

        public void sampledQEstimate(Action action, QAndQGradient qAndQGradient) {
            FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient();
            double d = 0.0d;
            for (int i = 0; i < DifferentiableSparseSampling.this.c; i++) {
                EnvironmentOutcome sample = DifferentiableSparseSampling.this.model.sample(this.sh.s(), action);
                State state = sample.op;
                double d2 = sample.r;
                FunctionGradient gradient = DifferentiableSparseSampling.this.rf.gradient(this.sh.s(), action, state);
                VAndVGradient estimateV = DifferentiableSparseSampling.this.getStateNode(state, this.height - 1).estimateV();
                Set<Integer> combinedNonZeroPDParameters = DifferentiableSparseSampling.this.combinedNonZeroPDParameters(estimateV.vGrad, gradient);
                d += d2 + (DifferentiableSparseSampling.this.gamma * estimateV.v);
                for (Integer num : combinedNonZeroPDParameters) {
                    sparseGradient.put(num.intValue(), sparseGradient.getPartialDerivative(num.intValue()) + gradient.getPartialDerivative(num.intValue()) + (DifferentiableSparseSampling.this.gamma * estimateV.vGrad.getPartialDerivative(num.intValue())));
                }
            }
            double d3 = d / DifferentiableSparseSampling.this.c;
            for (FunctionGradient.PartialDerivative partialDerivative : sparseGradient.getNonZeroPartialDerivatives()) {
                sparseGradient.put(partialDerivative.parameterId, partialDerivative.value / DifferentiableSparseSampling.this.c);
            }
            qAndQGradient.add(new QValue(this.sh.s(), action, d3), new QGradientTuple(this.sh.s(), action, sparseGradient));
        }

        public void exactQEstimate(Action action, QAndQGradient qAndQGradient) {
            int i = DifferentiableSparseSampling.this.rfDim;
            FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient();
            double d = 0.0d;
            for (TransitionProb transitionProb : ((FullModel) DifferentiableSparseSampling.this.model).transitions(this.sh.s(), action)) {
                State state = transitionProb.eo.op;
                double d2 = transitionProb.eo.r;
                FunctionGradient gradient = DifferentiableSparseSampling.this.rf.gradient(this.sh.s(), action, state);
                VAndVGradient estimateV = DifferentiableSparseSampling.this.getStateNode(state, this.height - 1).estimateV();
                Set<Integer> combinedNonZeroPDParameters = DifferentiableSparseSampling.this.combinedNonZeroPDParameters(estimateV.vGrad, gradient);
                d += transitionProb.p * (d2 + (DifferentiableSparseSampling.this.gamma * estimateV.v));
                for (Integer num : combinedNonZeroPDParameters) {
                    sparseGradient.put(num.intValue(), sparseGradient.getPartialDerivative(num.intValue()) + (transitionProb.p * (gradient.getPartialDerivative(num.intValue()) + (DifferentiableSparseSampling.this.gamma * estimateV.vGrad.getPartialDerivative(num.intValue())))));
                }
            }
            qAndQGradient.add(new QValue(this.sh.s(), action, d), new QGradientTuple(this.sh.s(), action, sparseGradient));
        }

        public VAndVGradient estimateV() {
            if (this.closed) {
                return new VAndVGradient(this.v, this.vgrad);
            }
            if (DifferentiableSparseSampling.this.model.terminal(this.sh.s())) {
                this.v = 0.0d;
                this.vgrad = new FunctionGradient.SparseGradient();
                this.closed = true;
                return new VAndVGradient(this.v, this.vgrad);
            }
            QAndQGradient estimateQs = estimateQs();
            setV(estimateQs);
            setVGrad(estimateQs);
            this.closed = true;
            DifferentiableSparseSampling.this.numUpdates++;
            return new VAndVGradient(this.v, this.vgrad);
        }

        protected void setV(QAndQGradient qAndQGradient) {
            double[] dArr = new double[qAndQGradient.qs.size()];
            for (int i = 0; i < qAndQGradient.qs.size(); i++) {
                dArr[i] = qAndQGradient.qs.get(i).q;
            }
            this.v = DifferentiableSparseSampling.this.operator.apply(dArr);
        }

        protected void setVGrad(QAndQGradient qAndQGradient) {
            if (qAndQGradient.qGrads.get(0) == null) {
                this.vgrad = DifferentiableSparseSampling.this.vinit.valueGradient(this.sh.s());
                return;
            }
            double[] dArr = new double[qAndQGradient.qs.size()];
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = qAndQGradient.qs.get(i).q;
            }
            FunctionGradient[] functionGradientArr = new FunctionGradient[dArr.length];
            for (int i2 = 0; i2 < dArr.length; i2++) {
                functionGradientArr[i2] = qAndQGradient.qGrads.get(i2).gradient;
            }
            this.vgrad = DifferentiableSparseSampling.this.operator.gradient(dArr, functionGradientArr);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/mlirl/differentiableplanners/DifferentiableSparseSampling$QAndQGradient.class */
    public static class QAndQGradient {
        List<QValue> qs;
        List<QGradientTuple> qGrads;

        public QAndQGradient(List<QValue> list, List<QGradientTuple> list2) {
            this.qs = list;
            this.qGrads = list2;
        }

        public QAndQGradient(int i) {
            this.qs = new ArrayList(i);
            this.qGrads = new ArrayList(i);
        }

        public void add(QValue qValue, QGradientTuple qGradientTuple) {
            this.qs.add(qValue);
            this.qGrads.add(qGradientTuple);
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/mlirl/differentiableplanners/DifferentiableSparseSampling$VAndVGradient.class */
    public static class VAndVGradient {
        double v;
        FunctionGradient vGrad;

        public VAndVGradient(double d, FunctionGradient functionGradient) {
            this.v = d;
            this.vGrad = functionGradient;
        }
    }

    public DifferentiableSparseSampling(SADomain sADomain, DifferentiableRF differentiableRF, double d, HashableStateFactory hashableStateFactory, int i, int i2, double d2) {
        solverInit(sADomain, d, hashableStateFactory);
        this.h = i;
        this.c = i2;
        this.rf = differentiableRF;
        this.boltzBeta = d2;
        this.nodesByHeight = new HashMap();
        this.rootLevelQValues = new HashMap();
        this.rfDim = differentiableRF.numParameters();
        this.vinit = new VanillaDiffVinit(new ConstantValueFunction(), differentiableRF);
        this.model = new CustomRewardModel(sADomain.getModel(), differentiableRF);
        this.operator = new DifferentiableSoftmaxOperator(d2);
        this.debugCode = 6368290;
    }

    public void setUseVariableCSize(boolean z) {
        this.useVariableC = z;
    }

    public void setC(int i) {
        this.c = i;
    }

    public void setH(int i) {
        this.h = i;
    }

    public int getC() {
        return this.c;
    }

    public int getH() {
        return this.h;
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public SampleModel getModel() {
        return this.model;
    }

    public DifferentiableDPOperator getOperator() {
        return this.operator;
    }

    public void setOperator(DifferentiableDPOperator differentiableDPOperator) {
        this.operator = differentiableDPOperator;
    }

    public void setForgetPreviousPlanResults(boolean z) {
        this.forgetPreviousPlanResults = z;
        if (this.forgetPreviousPlanResults) {
            this.nodesByHeight.clear();
        }
    }

    public void setValueForLeafNodes(ValueFunction valueFunction) {
        if (valueFunction instanceof DifferentiableVInit) {
            this.vinit = (DifferentiableVInit) valueFunction;
        } else {
            this.vinit = new VanillaDiffVinit((QFunction) valueFunction, this.rf);
        }
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public int getDebugCode() {
        return this.debugCode;
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void setDebugCode(int i) {
        this.debugCode = i;
    }

    public int getNumberOfValueEsitmates() {
        return this.numUpdates;
    }

    @Override // burlap.behavior.valuefunction.QProvider
    public List<QValue> qValues(State state) {
        HashableState hashState = this.hashingFactory.hashState(state);
        QAndQGradient qAndQGradient = this.rootLevelQValues.get(hashState);
        if (qAndQGradient == null) {
            planFromState(state);
            qAndQGradient = this.rootLevelQValues.get(hashState);
        }
        return qAndQGradient.qs;
    }

    @Override // burlap.behavior.valuefunction.QFunction
    public double qValue(State state, Action action) {
        HashableState hashState = this.hashingFactory.hashState(state);
        QAndQGradient qAndQGradient = this.rootLevelQValues.get(hashState);
        if (qAndQGradient == null) {
            planFromState(state);
            qAndQGradient = this.rootLevelQValues.get(hashState);
        }
        for (QValue qValue : qAndQGradient.qs) {
            if (qValue.a.equals(action)) {
                return qValue.q;
            }
        }
        throw new RuntimeException("Could not find a Q-value for: " + action.toString());
    }

    @Override // burlap.behavior.valuefunction.ValueFunction
    public double value(State state) {
        if (this.model.terminal(state)) {
            return 0.0d;
        }
        return QProvider.Helper.maxQ(this, state);
    }

    @Override // burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableQFunction
    public FunctionGradient qGradient(State state, Action action) {
        HashableState hashState = this.hashingFactory.hashState(state);
        QAndQGradient qAndQGradient = this.rootLevelQValues.get(hashState);
        if (qAndQGradient == null) {
            planFromState(state);
            qAndQGradient = this.rootLevelQValues.get(hashState);
        }
        for (QGradientTuple qGradientTuple : qAndQGradient.qGrads) {
            if (qGradientTuple.a.equals(action)) {
                return qGradientTuple.gradient;
            }
        }
        return null;
    }

    @Override // burlap.behavior.singleagent.planning.Planner
    public BoltzmannQPolicy planFromState(State state) {
        if (this.forgetPreviousPlanResults) {
            this.rootLevelQValues.clear();
        }
        HashableState hashState = this.hashingFactory.hashState(state);
        if (this.rootLevelQValues.containsKey(hashState)) {
            return new BoltzmannQPolicy(this, 1.0d / this.boltzBeta);
        }
        DPrint.cl(this.debugCode, "Beginning Planning.");
        int i = this.numUpdates;
        this.rootLevelQValues.put(hashState, getStateNode(state, this.h).estimateQs());
        DPrint.cl(this.debugCode, "Finished Planning with " + (this.numUpdates - i) + " value esitmates; for a cumulative total of: " + this.numUpdates);
        if (this.forgetPreviousPlanResults) {
            this.nodesByHeight.clear();
        }
        return new BoltzmannQPolicy(this, 1.0d / this.boltzBeta);
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void resetSolver() {
        this.nodesByHeight.clear();
        this.rootLevelQValues.clear();
        this.numUpdates = 0;
    }

    protected int getCAtHeight(int i) {
        if (!this.useVariableC) {
            return this.c;
        }
        this.h = i;
        int pow = (int) (this.c * Math.pow(this.gamma, 2 * i));
        if (pow == 0) {
            pow = 1;
        }
        return pow;
    }

    protected DiffStateNode getStateNode(State state, int i) {
        HashableState hashState = this.hashingFactory.hashState(state);
        SparseSampling.HashedHeightState hashedHeightState = new SparseSampling.HashedHeightState(hashState, i);
        DiffStateNode diffStateNode = this.nodesByHeight.get(hashedHeightState);
        if (diffStateNode == null) {
            diffStateNode = new DiffStateNode(hashState, i);
            this.nodesByHeight.put(hashedHeightState, diffStateNode);
        }
        return diffStateNode;
    }

    protected Set<Integer> combinedNonZeroPDParameters(FunctionGradient... functionGradientArr) {
        HashSet hashSet = new HashSet();
        for (FunctionGradient functionGradient : functionGradientArr) {
            Iterator<FunctionGradient.PartialDerivative> it = functionGradient.getNonZeroPartialDerivatives().iterator();
            while (it.hasNext()) {
                hashSet.add(Integer.valueOf(it.next().parameterId));
            }
        }
        return hashSet;
    }
}
