package burlap.behavior.singleagent.learning.tdmethods.vfa;

import burlap.behavior.functionapproximation.DifferentiableStateActionValue;
import burlap.behavior.functionapproximation.FunctionGradient;
import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.singleagent.options.EnvironmentOptionOutcome;
import burlap.datastructures.HashedAggregator;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/vfa/GradientDescentQLearning.class */
public class GradientDescentQLearning extends ApproximateQLearning {
    protected LearningRate learningRate;

    public GradientDescentQLearning(SADomain sADomain, double d, DifferentiableStateActionValue differentiableStateActionValue, double d2) {
        super(sADomain, d, differentiableStateActionValue);
        this.learningRate = new ConstantLR(Double.valueOf(d2));
    }

    public GradientDescentQLearning(SADomain sADomain, double d, DifferentiableStateActionValue differentiableStateActionValue, LearningRate learningRate) {
        super(sADomain, d, differentiableStateActionValue);
        this.learningRate = learningRate;
    }

    public LearningRate getLearningRate() {
        return this.learningRate;
    }

    public void setLearningRate(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    @Override // burlap.behavior.singleagent.learning.tdmethods.vfa.ApproximateQLearning
    public void updateQFunction(List<EnvironmentOutcome> list) {
        HashedAggregator hashedAggregator = new HashedAggregator();
        for (EnvironmentOutcome environmentOutcome : list) {
            double staleValue = (environmentOutcome.r + ((environmentOutcome instanceof EnvironmentOptionOutcome ? ((EnvironmentOptionOutcome) environmentOutcome).discount : this.gamma) * (environmentOutcome.terminated ? 0.0d : staleValue(environmentOutcome.op)))) - this.vfa.evaluate(environmentOutcome.o, environmentOutcome.a);
            for (FunctionGradient.PartialDerivative partialDerivative : ((DifferentiableStateActionValue) this.vfa).gradient(environmentOutcome.o, environmentOutcome.a).getNonZeroPartialDerivatives()) {
                hashedAggregator.add(Integer.valueOf(partialDerivative.parameterId), partialDerivative.value * staleValue);
            }
        }
        double size = 1.0d / list.size();
        for (Map.Entry entry : hashedAggregator.entrySet()) {
            int intValue = ((Integer) entry.getKey()).intValue();
            this.vfa.setParameter(intValue, this.vfa.getParameter(intValue) + (this.learningRate.pollLearningRate(this.totalSteps, intValue) * size * ((Double) entry.getValue()).doubleValue()));
        }
    }
}
