package burlap.behavior.singleagent.learnfromdemo.apprenticeship;

import burlap.behavior.functionapproximation.dense.DenseStateFeatures;
import burlap.behavior.policy.EnumerablePolicy;
import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.policy.PolicyUtils;
import burlap.behavior.policy.RandomPolicy;
import burlap.behavior.policy.support.ActionProb;
import burlap.behavior.singleagent.Episode;
import burlap.behavior.singleagent.learnfromdemo.CustomRewardModel;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.singleagent.planning.deterministic.DDPlannerPolicy;
import burlap.behavior.singleagent.planning.deterministic.DeterministicPlanner;
import burlap.behavior.valuefunction.QProvider;
import burlap.debugtools.DPrint;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.action.ActionType;
import burlap.mdp.core.action.ActionUtils;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.model.RewardFunction;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import burlap.statehashing.simple.SimpleHashableStateFactory;
import com.joptimizer.functions.ConvexMultivariateRealFunction;
import com.joptimizer.functions.LinearMultivariateRealFunction;
import com.joptimizer.functions.PSDQuadraticMultivariateRealFunction;
import com.joptimizer.optimizers.JOptimizer;
import com.joptimizer.optimizers.OptimizationRequest;
import com.joptimizer.util.Utils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.apache.commons.math3.geometry.VectorFormat;

/* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/apprenticeship/ApprenticeshipLearning.class */
public class ApprenticeshipLearning {
    public static final int DEBUG_CODE_SCORE = 746329;
    public static final int DEBUG_CODE_RF_WEIGHTS = 636392;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/apprenticeship/ApprenticeshipLearning$FeatureWeights.class */
    public static class FeatureWeights {
        private double[] weights;
        private double score;

        private FeatureWeights(double[] dArr, double d) {
            this.weights = (double[]) dArr.clone();
            this.score = d;
        }

        public FeatureWeights(FeatureWeights featureWeights) {
            this.weights = featureWeights.getWeights();
            this.score = featureWeights.getScore().doubleValue();
        }

        public double[] getWeights() {
            return (double[]) this.weights.clone();
        }

        public Double getScore() {
            return Double.valueOf(this.score);
        }
    }

    /* loaded from: input_file:burlap/behavior/singleagent/learnfromdemo/apprenticeship/ApprenticeshipLearning$StationaryRandomDistributionPolicy.class */
    public static class StationaryRandomDistributionPolicy implements EnumerablePolicy {
        Map<HashableState, Action> stateActionMapping;
        List<ActionType> actionTypes;
        Map<HashableState, List<ActionProb>> stateActionDistributionMapping;
        HashableStateFactory hashFactory;
        Random rando;

        private StationaryRandomDistributionPolicy(SADomain sADomain) {
            this.stateActionMapping = new HashMap();
            this.stateActionDistributionMapping = new HashMap();
            this.actionTypes = sADomain.getActionTypes();
            this.rando = new Random();
            this.hashFactory = new SimpleHashableStateFactory(true);
        }

        public static Policy generateRandomPolicy(SADomain sADomain) {
            return new RandomPolicy(sADomain);
        }

        private void addNewDistributionForState(State state) {
            HashableState hashState = this.hashFactory.hashState(state);
            List<Action> allApplicableActionsForTypes = ActionUtils.allApplicableActionsForTypes(this.actionTypes, state);
            Double[] dArr = new Double[allApplicableActionsForTypes.size()];
            Double valueOf = Double.valueOf(0.0d);
            for (int i = 0; i < dArr.length; i++) {
                dArr[i] = Double.valueOf(this.rando.nextDouble());
                valueOf = Double.valueOf(valueOf.doubleValue() + dArr[i].doubleValue());
            }
            ArrayList arrayList = new ArrayList(allApplicableActionsForTypes.size());
            for (int i2 = 0; i2 < dArr.length; i2++) {
                arrayList.add(new ActionProb(allApplicableActionsForTypes.get(i2), dArr[i2].doubleValue() / valueOf.doubleValue()));
            }
            this.stateActionDistributionMapping.put(hashState, arrayList);
        }

