package burlap.behavior.singleagent.learning.tdmethods.vfa;

import burlap.behavior.functionapproximation.DifferentiableStateActionValue;
import burlap.behavior.functionapproximation.FunctionGradient;
import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.policy.EpsilonGreedy;
import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.singleagent.Episode;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.singleagent.options.EnvironmentOptionOutcome;
import burlap.behavior.singleagent.options.Option;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.Environment;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import burlap.mdp.singleagent.environment.SimulatedEnvironment;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/vfa/GradientDescentSarsaLam.class */
public class GradientDescentSarsaLam extends MDPSolver implements QProvider, LearningAgent, Planner {
    protected DifferentiableStateActionValue vfa;
    protected LearningRate learningRate;
    protected Policy learningPolicy;
    protected double lambda;
    protected int maxEpisodeSize;
    protected int eStepCounter;
    protected int numEpisodesForPlanning;
    protected double maxWeightChangeForPlanningTermination;
    protected double maxWeightChangeInLastEpisode = Double.POSITIVE_INFINITY;
    protected boolean useFeatureWiseLearningRate = true;
    protected double minEligibityForUpdate = 0.01d;
    protected boolean useReplacingTraces = false;
    protected boolean shouldDecomposeOptions = true;
    protected int totalNumberOfSteps = 0;

    /* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/vfa/GradientDescentSarsaLam$EligibilityTraceVector.class */
    public static class EligibilityTraceVector {
        public int weight;
        public double eligibilityValue;
        public double initialWeightValue;

        public EligibilityTraceVector(int i, double d, double d2) {
            this.weight = i;
            this.eligibilityValue = d2;
            this.initialWeightValue = d;
        }
    }

    public GradientDescentSarsaLam(SADomain sADomain, double d, DifferentiableStateActionValue differentiableStateActionValue, double d2, double d3) {
        GDSLInit(sADomain, d, differentiableStateActionValue, d2, new EpsilonGreedy(this, 0.1d), Integer.MAX_VALUE, d3);
    }

    public GradientDescentSarsaLam(SADomain sADomain, double d, DifferentiableStateActionValue differentiableStateActionValue, double d2, int i, double d3) {
        GDSLInit(sADomain, d, differentiableStateActionValue, d2, new EpsilonGreedy(this, 0.1d), i, d3);
    }

    public GradientDescentSarsaLam(SADomain sADomain, double d, DifferentiableStateActionValue differentiableStateActionValue, double d2, Policy policy, int i, double d3) {
        GDSLInit(sADomain, d, differentiableStateActionValue, d2, policy, i, d3);
    }

    protected void GDSLInit(SADomain sADomain, double d, DifferentiableStateActionValue differentiableStateActionValue, double d2, Policy policy, int i, double d3) {
        solverInit(sADomain, d, null);
        this.vfa = differentiableStateActionValue;
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.learningPolicy = policy;
        this.maxEpisodeSize = i;
        this.lambda = d3;
        this.numEpisodesForPlanning = 1;
        this.maxWeightChangeForPlanningTermination = 0.0d;
    }

    public void initializeForPlanning(int i) {
        this.numEpisodesForPlanning = i;
    }

