package burlap.behavior.singleagent.learning.tdmethods.vfa;

import burlap.behavior.functionapproximation.ParametricFunction;
import burlap.behavior.policy.EpsilonGreedy;
import burlap.behavior.policy.Policy;
import burlap.behavior.singleagent.Episode;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.singleagent.learning.experiencereplay.ExperienceMemory;
import burlap.behavior.singleagent.learning.experiencereplay.FixedSizeMemory;
import burlap.behavior.singleagent.options.EnvironmentOptionOutcome;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.mdp.auxiliary.StateMapping;
import burlap.mdp.auxiliary.common.ShallowIdentityStateMapping;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.Environment;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/vfa/ApproximateQLearning.class */
public abstract class ApproximateQLearning extends MDPSolver implements LearningAgent, QProvider {
    protected ParametricFunction.ParametricStateActionFunction vfa;
    protected ParametricFunction.ParametricStateActionFunction staleVfa;
    protected int staleDuration;
    protected int stepsSinceStale;
    protected ExperienceMemory memory;
    protected StateMapping stateMapping;
    protected int numReplay;
    protected Policy learningPolicy;
    protected int totalSteps;
    protected int totalEpisodes;

    public ApproximateQLearning(SADomain sADomain, double d, ParametricFunction.ParametricStateActionFunction parametricStateActionFunction) {
        this(sADomain, d, parametricStateActionFunction, new ShallowIdentityStateMapping());
    }

    public ApproximateQLearning(SADomain sADomain, double d, ParametricFunction.ParametricStateActionFunction parametricStateActionFunction, StateMapping stateMapping) {
        this.staleDuration = 1;
        this.stepsSinceStale = 0;
        this.memory = new FixedSizeMemory(1, true);
        this.numReplay = 1;
        this.totalSteps = 0;
        this.totalEpisodes = 0;
        this.vfa = parametricStateActionFunction;
        this.staleVfa = parametricStateActionFunction;
        this.learningPolicy = new EpsilonGreedy(this, 0.1d);
        this.stateMapping = stateMapping;
        solverInit(sADomain, d, null);
    }

    public void setLearningPolicy(Policy policy) {
        this.learningPolicy = policy;
    }

    public void setExperienceReplay(ExperienceMemory experienceMemory, int i) {
        this.memory = experienceMemory;
        this.numReplay = i;
    }

    public void useStaleTarget(int i) {
        if (this.staleDuration <= 1 && i > 1) {
            this.staleVfa = (ParametricFunction.ParametricStateActionFunction) this.vfa.copy();
        }
        if (this.staleDuration > 1 && i <= 1) {
            this.staleVfa = this.vfa;
        }
        this.staleDuration = i;
    }

    public StateMapping getStateMapping() {
        return this.stateMapping;
    }

    public void setStateMapping(StateMapping stateMapping) {
        this.stateMapping = stateMapping;
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public Episode runLearningEpisode(Environment environment) {
        return runLearningEpisode(environment, -1);
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public Episode runLearningEpisode(Environment environment, int i) {
        Episode episode = new Episode(environment.currentObservation());
        int i2 = 0;
        while (!environment.isInTerminalState() && (i2 < i || i == -1)) {
            Action action = this.learningPolicy.action(this.stateMapping.mapState(environment.currentObservation()));
            EnvironmentOutcome executeAction = environment.executeAction(action);
            this.memory.addExperience(executeAction);
            int numSteps = executeAction instanceof EnvironmentOptionOutcome ? ((EnvironmentOptionOutcome) executeAction).numSteps() : 1;
            i2 += numSteps;
            this.totalSteps += numSteps;
            episode.transition(action, executeAction.op, executeAction.r);
            updateQFunction(this.memory.sampleExperiences(this.numReplay));
            this.stepsSinceStale++;
            if (this.stepsSinceStale >= this.staleDuration) {
                updateStaleFunction();
            }
        }
        this.totalEpisodes++;
        return episode;
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void resetSolver() {
        this.vfa.resetParameters();
        this.memory.resetMemory();
        this.totalSteps = 0;
        this.totalEpisodes = 0;
    }

    @Override // burlap.behavior.valuefunction.QProvider
    public List<QValue> qValues(State state) {
        State mapState = this.stateMapping.mapState(state);
        List<Action> applicableActions = applicableActions(mapState);
        ArrayList arrayList = new ArrayList(applicableActions.size());
        for (Action action : applicableActions) {
            arrayList.add(new QValue(mapState, action, qValue(mapState, action)));
        }
        return arrayList;
    }

    @Override // burlap.behavior.valuefunction.QFunction
    public double qValue(State state, Action action) {
        return this.vfa.evaluate(this.stateMapping.mapState(state), action);
    }

    @Override // burlap.behavior.valuefunction.ValueFunction
    public double value(State state) {
        double d = Double.NEGATIVE_INFINITY;
        Iterator<QValue> it = qValues(this.stateMapping.mapState(state)).iterator();
        while (it.hasNext()) {
            d = Math.max(d, it.next().q);
        }
        return d;
    }

    public List<QValue> getStaleQs(State state) {
        State mapState = this.stateMapping.mapState(state);
        List<Action> applicableActions = applicableActions(mapState);
        ArrayList arrayList = new ArrayList(applicableActions.size());
        Iterator<Action> it = applicableActions.iterator();
        while (it.hasNext()) {
            arrayList.add(getStaleQ(mapState, it.next()));
        }
        return arrayList;
    }

    public QValue getStaleQ(State state, Action action) {
        State mapState = this.stateMapping.mapState(state);
        return new QValue(mapState, action, this.staleVfa.evaluate(mapState, action));
    }

    public double staleValue(State state) {
        double d = Double.NEGATIVE_INFINITY;
        Iterator<QValue> it = getStaleQs(this.stateMapping.mapState(state)).iterator();
        while (it.hasNext()) {
            d = Math.max(d, it.next().q);
        }
        return d;
    }

    public void updateStaleFunction() {
        if (this.staleDuration > 1) {
            this.staleVfa = (ParametricFunction.ParametricStateActionFunction) this.vfa.copy();
        } else {
            this.staleVfa = this.vfa;
        }
        this.stepsSinceStale = 0;
    }

    public void resumeFrom(int i, int i2) {
        this.totalSteps = i;
        this.totalEpisodes = i2;
        updateStaleFunction();
    }

    public abstract void updateQFunction(List<EnvironmentOutcome> list);
}
