package burlap.behavior.singleagent.learnfromdemo.mlirl;

import burlap.behavior.functionapproximation.FunctionGradient;
import burlap.behavior.policy.BoltzmannQPolicy;
import burlap.behavior.singleagent.Episode;
import burlap.behavior.singleagent.learnfromdemo.CustomRewardModel;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.BoltzmannPolicyGradient;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableQFunction;
import burlap.behavior.singleagent.learnfromdemo.mlirl.support.DifferentiableRF;
import burlap.behavior.valuefunction.QProvider;
import burlap.datastructures.HashedAggregator;
import burlap.debugtools.DPrint;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/mlirl/MLIRL.class */
public class MLIRL {
    protected MLIRLRequest request;
    protected double learningRate;
    protected double maxLikelihoodChange;
    protected int maxSteps;
    protected int debugCode = 625420;

    public MLIRL(MLIRLRequest mLIRLRequest, double d, double d2, int i) {
        this.request = mLIRLRequest;
        this.learningRate = d;
        this.maxLikelihoodChange = d2;
        this.maxSteps = i;
        if (!mLIRLRequest.isValid()) {
            throw new RuntimeException("Provided MLIRLRequest object is not valid.");
        }
    }

    public void setRequest(MLIRLRequest mLIRLRequest) {
        this.request = mLIRLRequest;
    }

    public void toggleDebugPrinting(boolean z) {
        DPrint.toggleCode(this.debugCode, z);
        this.request.getPlanner().toggleDebugPrinting(z);
    }

    public int getDebugCode() {
        return this.debugCode;
    }

    public void setDebugCode(int i) {
        this.debugCode = i;
    }

    public void performIRL() {
        DifferentiableRF rf = this.request.getRf();
        this.request.getPlanner().resetSolver();
        this.request.getPlanner().setModel(new CustomRewardModel(this.request.getDomain().getModel(), rf));
        double logLikelihood = logLikelihood();
        DPrint.cl(this.debugCode, "RF: " + this.request.getRf().toString());
        DPrint.cl(this.debugCode, "Log likelihood: " + logLikelihood);
        int i = 0;
        while (true) {
            if (i >= this.maxSteps && this.maxSteps != -1) {
                break;
            }
            double d = 0.0d;
            for (FunctionGradient.PartialDerivative partialDerivative : logLikelihoodGradient().getNonZeroPartialDerivatives()) {
                double parameter = rf.getParameter(partialDerivative.parameterId);
                double d2 = parameter + (this.learningRate * partialDerivative.value);
                rf.setParameter(partialDerivative.parameterId, d2);
                d = Math.max(d, Math.abs(parameter - d2));
            }
            this.request.getPlanner().resetSolver();
            this.request.getPlanner().setModel(new CustomRewardModel(this.request.getDomain().getModel(), rf));
            double logLikelihood2 = logLikelihood();
            double d3 = logLikelihood2 - logLikelihood;
            logLikelihood = logLikelihood2;
            DPrint.cl(this.debugCode, "RF: " + this.request.getRf().toString());
            DPrint.cl(this.debugCode, "Log likelihood: " + logLikelihood + " (change: " + d3 + ")");
            if (Math.abs(d3) < this.maxLikelihoodChange) {
                i++;
                break;
            }
            i++;
        }
        DPrint.cl(this.debugCode, "\nNum gradient ascent steps: " + i);
        DPrint.cl(this.debugCode, "RF: " + this.request.getRf().toString());
    }

    public double logLikelihood() {
        double[] episodeWeights = this.request.getEpisodeWeights();
        List<Episode> expertEpisodes = this.request.getExpertEpisodes();
        double d = 0.0d;
        for (int i = 0; i < expertEpisodes.size(); i++) {
            d += logLikelihoodOfTrajectory(expertEpisodes.get(i), episodeWeights[i]);
        }
        return d;
    }

    public double logLikelihoodOfTrajectory(Episode episode, double d) {
        double d2 = 0.0d;
        BoltzmannQPolicy boltzmannQPolicy = new BoltzmannQPolicy((QProvider) this.request.getPlanner(), 1.0d / this.request.getBoltzmannBeta());
        for (int i = 0; i < episode.numTimeSteps() - 1; i++) {
            this.request.getPlanner().planFromState(episode.state(i));
            d2 += Math.log(boltzmannQPolicy.actionProb(episode.state(i), episode.action(i)));
        }
        return d2 * d;
    }

    public FunctionGradient logLikelihoodGradient() {
        HashedAggregator hashedAggregator = new HashedAggregator();
        double[] episodeWeights = this.request.getEpisodeWeights();
        List<Episode> expertEpisodes = this.request.getExpertEpisodes();
        for (int i = 0; i < expertEpisodes.size(); i++) {
            Episode episode = expertEpisodes.get(i);
            double d = episodeWeights[i];
            for (int i2 = 0; i2 < episode.numTimeSteps() - 1; i2++) {
                this.request.getPlanner().planFromState(episode.state(i2));
                for (FunctionGradient.PartialDerivative partialDerivative : logPolicyGrad(episode.state(i2), episode.action(i2)).getNonZeroPartialDerivatives()) {
                    hashedAggregator.add(Integer.valueOf(partialDerivative.parameterId), partialDerivative.value * d);
                }
            }
        }
        FunctionGradient.SparseGradient sparseGradient = new FunctionGradient.SparseGradient(hashedAggregator.size());
        for (Map.Entry entry : hashedAggregator.entrySet()) {
            sparseGradient.put(((Integer) entry.getKey()).intValue(), ((Double) entry.getValue()).doubleValue());
        }
        return sparseGradient;
    }

    public FunctionGradient logPolicyGrad(State state, Action action) {
        double actionProb = 1.0d / new BoltzmannQPolicy((QProvider) this.request.getPlanner(), 1.0d / this.request.getBoltzmannBeta()).actionProb(state, action);
        FunctionGradient computeBoltzmannPolicyGradient = BoltzmannPolicyGradient.computeBoltzmannPolicyGradient(state, action, (DifferentiableQFunction) this.request.getPlanner(), this.request.getBoltzmannBeta());
        for (FunctionGradient.PartialDerivative partialDerivative : computeBoltzmannPolicyGradient.getNonZeroPartialDerivatives()) {
            computeBoltzmannPolicyGradient.put(partialDerivative.parameterId, partialDerivative.value * actionProb);
        }
        return computeBoltzmannPolicyGradient;
    }

    protected static void addToVector(double[] dArr, double[] dArr2) {
        for (int i = 0; i < dArr.length; i++) {
            int i2 = i;
            dArr[i2] = dArr[i2] + dArr2[i];
        }
    }
}