        @Override // burlap.behavior.policy.Policy
        public Action action(State state) {
            HashableState hashState = this.hashFactory.hashState(state);
            if (!this.stateActionDistributionMapping.containsKey(hashState)) {
                addNewDistributionForState(state);
            }
            List<ActionProb> list = this.stateActionDistributionMapping.get(hashState);
            Double valueOf = Double.valueOf(this.rando.nextDouble());
            Double valueOf2 = Double.valueOf(0.0d);
            for (ActionProb actionProb : list) {
                valueOf2 = Double.valueOf(valueOf2.doubleValue() + actionProb.pSelection);
                if (valueOf2.doubleValue() >= valueOf.doubleValue()) {
                    return actionProb.ga;
                }
            }
            return null;
        }

        @Override // burlap.behavior.policy.Policy
        public double actionProb(State state, Action action) {
            return PolicyUtils.actionProbFromEnum(this, state, action);
        }

        @Override // burlap.behavior.policy.EnumerablePolicy
        public List<ActionProb> policyDistribution(State state) {
            HashableState hashState = this.hashFactory.hashState(state);
            if (!this.stateActionDistributionMapping.containsKey(hashState)) {
                addNewDistributionForState(state);
            }
            return new ArrayList(this.stateActionDistributionMapping.get(hashState));
        }

        @Override // burlap.behavior.policy.Policy
        public boolean definedFor(State state) {
            return true;
        }
    }

    private ApprenticeshipLearning() {
    }

    public static double[] estimateFeatureExpectation(Episode episode, DenseStateFeatures denseStateFeatures, Double d) {
        return estimateFeatureExpectation((List<Episode>) Arrays.asList(episode), denseStateFeatures, d);
    }

    public static double[] estimateFeatureExpectation(List<Episode> list, DenseStateFeatures denseStateFeatures, Double d) {
        double[] dArr = null;
        for (Episode episode : list) {
            for (int i = 0; i < episode.stateSequence.size(); i++) {
                double[] features = denseStateFeatures.features(episode.stateSequence.get(i));
                if (dArr == null) {
                    dArr = new double[features.length];
                }
                for (int i2 = 0; i2 < dArr.length; i2++) {
                    if (features[i2] != 0.0d) {
                        double[] dArr2 = dArr;
                        int i3 = i2;
                        dArr2[i3] = dArr2[i3] + (features[i2] * Math.pow(d.doubleValue(), i));
                    }
                }
            }
        }
        for (int i4 = 0; i4 < dArr.length; i4++) {
            double[] dArr3 = dArr;
            int i5 = i4;
            dArr3[i5] = dArr3[i5] / list.size();
        }
        return dArr;
    }

    public static RewardFunction generateRewardFunction(final DenseStateFeatures denseStateFeatures, FeatureWeights featureWeights) {
        final FeatureWeights featureWeights2 = new FeatureWeights(featureWeights);
        return new RewardFunction() { // from class: burlap.behavior.singleagent.learnfromdemo.apprenticeship.ApprenticeshipLearning.1
            @Override // burlap.mdp.singleagent.model.RewardFunction
            public double reward(State state, Action action, State state2) {
                double[] weights = FeatureWeights.this.getWeights();
                double d = 0.0d;
                double[] features = denseStateFeatures.features(state);
                for (int i = 0; i < features.length; i++) {
                    d += weights[i] * features[i];
                }
                return d;
            }
        };
    }

    public static State getInitialState(List<Episode> list) {
        return list.get(new Random().nextInt(list.size())).state(0);
    }

    public static Policy getLearnedPolicy(ApprenticeshipLearningRequest apprenticeshipLearningRequest) {
        if (apprenticeshipLearningRequest.isValid()) {
            return apprenticeshipLearningRequest.getUsingMaxMargin() ? maxMarginMethod(apprenticeshipLearningRequest) : projectionMethod(apprenticeshipLearningRequest);
        }
        return null;
    }

