package burlap.behavior.singleagent.learnfromdemo.mlirl;

import burlap.behavior.singleagent.Episode;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableRF;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.QGradientPlannerFactory;
import burlap.behavior.singleagent.planning.Planner;
import burlap.debugtools.DPrint;
import burlap.debugtools.RandomFactory;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Random;
import org.eclipse.jetty.util.URIUtil;

/* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/mlirl/MultipleIntentionsMLIRL.class */
public class MultipleIntentionsMLIRL {
    protected MultipleIntentionsMLIRLRequest request;
    protected List<MLIRLRequest> clusterRequests;
    protected double[] clusterPriors;
    protected MLIRL mlirlInstance;
    protected int numEMIterations;
    protected int debugCode = 13435;
    protected Random rand = RandomFactory.getMapped(0);

    public MultipleIntentionsMLIRL(MultipleIntentionsMLIRLRequest multipleIntentionsMLIRLRequest, int i, double d, double d2, int i2) {
        if (!multipleIntentionsMLIRLRequest.isValid()) {
            throw new RuntimeException("Provided MultipleIntentionsMLIRLRequest object is not valid.");
        }
        this.request = multipleIntentionsMLIRLRequest;
        initializeClusters(this.request.getK(), this.request.getPlannerFactory());
        this.numEMIterations = i;
        this.mlirlInstance = new MLIRL(multipleIntentionsMLIRLRequest, d, d2, i2);
    }

    public void performIRL() {
        int length = this.clusterPriors.length;
        for (int i = 0; i < this.numEMIterations; i++) {
            DPrint.cl(this.debugCode, "Starting EM iteration " + (i + 1) + URIUtil.SLASH + this.numEMIterations);
            double[][] computePerClusterMLIRLWeights = computePerClusterMLIRLWeights();
            for (int i2 = 0; i2 < length; i2++) {
                MLIRLRequest mLIRLRequest = this.clusterRequests.get(i2);
                mLIRLRequest.setEpisodeWeights((double[]) computePerClusterMLIRLWeights[i2].clone());
                this.mlirlInstance.setRequest(mLIRLRequest);
                this.mlirlInstance.performIRL();
            }
        }
        DPrint.cl(this.debugCode, "Finished EM");
    }

    public double[] computeProbabilityOfClustersGivenTrajectory(Episode episode) {
        int length = this.clusterPriors.length;
        double[] dArr = new double[length];
        double d = Double.NEGATIVE_INFINITY;
        for (int i = 0; i < length; i++) {
            double log = Math.log(this.clusterPriors[i]);
            this.mlirlInstance.setRequest(this.clusterRequests.get(i));
            double logLikelihoodOfTrajectory = this.mlirlInstance.logLikelihoodOfTrajectory(episode, 1.0d) + log;
            dArr[i] = logLikelihoodOfTrajectory;
            d = Math.max(d, logLikelihoodOfTrajectory);
        }
        double d2 = 0.0d;
        for (int i2 = 0; i2 < length; i2++) {
            d2 += Math.exp(dArr[i2] - d);
        }
        double log2 = d + Math.log(d2);
        for (int i3 = 0; i3 < length; i3++) {
            dArr[i3] = Math.exp(dArr[i3] - log2);
        }
        return dArr;
    }

    public List<DifferentiableRF> getClusterRFs() {
        ArrayList arrayList = new ArrayList(this.clusterPriors.length);
        Iterator<MLIRLRequest> it = this.clusterRequests.iterator();
        while (it.hasNext()) {
            arrayList.add(it.next().getRf());
        }
        return arrayList;
    }

    public double[] getClusterPriors() {
        return this.clusterPriors;
    }

    public void toggleDebugPrinting(boolean z) {
        DPrint.toggleCode(this.debugCode, z);
        this.mlirlInstance.toggleDebugPrinting(z);
    }

    public int getDebugCode() {
        return this.debugCode;
    }

    public void setDebugCode(int i) {
        this.debugCode = i;
    }

    protected double[][] computePerClusterMLIRLWeights() {
        int length = this.clusterPriors.length;
        int size = this.request.getExpertEpisodes().size();
        double[][] dArr = new double[length][size];
        for (int i = 0; i < length; i++) {
            double log = Math.log(this.clusterPriors[i]);
            this.mlirlInstance.setRequest(this.clusterRequests.get(i));
            for (int i2 = 0; i2 < size; i2++) {
                dArr[i][i2] = log + this.mlirlInstance.logLikelihoodOfTrajectory(this.request.getExpertEpisodes().get(i2), 1.0d);
            }
        }
        double d = 0.0d;
        for (int i3 = 0; i3 < size; i3++) {
            double computeClusterTrajectoryLoggedNormalization = computeClusterTrajectoryLoggedNormalization(i3, dArr);
            for (int i4 = 0; i4 < length; i4++) {
                double exp = Math.exp(dArr[i4][i3] - computeClusterTrajectoryLoggedNormalization);
                dArr[i4][i3] = exp;
                d += exp;
            }
        }
        for (int i5 = 0; i5 < length; i5++) {
            double d2 = 0.0d;
            for (int i6 = 0; i6 < size; i6++) {
                d2 += dArr[i5][i6];
            }
            this.clusterPriors[i5] = d2 / d;
        }
        return dArr;
    }

    protected double computeClusterTrajectoryLoggedNormalization(int i, double[][] dArr) {
        double d = Double.NEGATIVE_INFINITY;
        for (double[] dArr2 : dArr) {
            d = Math.max(d, dArr2[i]);
        }
        double d2 = 0.0d;
        for (double[] dArr3 : dArr) {
            d2 += Math.exp(dArr3[i] - d);
        }
        return d + Math.log(d2);
    }

    protected void initializeClusters(int i, QGradientPlannerFactory qGradientPlannerFactory) {
        ArrayList arrayList = new ArrayList(i);
        for (int i2 = 0; i2 < i; i2++) {
            arrayList.add((DifferentiableRF) this.request.getRf().copy());
        }
        initializeClusterRFParameters(arrayList);
        this.clusterRequests = new ArrayList(i);
        this.clusterPriors = new double[i];
        double d = 1.0d / i;
        for (int i3 = 0; i3 < i; i3++) {
            this.clusterPriors[i3] = d;
            MLIRLRequest mLIRLRequest = new MLIRLRequest(this.request.getDomain(), (Planner) null, this.request.getExpertEpisodes(), arrayList.get(i3));
            mLIRLRequest.setGamma(this.request.getGamma());
            mLIRLRequest.setBoltzmannBeta(this.request.getBoltzmannBeta());
            mLIRLRequest.setPlanner((Planner) qGradientPlannerFactory.generateDifferentiablePlannerForRequest(mLIRLRequest));
            this.clusterRequests.add(mLIRLRequest);
        }
    }

    protected void initializeClusterRFParameters(List<DifferentiableRF> list) {
        Iterator<DifferentiableRF> it = list.iterator();
        while (it.hasNext()) {
            randomizeParameters(it.next());
        }
    }

    protected void randomizeParameters(DifferentiableRF differentiableRF) {
        for (int i = 0; i < differentiableRF.numParameters(); i++) {
            differentiableRF.setParameter(i, (this.rand.nextDouble() * 2.0d) - 1.0d);
        }
    }

    protected void randomizeParameters(double[] dArr) {
        for (int i = 0; i < dArr.length; i++) {
            dArr[i] = (this.rand.nextDouble() * 2.0d) - 1.0d;
        }
    }
}
