package burlap.behavior.singleagent.learning.lspi;

import burlap.behavior.functionapproximation.dense.DenseStateActionFeatures;
import burlap.behavior.functionapproximation.dense.DenseStateActionLinearVFA;
import burlap.behavior.policy.EpsilonGreedy;
import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.policy.PolicyUtils;
import burlap.behavior.singleagent.Episode;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.singleagent.learning.lspi.SARSCollector;
import burlap.behavior.singleagent.learning.lspi.SARSData;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.debugtools.DPrint;
import burlap.mdp.auxiliary.common.ConstantStateGenerator;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.Environment;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
import org.ejml.simple.SimpleMatrix;

/* loaded from: input_file:burlap/behavior/singleagent/learning/lspi/LSPI.class */
public class LSPI extends MDPSolver implements QProvider, LearningAgent, Planner {
    protected DenseStateActionLinearVFA vfa;
    protected SARSData dataset;
    protected DenseStateActionFeatures saFeatures;
    protected SimpleMatrix lastWeights;
    protected SARSCollector planningCollector;
    protected Policy learningPolicy;
    protected int numEpisodesToStore;
    protected double identityScalar = 100.0d;
    protected int numSamplesForPlanning = 10000;
    protected double maxChange = 1.0E-6d;
    protected int maxNumPlanningIterations = 30;
    protected int maxLearningSteps = Integer.MAX_VALUE;
    protected int numStepsSinceLastLearningPI = 0;
    protected int minNewStepsForLearningPI = 100;
    protected LinkedList<Episode> episodeHistory = new LinkedList<>();

    /* JADX INFO: Access modifiers changed from: protected */
    /* loaded from: input_file:burlap/behavior/singleagent/learning/lspi/LSPI$SSFeatures.class */
    public class SSFeatures {
        public double[] sActionFeatures;
        public double[] sPrimeActionFeatures;

        public SSFeatures(double[] dArr, double[] dArr2) {
            this.sActionFeatures = dArr;
            this.sPrimeActionFeatures = dArr2;
        }
    }

    public LSPI(SADomain sADomain, double d, DenseStateActionFeatures denseStateActionFeatures) {
        solverInit(sADomain, d, null);
        this.saFeatures = denseStateActionFeatures;
        this.vfa = new DenseStateActionLinearVFA(denseStateActionFeatures, 0.0d);
        this.learningPolicy = new EpsilonGreedy(this, 0.1d);
    }

    public LSPI(SADomain sADomain, double d, DenseStateActionFeatures denseStateActionFeatures, SARSData sARSData) {
        solverInit(sADomain, d, null);
        this.saFeatures = denseStateActionFeatures;
        this.vfa = new DenseStateActionLinearVFA(denseStateActionFeatures, 0.0d);
        this.learningPolicy = new EpsilonGreedy(this, 0.1d);
        this.dataset = sARSData;
    }

    public void initializeForPlanning(int i) {
        this.numSamplesForPlanning = i;
    }

    public void initializeForPlanning(int i, SARSCollector sARSCollector) {
        this.numSamplesForPlanning = i;
        this.planningCollector = sARSCollector;
    }

    public void setDataset(SARSData sARSData) {
        this.dataset = sARSData;
    }

    public SARSData getDataset() {
        return this.dataset;
    }

    public DenseStateActionFeatures getSaFeatures() {
        return this.saFeatures;
    }

    public void setSaFeatures(DenseStateActionFeatures denseStateActionFeatures) {
        this.saFeatures = denseStateActionFeatures;
    }

    public double getIdentityScalar() {
        return this.identityScalar;
    }

    public void setIdentityScalar(double d) {
        this.identityScalar = d;
    }

    public int getNumSamplesForPlanning() {
        return this.numSamplesForPlanning;
    }

    public void setNumSamplesForPlanning(int i) {
        this.numSamplesForPlanning = i;
    }

    public SARSCollector getPlanningCollector() {
        return this.planningCollector;
    }

    public void setPlanningCollector(SARSCollector sARSCollector) {
        this.planningCollector = sARSCollector;
    }

    public int getMaxNumPlanningIterations() {
        return this.maxNumPlanningIterations;
    }

    public void setMaxNumPlanningIterations(int i) {
        this.maxNumPlanningIterations = i;
    }

    public Policy getLearningPolicy() {
        return this.learningPolicy;
    }

    public void setLearningPolicy(Policy policy) {
        this.learningPolicy = policy;
    }

    public int getMaxLearningSteps() {
        return this.maxLearningSteps;
    }

    public void setMaxLearningSteps(int i) {
        this.maxLearningSteps = i;
    }

    public int getMinNewStepsForLearningPI() {
        return this.minNewStepsForLearningPI;
    }

    public void setMinNewStepsForLearningPI(int i) {
        this.minNewStepsForLearningPI = i;
    }

    public double getMaxChange() {
        return this.maxChange;
    }

    public void setMaxChange(double d) {
        this.maxChange = d;
    }

