package burlap.behavior.singleagent.learning.tdmethods;

import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.policy.EpsilonGreedy;
import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.singleagent.Episode;
import burlap.behavior.singleagent.MDPSolver;
import burlap.behavior.singleagent.learning.LearningAgent;
import burlap.behavior.singleagent.options.EnvironmentOptionOutcome;
import burlap.behavior.singleagent.options.Option;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.valuefunction.ConstantValueFunction;
import burlap.behavior.valuefunction.QFunction;
import burlap.behavior.valuefunction.QProvider;
import burlap.behavior.valuefunction.QValue;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.Environment;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import burlap.mdp.singleagent.environment.SimulatedEnvironment;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.io.BufferedWriter;
import java.io.FileInputStream;
import java.io.FileNotFoundException;
import java.io.FileWriter;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import javax.management.RuntimeErrorException;
import org.yaml.snakeyaml.Yaml;

/* loaded from: input_file:burlap/behavior/singleagent/learning/tdmethods/QLearning.class */
public class QLearning extends MDPSolver implements QProvider, LearningAgent, Planner {
    protected Map<HashableState, QLearningStateNode> qFunction;
    protected QFunction qInitFunction;
    protected LearningRate learningRate;
    protected Policy learningPolicy;
    protected int maxEpisodeSize;
    protected int eStepCounter;
    protected int numEpisodesForPlanning;
    protected double maxQChangeForPlanningTermination;
    protected double maxQChangeInLastEpisode = Double.POSITIVE_INFINITY;
    protected boolean shouldDecomposeOptions = true;
    protected int totalNumberOfSteps = 0;

    public QLearning(SADomain sADomain, double d, HashableStateFactory hashableStateFactory, double d2, double d3) {
        QLInit(sADomain, d, hashableStateFactory, new ConstantValueFunction(d2), d3, new EpsilonGreedy(this, 0.1d), Integer.MAX_VALUE);
    }

    public QLearning(SADomain sADomain, double d, HashableStateFactory hashableStateFactory, double d2, double d3, int i) {
        QLInit(sADomain, d, hashableStateFactory, new ConstantValueFunction(d2), d3, new EpsilonGreedy(this, 0.1d), i);
    }

    public QLearning(SADomain sADomain, double d, HashableStateFactory hashableStateFactory, double d2, double d3, Policy policy, int i) {
        QLInit(sADomain, d, hashableStateFactory, new ConstantValueFunction(d2), d3, policy, i);
    }

    public QLearning(SADomain sADomain, double d, HashableStateFactory hashableStateFactory, QFunction qFunction, double d2, Policy policy, int i) {
        QLInit(sADomain, d, hashableStateFactory, qFunction, d2, policy, i);
    }

    protected void QLInit(SADomain sADomain, double d, HashableStateFactory hashableStateFactory, QFunction qFunction, double d2, Policy policy, int i) {
        solverInit(sADomain, d, hashableStateFactory);
        this.qFunction = new HashMap();
        this.learningRate = new ConstantLR(Double.valueOf(d2));
        this.learningPolicy = policy;
        this.maxEpisodeSize = i;
        this.qInitFunction = qFunction;
        this.numEpisodesForPlanning = 1;
        this.maxQChangeForPlanningTermination = 0.0d;
    }

    public void initializeForPlanning(int i) {
        this.numEpisodesForPlanning = i;
    }

