package burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners;

import burlap.behavior.functionapproximation.FunctionGradient;
import burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners.dpoperator.DifferentiableDPOperator;
import burlap.behavior.singleagent.learnfromdemo.mlirl.differentiableplanners.dpoperator.DifferentiableSoftmaxOperator;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableQFunction;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableRF;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableValueFunction;
import burlap.behavior.singleagent.planning.stochastic.DynamicProgramming;
import burlap.behavior.singleagent.planning.stochastic.dpoperator.DPOperator;
import burlap.behavior.valuefunction.QValue;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.model.FullModel;
import burlap.mdp.singleagent.model.TransitionProb;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/mlirl/differentiableplanners/DifferentiableDP.class */
public abstract class DifferentiableDP extends DynamicProgramming implements DifferentiableQFunction, DifferentiableValueFunction {
    protected Map<HashableState, FunctionGradient> valueGradient = new HashMap();
    protected DifferentiableRF rf;

    @Override // burlap.behavior.singleagent.planning.stochastic.DynamicProgramming
    public void DPPInit(SADomain sADomain, double d, HashableStateFactory hashableStateFactory) {
        super.DPPInit(sADomain, d, hashableStateFactory);
        this.operator = new DifferentiableSoftmaxOperator();
    }

    @Override // burlap.behavior.singleagent.planning.stochastic.DynamicProgramming, burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void resetSolver() {
        super.resetSolver();
        this.valueGradient.clear();
    }

    @Override // burlap.behavior.singleagent.planning.stochastic.DynamicProgramming
    public void setOperator(DPOperator dPOperator) {
        if (!(dPOperator instanceof DifferentiableDPOperator)) {
            throw new RuntimeException("DPOperator must be a DifferentiableDPOperator");
        }
        this.operator = dPOperator;
    }

    @Override // burlap.behavior.singleagent.planning.stochastic.DynamicProgramming
    public DifferentiableDPOperator getOperator() {
        return (DifferentiableDPOperator) this.operator;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public FunctionGradient performDPValueGradientUpdateOn(HashableState hashableState) {
        List<QValue> qValues = qValues(hashableState.s());
        double[] dArr = new double[qValues.size()];
        for (int i = 0; i < qValues.size(); i++) {
            dArr[i] = qValues.get(i).q;
        }
        FunctionGradient[] functionGradientArr = new FunctionGradient[dArr.length];
        for (int i2 = 0; i2 < dArr.length; i2++) {
            functionGradientArr[i2] = qGradient(hashableState.s(), qValues.get(i2).a);
        }
        FunctionGradient gradient = ((DifferentiableDPOperator) this.operator).gradient(dArr, functionGradientArr);
        this.valueGradient.put(hashableState, gradient);
        return gradient;
    }

    @Override // burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableValueFunction
    public FunctionGradient valueGradient(State state) {
        FunctionGradient functionGradient = this.valueGradient.get(this.hashingFactory.hashState(state));
        if (functionGradient == null) {
            functionGradient = new FunctionGradient.SparseGradient();
        }
        return functionGradient;
    }

    @Override // burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableQFunction
    public FunctionGradient qGradient(State state, Action action) {
        return computeQGradient(state, action);
    }

    protected FunctionGradient computeQGradient(State state, Action action) {
        FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient();
        for (TransitionProb transitionProb : ((FullModel) this.model).transitions(state, action)) {
            FunctionGradient valueGradient = valueGradient(transitionProb.eo.op);
            FunctionGradient gradient = this.rf.gradient(state, action, transitionProb.eo.op);
            Iterator<Integer> it = combinedNonZeroPDParameters(valueGradient, gradient).iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                sparseGradient.put(intValue, sparseGradient.getPartialDerivative(intValue) + (transitionProb.p * (gradient.getPartialDerivative(intValue) + (this.gamma * valueGradient.getPartialDerivative(intValue)))));
            }
        }
        return sparseGradient;
    }

    protected Set<Integer> combinedNonZeroPDParameters(FunctionGradient... functionGradientArr) {
        HashSet hashSet = new HashSet();
        for (FunctionGradient functionGradient : functionGradientArr) {
            Iterator<FunctionGradient.PartialDerivative> it = functionGradient.getNonZeroPartialDerivatives().iterator();
            while (it.hasNext()) {
                hashSet.add(Integer.valueOf(it.next().parameterId));
            }
        }
        return hashSet;
    }
}