    public void setLearningRate(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    public void setUseFeatureWiseLearningRate(boolean z) {
        this.useFeatureWiseLearningRate = z;
    }

    public void setLearningPolicy(Policy policy) {
        this.learningPolicy = policy;
    }

    public void setMaximumEpisodesForPlanning(int i) {
        if (i > 0) {
            this.numEpisodesForPlanning = i;
        } else {
            this.numEpisodesForPlanning = 1;
        }
    }

    public void setMaxVFAWeightChangeForPlanningTerminaiton(double d) {
        if (d > 0.0d) {
            this.maxWeightChangeForPlanningTermination = d;
        } else {
            this.maxWeightChangeForPlanningTermination = 0.0d;
        }
    }

    public int getLastNumSteps() {
        return this.eStepCounter;
    }

    public void setUseReplaceTraces(boolean z) {
        this.useReplacingTraces = z;
    }

    public void toggleShouldDecomposeOption(boolean z) {
        this.shouldDecomposeOptions = z;
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public Episode runLearningEpisode(Environment environment) {
        return runLearningEpisode(environment, -1);
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public Episode runLearningEpisode(Environment environment, int i) {
        State currentObservation = environment.currentObservation();
        Episode episode = new Episode(currentObservation);
        this.maxWeightChangeInLastEpisode = 0.0d;
        State state = currentObservation;
        this.eStepCounter = 0;
        HashMap hashMap = new HashMap();
        Action action = this.learningPolicy.action(state);
        while (!environment.isInTerminalState() && (this.eStepCounter < i || i == -1)) {
            double evaluate = this.vfa.evaluate(state, action);
            FunctionGradient gradient = this.vfa.gradient(state, action);
            EnvironmentOutcome executeAction = !(action instanceof Option) ? environment.executeAction(action) : ((Option) action).control(environment, this.gamma);
            State state2 = executeAction.op;
            Action action2 = this.learningPolicy.action(state2);
            double evaluate2 = executeAction.terminated ? 0.0d : this.vfa.evaluate(state2, action2);
            double d = executeAction.r;
            double d2 = executeAction instanceof EnvironmentOptionOutcome ? ((EnvironmentOptionOutcome) executeAction).discount : this.gamma;
            this.eStepCounter += executeAction instanceof EnvironmentOptionOutcome ? ((EnvironmentOptionOutcome) executeAction).numSteps() : 1;
            if ((action instanceof Option) && this.shouldDecomposeOptions) {
                episode.appendAndMergeEpisodeAnalysis(((EnvironmentOptionOutcome) executeAction).episode);
            } else {
                episode.transition(action, state2, d);
            }
            double d3 = (d + (d2 * evaluate2)) - evaluate;
            if (this.useReplacingTraces) {
                for (Action action3 : applicableActions(state)) {
                    this.vfa.evaluate(state, action3);
                    for (FunctionGradient.PartialDerivative partialDerivative : this.vfa.gradient(state, action3).getNonZeroPartialDerivatives()) {
                        EligibilityTraceVector eligibilityTraceVector = (EligibilityTraceVector) hashMap.get(Integer.valueOf(partialDerivative.parameterId));
                        if (eligibilityTraceVector != null) {
                            eligibilityTraceVector.eligibilityValue = 0.0d;
                        } else {
                            hashMap.put(Integer.valueOf(partialDerivative.parameterId), new EligibilityTraceVector(partialDerivative.parameterId, this.vfa.getParameter(partialDerivative.parameterId), 0.0d));
                        }
                    }
                }
            } else {
                for (FunctionGradient.PartialDerivative partialDerivative2 : gradient.getNonZeroPartialDerivatives()) {
                    if (!hashMap.containsKey(Integer.valueOf(partialDerivative2.parameterId))) {
                        hashMap.put(Integer.valueOf(partialDerivative2.parameterId), new EligibilityTraceVector(partialDerivative2.parameterId, this.vfa.getParameter(partialDerivative2.parameterId), 0.0d));
                    }
                }
            }
            double pollLearningRate = this.useFeatureWiseLearningRate ? 0.0d : this.learningRate.pollLearningRate(this.totalNumberOfSteps, state, action);
            HashSet hashSet = new HashSet();
            for (EligibilityTraceVector eligibilityTraceVector2 : hashMap.values()) {
                if (this.useFeatureWiseLearningRate) {
                    pollLearningRate = this.learningRate.pollLearningRate(this.totalNumberOfSteps, eligibilityTraceVector2.weight);
                }
                eligibilityTraceVector2.eligibilityValue += gradient.getPartialDerivative(eligibilityTraceVector2.weight);
                double parameter = this.vfa.getParameter(eligibilityTraceVector2.weight) + (pollLearningRate * d3 * eligibilityTraceVector2.eligibilityValue);
                this.vfa.setParameter(eligibilityTraceVector2.weight, parameter);
                double abs = Math.abs(eligibilityTraceVector2.initialWeightValue - parameter);
                if (abs > this.maxWeightChangeInLastEpisode) {
                    this.maxWeightChangeInLastEpisode = abs;
                }
                eligibilityTraceVector2.eligibilityValue *= this.lambda * d2;
                if (eligibilityTraceVector2.eligibilityValue < this.minEligibityForUpdate) {
                    hashSet.add(Integer.valueOf(eligibilityTraceVector2.weight));
                }
            }
            Iterator it = hashSet.iterator();
            while (it.hasNext()) {
                hashMap.remove((Integer) it.next());
            }
            state = state2;
            action = action2;
            this.totalNumberOfSteps++;
        }
        return episode;
    }

    @Override // burlap.behavior.valuefunction.QProvider
    public List<QValue> qValues(State state) {
        List<Action> applicableActions = applicableActions(state);
        ArrayList arrayList = new ArrayList(applicableActions.size());
        for (Action action : applicableActions) {
            arrayList.add(new QValue(state, action, this.vfa.evaluate(state, action)));
        }
        return arrayList;
    }

    @Override // burlap.behavior.valuefunction.QFunction
    public double qValue(State state, Action action) {
        return this.vfa.evaluate(state, action);
    }

    @Override // burlap.behavior.valuefunction.ValueFunction
    public double value(State state) {
        return QProvider.Helper.maxQ(this, state);
    }

    @Override // burlap.behavior.singleagent.planning.Planner
    public GreedyQPolicy planFromState(State state) {
        if (this.model == null) {
            throw new RuntimeException("Planning requires a model, but none is provided.");
        }
        SimulatedEnvironment simulatedEnvironment = new SimulatedEnvironment(this.domain, state);
        int i = 0;
        do {
            runLearningEpisode(simulatedEnvironment);
            i++;
            if (i >= this.numEpisodesForPlanning) {
                break;
            }
        } while (this.maxWeightChangeInLastEpisode > this.maxWeightChangeForPlanningTermination);
        return new GreedyQPolicy(this);
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void resetSolver() {
        this.vfa.resetParameters();
        this.eStepCounter = 0;
        this.maxWeightChangeInLastEpisode = Double.POSITIVE_INFINITY;
    }
}