    public void setLearningRateFunction(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    public void setQInitFunction(QFunction qFunction) {
        this.qInitFunction = qFunction;
    }

    public void setLearningPolicy(Policy policy) {
        this.learningPolicy = policy;
    }

    public void setMaximumEpisodesForPlanning(int i) {
        if (i > 0) {
            this.numEpisodesForPlanning = i;
        } else {
            this.numEpisodesForPlanning = 1;
        }
    }

    public void setMaxQChangeForPlanningTerminaiton(double d) {
        if (d > 0.0d) {
            this.maxQChangeForPlanningTermination = d;
        } else {
            this.maxQChangeForPlanningTermination = 0.0d;
        }
    }

    public int getLastNumSteps() {
        return this.eStepCounter;
    }

    public void toggleShouldDecomposeOption(boolean z) {
        this.shouldDecomposeOptions = z;
    }

    @Override // burlap.behavior.valuefunction.QProvider
    public List<QValue> qValues(State state) {
        return getQs(stateHash(state));
    }

    @Override // burlap.behavior.valuefunction.QFunction
    public double qValue(State state, Action action) {
        return getQ(stateHash(state), action).q;
    }

    protected List<QValue> getQs(HashableState hashableState) {
        return getStateNode(hashableState).qEntry;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public QValue getQ(HashableState hashableState, Action action) {
        for (QValue qValue : getStateNode(hashableState).qEntry) {
            if (qValue.a.equals(action)) {
                return qValue;
            }
        }
        return null;
    }

    @Override // burlap.behavior.valuefunction.ValueFunction
    public double value(State state) {
        return QProvider.Helper.maxQ(this, state);
    }

    protected QLearningStateNode getStateNode(HashableState hashableState) {
        QLearningStateNode qLearningStateNode = this.qFunction.get(hashableState);
        if (qLearningStateNode == null) {
            qLearningStateNode = new QLearningStateNode(hashableState);
            List<Action> applicableActions = applicableActions(hashableState.s());
            if (applicableActions.isEmpty()) {
                applicableActions(hashableState.s());
                throw new RuntimeErrorException(new Error("No possible actions in this state, cannot continue Q-learning"));
            }
            for (Action action : applicableActions) {
                qLearningStateNode.addQValue(action, this.qInitFunction.qValue(hashableState.s(), action));
            }
            this.qFunction.put(hashableState, qLearningStateNode);
        }
        return qLearningStateNode;
    }

    protected double getMaxQ(HashableState hashableState) {
        double d = Double.NEGATIVE_INFINITY;
        for (QValue qValue : getQs(hashableState)) {
            if (qValue.q > d) {
                d = qValue.q;
            }
        }
        return d;
    }

    @Override // burlap.behavior.singleagent.planning.Planner
    public GreedyQPolicy planFromState(State state) {
        if (this.model == null) {
            throw new RuntimeException("QLearning (and its subclasses) cannot execute planFromState because a model is not specified.");
        }
        SimulatedEnvironment simulatedEnvironment = new SimulatedEnvironment(this.domain, state);
        int i = 0;
        do {
            runLearningEpisode(simulatedEnvironment, this.maxEpisodeSize);
            i++;
            if (i >= this.numEpisodesForPlanning) {
                break;
            }
        } while (this.maxQChangeInLastEpisode > this.maxQChangeForPlanningTermination);
        return new GreedyQPolicy(this);
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public Episode runLearningEpisode(Environment environment) {
        return runLearningEpisode(environment, -1);
    }

    @Override // burlap.behavior.singleagent.learning.LearningAgent
    public Episode runLearningEpisode(Environment environment, int i) {
        State currentObservation = environment.currentObservation();
        Episode episode = new Episode(currentObservation);
        HashableState stateHash = stateHash(currentObservation);
        this.eStepCounter = 0;
        this.maxQChangeInLastEpisode = 0.0d;
        while (!environment.isInTerminalState() && (this.eStepCounter < i || i == -1)) {
            Action action = this.learningPolicy.action(stateHash.s());
            QValue q = getQ(stateHash, action);
            EnvironmentOutcome executeAction = !(action instanceof Option) ? environment.executeAction(action) : ((Option) action).control(environment, this.gamma);
            HashableState stateHash2 = stateHash(executeAction.op);
            double d = 0.0d;
            if (!executeAction.terminated) {
                d = getMaxQ(stateHash2);
            }
            double d2 = executeAction.r;
            double d3 = executeAction instanceof EnvironmentOptionOutcome ? ((EnvironmentOptionOutcome) executeAction).discount : this.gamma;
            this.eStepCounter += executeAction instanceof EnvironmentOptionOutcome ? ((EnvironmentOptionOutcome) executeAction).numSteps() : 1;
            if ((action instanceof Option) && this.shouldDecomposeOptions) {
                episode.appendAndMergeEpisodeAnalysis(((EnvironmentOptionOutcome) executeAction).episode);
            } else {
                episode.transition(action, stateHash2.s(), d2);
            }
            double d4 = q.q;
            q.q += this.learningRate.pollLearningRate(this.totalNumberOfSteps, stateHash.s(), action) * ((d2 + (d3 * d)) - q.q);
            double abs = Math.abs(d4 - q.q);
            if (abs > this.maxQChangeInLastEpisode) {
                this.maxQChangeInLastEpisode = abs;
            }
            stateHash = stateHash(environment.currentObservation());
            this.totalNumberOfSteps++;
        }
        return episode;
    }

    @Override // burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void resetSolver() {
        this.qFunction.clear();
        this.eStepCounter = 0;
        this.maxQChangeInLastEpisode = Double.POSITIVE_INFINITY;
    }

    public void writeQTable(String str) {
        try {
            new Yaml().dump(this.qFunction, new BufferedWriter(new FileWriter(str)));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void loadQTable(String str) {
        try {
            this.qFunction = (Map) new Yaml().load(new FileInputStream(str));
        } catch (FileNotFoundException e) {
            e.printStackTrace();
        }
    }
}