    private static Policy maxMarginMethod(ApprenticeshipLearningRequest apprenticeshipLearningRequest) {
        FeatureWeights featureWeights;
        int i = 0;
        List<Episode> expertEpisodes = apprenticeshipLearningRequest.getExpertEpisodes();
        Iterator<Episode> it = expertEpisodes.iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().numTimeSteps());
        }
        Planner planner = apprenticeshipLearningRequest.getPlanner();
        HashableStateFactory hashingFactory = planner.getHashingFactory();
        SADomain domain = apprenticeshipLearningRequest.getDomain();
        Policy stationaryRandomDistributionPolicy = new StationaryRandomDistributionPolicy(domain);
        DenseStateFeatures featureGenerator = apprenticeshipLearningRequest.getFeatureGenerator();
        ArrayList arrayList = new ArrayList();
        double[] estimateFeatureExpectation = estimateFeatureExpectation(expertEpisodes, featureGenerator, Double.valueOf(apprenticeshipLearningRequest.getGamma()));
        arrayList.add(estimateFeatureExpectation(PolicyUtils.rollout(stationaryRandomDistributionPolicy, apprenticeshipLearningRequest.getStartStateGenerator().generateState(), apprenticeshipLearningRequest.getPlanner().getModel(), i), featureGenerator, Double.valueOf(apprenticeshipLearningRequest.getGamma())));
        int maxIterations = apprenticeshipLearningRequest.getMaxIterations();
        double[] dArr = new double[maxIterations];
        int policyCount = apprenticeshipLearningRequest.getPolicyCount();
        for (int i2 = 0; i2 < maxIterations; i2++) {
            FeatureWeights featureWeights2 = null;
            while (true) {
                featureWeights = featureWeights2;
                if (featureWeights != null) {
                    break;
                }
                featureWeights2 = solveFeatureWeights(estimateFeatureExpectation, arrayList);
            }
            for (int i3 = 0; i3 < featureWeights.weights.length; i3++) {
                DPrint.c(DEBUG_CODE_RF_WEIGHTS, i3 + ": " + featureWeights.weights[i3] + VectorFormat.DEFAULT_SEPARATOR);
            }
            DPrint.cl(DEBUG_CODE_RF_WEIGHTS, "");
            if (featureWeights == null || Math.abs(featureWeights.getScore().doubleValue()) <= apprenticeshipLearningRequest.getEpsilon()) {
                apprenticeshipLearningRequest.setTHistory(dArr);
                return stationaryRandomDistributionPolicy;
            }
            dArr[i2] = featureWeights.getScore().doubleValue();
            DPrint.cl(DEBUG_CODE_SCORE, "Score: " + dArr[i2]);
            CustomRewardModel customRewardModel = new CustomRewardModel(domain.getModel(), generateRewardFunction(featureGenerator, featureWeights));
            planner.resetSolver();
            planner.solverInit(domain, apprenticeshipLearningRequest.getGamma(), hashingFactory);
            planner.setModel(customRewardModel);
            planner.planFromState(apprenticeshipLearningRequest.getStartStateGenerator().generateState());
            if (planner instanceof DeterministicPlanner) {
                stationaryRandomDistributionPolicy = new DDPlannerPolicy((DeterministicPlanner) planner);
            } else if (planner instanceof QProvider) {
                stationaryRandomDistributionPolicy = new GreedyQPolicy((QProvider) planner);
            }
            ArrayList arrayList2 = new ArrayList();
            for (int i4 = 0; i4 < policyCount; i4++) {
                arrayList2.add(PolicyUtils.rollout(stationaryRandomDistributionPolicy, apprenticeshipLearningRequest.getStartStateGenerator().generateState(), customRewardModel, i));
            }
            arrayList.add(estimateFeatureExpectation(arrayList2, featureGenerator, Double.valueOf(apprenticeshipLearningRequest.getGamma())));
        }
        apprenticeshipLearningRequest.setTHistory(dArr);
        return stationaryRandomDistributionPolicy;
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v68, types: [burlap.behavior.policy.GreedyQPolicy] */
    /* JADX WARN: Type inference failed for: r0v80, types: [burlap.behavior.singleagent.planning.deterministic.DDPlannerPolicy] */
    private static Policy projectionMethod(ApprenticeshipLearningRequest apprenticeshipLearningRequest) {
        int i = 0;
        List<Episode> expertEpisodes = apprenticeshipLearningRequest.getExpertEpisodes();
        Iterator<Episode> it = expertEpisodes.iterator();
        while (it.hasNext()) {
            i = Math.max(i, it.next().numTimeSteps());
        }
        Planner planner = apprenticeshipLearningRequest.getPlanner();
        HashableStateFactory hashingFactory = planner.getHashingFactory();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        DenseStateFeatures featureGenerator = apprenticeshipLearningRequest.getFeatureGenerator();
        double[] estimateFeatureExpectation = estimateFeatureExpectation(expertEpisodes, featureGenerator, Double.valueOf(apprenticeshipLearningRequest.getGamma()));
        SADomain domain = apprenticeshipLearningRequest.getDomain();
        StationaryRandomDistributionPolicy stationaryRandomDistributionPolicy = new StationaryRandomDistributionPolicy(domain);
        arrayList.add(stationaryRandomDistributionPolicy);
        ArrayList arrayList3 = new ArrayList();
        for (int i2 = 0; i2 < apprenticeshipLearningRequest.getPolicyCount(); i2++) {
            arrayList3.add(PolicyUtils.rollout(stationaryRandomDistributionPolicy, apprenticeshipLearningRequest.getStartStateGenerator().generateState(), domain.getModel(), i));
        }
        double[] estimateFeatureExpectation2 = estimateFeatureExpectation(arrayList3, featureGenerator, Double.valueOf(apprenticeshipLearningRequest.getGamma()));
        arrayList2.add(estimateFeatureExpectation2);
        double[] dArr = null;
        int maxIterations = apprenticeshipLearningRequest.getMaxIterations();
        double[] dArr2 = new double[maxIterations];
        int policyCount = apprenticeshipLearningRequest.getPolicyCount();
        for (int i3 = 0; i3 < maxIterations; i3++) {
            double[] projectExpertFE = dArr == null ? (double[]) estimateFeatureExpectation2.clone() : projectExpertFE(estimateFeatureExpectation, estimateFeatureExpectation2, dArr);
            FeatureWeights weightsProjectionMethod = getWeightsProjectionMethod(estimateFeatureExpectation, projectExpertFE);
            dArr2[i3] = weightsProjectionMethod.getScore().doubleValue();
            DPrint.cl(DEBUG_CODE_SCORE, "Score: " + dArr2[i3]);
            dArr = projectExpertFE;
            if (weightsProjectionMethod.getScore().doubleValue() <= apprenticeshipLearningRequest.getEpsilon()) {
                return stationaryRandomDistributionPolicy;
            }
            for (int i4 = 0; i4 < weightsProjectionMethod.weights.length; i4++) {
                DPrint.c(DEBUG_CODE_RF_WEIGHTS, i4 + ": " + weightsProjectionMethod.weights[i4] + VectorFormat.DEFAULT_SEPARATOR);
            }
            DPrint.cl(DEBUG_CODE_RF_WEIGHTS, "");
            CustomRewardModel customRewardModel = new CustomRewardModel(domain.getModel(), generateRewardFunction(featureGenerator, weightsProjectionMethod));
            planner.resetSolver();
            planner.solverInit(domain, apprenticeshipLearningRequest.getGamma(), hashingFactory);
            planner.setModel(customRewardModel);
            planner.planFromState(apprenticeshipLearningRequest.getStartStateGenerator().generateState());
            if (planner instanceof DeterministicPlanner) {
                stationaryRandomDistributionPolicy = new DDPlannerPolicy((DeterministicPlanner) planner);
            } else if (planner instanceof QProvider) {
                stationaryRandomDistributionPolicy = new GreedyQPolicy((QProvider) planner);
            }
            arrayList.add(stationaryRandomDistributionPolicy);
            ArrayList arrayList4 = new ArrayList();
            for (int i5 = 0; i5 < policyCount; i5++) {
                arrayList4.add(PolicyUtils.rollout(stationaryRandomDistributionPolicy, apprenticeshipLearningRequest.getStartStateGenerator().generateState(), customRewardModel, i));
            }
            estimateFeatureExpectation2 = estimateFeatureExpectation(arrayList4, featureGenerator, Double.valueOf(apprenticeshipLearningRequest.getGamma()));
            arrayList2.add(estimateFeatureExpectation2.clone());
        }
        apprenticeshipLearningRequest.setTHistory(dArr2);
        return stationaryRandomDistributionPolicy;
    }

    private static FeatureWeights solveFeatureWeights(double[] dArr, List<double[]> list) {
        int length = dArr.length;
        double[] dArr2 = new double[length + 1];
        dArr2[length] = -1.0d;
        LinearMultivariateRealFunction linearMultivariateRealFunction = new LinearMultivariateRealFunction(dArr2, 0.0d);
        ArrayList arrayList = new ArrayList();
        for (double[] dArr3 : list) {
            double[] dArr4 = new double[length + 1];
            for (int i = 0; i < dArr3.length; i++) {
                dArr4[i] = dArr3[i] - dArr[i];
            }
            dArr4[length] = 1.0d;
            arrayList.add(new LinearMultivariateRealFunction(dArr4, 1.0d));
        }
        double[][] createConstantDiagonalMatrix = Utils.createConstantDiagonalMatrix(length + 1, 1.0d);
        createConstantDiagonalMatrix[length][length] = 0.0d;
        arrayList.add(new PSDQuadraticMultivariateRealFunction(createConstantDiagonalMatrix, null, -0.5d));
        OptimizationRequest optimizationRequest = new OptimizationRequest();
        optimizationRequest.setF0(linearMultivariateRealFunction);
        optimizationRequest.setFi((ConvexMultivariateRealFunction[]) arrayList.toArray(new ConvexMultivariateRealFunction[arrayList.size()]));
        optimizationRequest.setCheckKKTSolutionAccuracy(false);
        optimizationRequest.setTolerance(1.0E-12d);
        optimizationRequest.setToleranceFeas(1.0E-12d);
        JOptimizer jOptimizer = new JOptimizer();
        jOptimizer.setOptimizationRequest(optimizationRequest);
        try {
            jOptimizer.optimize();
            double[] solution = jOptimizer.getOptimizationResponse().getSolution();
            return new FeatureWeights(Arrays.copyOfRange(solution, 0, length), solution[length]);
        } catch (Exception e) {
            System.out.println(e);
            return null;
        }
    }

    private static double[] projectExpertFE(double[] dArr, double[] dArr2, double[] dArr3) {
        double[] dArr4 = new double[dArr3.length];
        double d = 0.0d;
        double d2 = 0.0d;
        for (int i = 0; i < dArr4.length; i++) {
            d += (dArr2[i] - dArr3[i]) * (dArr[i] - dArr3[i]);
            d2 += (dArr2[i] - dArr3[i]) * (dArr2[i] - dArr3[i]);
        }
        double d3 = d / d2;
        for (int i2 = 0; i2 < dArr4.length; i2++) {
            dArr4[i2] = dArr3[i2] + ((dArr2[i2] - dArr3[i2]) * d3);
        }
        return dArr4;
    }

    private static FeatureWeights getWeightsProjectionMethod(double[] dArr, double[] dArr2) {
        double[] dArr3 = new double[dArr2.length];
        for (int i = 0; i < dArr3.length; i++) {
            dArr3[i] = dArr[i] - dArr2[i];
        }
        double d = 0.0d;
        for (double d2 : dArr3) {
            d += d2 * d2;
        }
        return new FeatureWeights(dArr3, Math.sqrt(d));
    }
}
