package burlap.behavior.singleagent.learnfromdemo.mlirl.support;

import burlap.behavior.functionapproximation.FunctionGradient;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/mlirl/support/BoltzmannPolicyGradient.class */
public class BoltzmannPolicyGradient {
    private BoltzmannPolicyGradient() {
    }

    public static FunctionGradient computeBoltzmannPolicyGradient(State state, Action action, DifferentiableQFunction differentiableQFunction, double d) {
        List<QValue> qValues = ((QProvider) differentiableQFunction).qValues(state);
        double[] dArr = new double[qValues.size()];
        for (int i = 0; i < qValues.size(); i++) {
            dArr[i] = qValues.get(i).q;
        }
        int i2 = -1;
        int i3 = 0;
        while (true) {
            if (i3 >= qValues.size()) {
                break;
            }
            if (qValues.get(i3).a.equals(action)) {
                i2 = i3;
                break;
            }
            i3++;
        }
        if (i2 == -1) {
            throw new RuntimeException("Error in computing BoltzmannPolicyGradient: Could not find query action in Q-value list.");
        }
        FunctionGradient[] functionGradientArr = new FunctionGradient[dArr.length];
        for (int i4 = 0; i4 < dArr.length; i4++) {
            functionGradientArr[i4] = differentiableQFunction.qGradient(state, qValues.get(i4).a);
        }
        double maxBetaScaled = maxBetaScaled(dArr, d);
        return computePolicyGradient(d, dArr, maxBetaScaled, logSum(dArr, maxBetaScaled, d), functionGradientArr, i2);
    }

    public static FunctionGradient computePolicyGradient(double d, double[] dArr, double d2, double d3, FunctionGradient[] functionGradientArr, int i) {
        FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient();
        double exp = d * Math.exp((((d * dArr[i]) + d2) - d3) - d3);
        Set<Integer> combinedNonZeroPDParameters = combinedNonZeroPDParameters(functionGradientArr);
        for (int i2 = 0; i2 < dArr.length; i2++) {
            Iterator<Integer> it = combinedNonZeroPDParameters.iterator();
            while (it.hasNext()) {
                int intValue = it.next().intValue();
                sparseGradient.put(intValue, sparseGradient.getPartialDerivative(intValue) + ((functionGradientArr[i].getPartialDerivative(intValue) - functionGradientArr[i2].getPartialDerivative(intValue)) * Math.exp((d * dArr[i2]) - d2)));
            }
        }
        FunctionGradient.SparseGradient sparseGradient2 = new FunctionGradient.SparseGradient(sparseGradient.numNonZeroPDs());
        for (FunctionGradient.PartialDerivative partialDerivative : sparseGradient.getNonZeroPartialDerivatives()) {
            sparseGradient2.put(partialDerivative.parameterId, partialDerivative.value * exp);
        }
        return sparseGradient2;
    }

    public static double maxBetaScaled(double[] dArr, double d) {
        double d2 = Double.NEGATIVE_INFINITY;
        for (double d3 : dArr) {
            if (d3 > d2) {
                d2 = d3;
            }
        }
        return d * d2;
    }

    public static double logSum(double[] dArr, double d, double d2) {
        double d3 = 0.0d;
        for (double d4 : dArr) {
            d3 += Math.exp((d2 * d4) - d);
        }
        return d + Math.log(d3);
    }

    protected static Set<Integer> combinedNonZeroPDParameters(FunctionGradient... functionGradientArr) {
        HashSet hashSet = new HashSet();
        for (FunctionGradient functionGradient : functionGradientArr) {
            Iterator<FunctionGradient.PartialDerivative> it = functionGradient.getNonZeroPartialDerivatives().iterator();
            while (it.hasNext()) {
                hashSet.add(Integer.valueOf(it.next().parameterId));
            }
        }
        return hashSet;
    }
}