    public SimpleMatrix LSTDQ() {
        GreedyQPolicy greedyQPolicy = new GreedyQPolicy(this);
        ArrayList arrayList = new ArrayList(this.dataset.size());
        int i = 0;
        for (SARSData.SARS sars : this.dataset.dataset) {
            SSFeatures sSFeatures = new SSFeatures(this.saFeatures.features(sars.s, sars.a), this.saFeatures.features(sars.sp, greedyQPolicy.action(sars.sp)));
            arrayList.add(sSFeatures);
            i = Math.max(i, sSFeatures.sActionFeatures.length);
        }
        SimpleMatrix scale = SimpleMatrix.identity(i).scale(this.identityScalar);
        SimpleMatrix simpleMatrix = new SimpleMatrix(i, 1);
        for (int i2 = 0; i2 < arrayList.size(); i2++) {
            SimpleMatrix phiConstructor = phiConstructor(((SSFeatures) arrayList.get(i2)).sActionFeatures, i);
            SimpleMatrix phiConstructor2 = phiConstructor(((SSFeatures) arrayList.get(i2)).sPrimeActionFeatures, i);
            double d = this.dataset.get(i2).r;
            scale = scale.minus(scale.mult(phiConstructor).mult(phiConstructor.minus(phiConstructor2.scale(this.gamma)).transpose()).mult(scale).scale(1.0d / (phiConstructor.minus(phiConstructor2.scale(this.gamma)).transpose().mult(scale).mult(phiConstructor).get(0) + 1.0d)));
            simpleMatrix = simpleMatrix.plus(phiConstructor.scale(d));
        }
        SimpleMatrix mult = scale.mult(simpleMatrix);
        this.vfa = this.vfa.copy();
        for (int i3 = 0; i3 < i; i3++) {
            this.vfa.setParameter(i3, mult.get(i3, 0));
        }
        return mult;
    }

    public GreedyQPolicy runPolicyIteration(int i, double d) {
        boolean z = false;
        for (int i2 = 0; i2 < i && !z; i2++) {
            SimpleMatrix LSTDQ = LSTDQ();
            double d2 = Double.POSITIVE_INFINITY;
            if (this.lastWeights != null) {
                d2 = this.lastWeights.minus(LSTDQ).normF();
                if (d2 <= d) {
                    z = true;
                }
            }
            this.lastWeights = LSTDQ;
            DPrint.cl(0, "Finished iteration: " + i2 + ". Weight change: " + d2);
        }
        DPrint.cl(0, "Finished Policy Iteration.");
        return new GreedyQPolicy(this);
    }

    protected SimpleMatrix phiConstructor(double[] dArr, int i) {
        return new SimpleMatrix(i, 1, true, dArr);
    }

    @Override // burlap.behavior.valuefunction.QProvider
    public List<QValue> qValues(State state) {
        List<Action> applicableActions = applicableActions(state);
        ArrayList arrayList = new ArrayList(applicableActions.size());
        for (Action action : applicableActions) {
            arrayList.add(new QValue(state, action, this.vfa.evaluate(state, action)));
        }
        return arrayList;
    }

    @Override // burlap.behavior.valuefunction.QFunction
    public double qValue(State state, Action action) {
        return this.vfa.evaluate(state, action);
    }

    @Override // burlap.behavior.valuefunction.ValueFunction
    public double value(State state) {
        return QProvider.Helper.maxQ(this, state);
    }

    @Override // burlap.behavior.singleagent.planning.Planner
    public GreedyQPolicy planFromState(State state) {
        if (this.model == null) {
            throw new RuntimeException("LSPI cannot execute planFromState because the reward function and/or terminal function for planning have not been set. Use the initializeForPlanning method to set them.");
        }
        if (this.planningCollector == null) {
            this.planningCollector = new SARSCollector.UniformRandomSARSCollector(this.actionTypes);
        }
        this.dataset = this.planningCollector.collectNInstances(new ConstantStateGenerator(state), this.model, this.numSamplesForPlanning, Integer.MAX_VALUE, this.dataset);
        return runPolicyIteration(this.maxNumPlanningIterations, this.maxChange);
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void resetSolver() {
        this.dataset.clear();
        this.vfa.resetParameters();
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public Episode runLearningEpisode(Environment environment) {
        return runLearningEpisode(environment, -1);
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public Episode runLearningEpisode(Environment environment, int i) {
        Episode rollout = i != -1 ? PolicyUtils.rollout(this.learningPolicy, environment, i) : PolicyUtils.rollout(this.learningPolicy, environment);
        updateDatasetWithLearningEpisode(rollout);
        if (shouldRereunPolicyIteration(rollout)) {
            runPolicyIteration(this.maxNumPlanningIterations, this.maxChange);
            this.numStepsSinceLastLearningPI = 0;
        } else {
            this.numStepsSinceLastLearningPI += rollout.numTimeSteps() - 1;
        }
        if (this.episodeHistory.size() >= this.numEpisodesToStore) {
            this.episodeHistory.poll();
        }
        this.episodeHistory.offer(rollout);
        return rollout;
    }

    protected void updateDatasetWithLearningEpisode(Episode episode) {
        if (this.dataset == null) {
            this.dataset = new SARSData(episode.numTimeSteps() - 1);
        }
        for (int i = 0; i < episode.numTimeSteps() - 1; i++) {
            this.dataset.add(episode.state(i), episode.action(i), episode.reward(i + 1), episode.state(i + 1));
        }
    }

    protected boolean shouldRereunPolicyIteration(Episode episode) {
        return (this.numStepsSinceLastLearningPI + episode.numTimeSteps()) - 1 > this.minNewStepsForLearningPI;
    }

    public Episode getLastLearningEpisode() {
        return this.episodeHistory.getLast();
    }

    public void setNumEpisodesToStore(int i) {
        this.numEpisodesToStore = i;
    }

    public List<Episode> getAllStoredLearningEpisodes() {
        return this.episodeHistory;
    }
}
