package burlap.behavior.singleagent.planning.stochastic.policyiteration;

import burlap.behavior.policy.EnumerablePolicy;
import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.singleagent.planning.Planner;
import burlap.behavior.singleagent.planning.stochastic.DynamicProgramming;
import burlap.debugtools.DPrint;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.model.FullModel;
import burlap.mdp.singleagent.model.TransitionProb;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.Set;

/* loaded from: input_file:burlap/behavior/singleagent/planning/stochastic/policyiteration/PolicyIteration.class */
public class PolicyIteration extends DynamicProgramming implements Planner {
    protected double maxEvalDelta;
    protected double maxPIDelta;
    protected int maxIterations;
    protected int maxPolicyIterations;
    protected EnumerablePolicy evaluativePolicy;
    protected boolean foundReachableStates = false;
    protected int totalPolicyIterations = 0;
    protected int totalValueIterations = 0;
    protected boolean hasRunPlanning = false;

    public PolicyIteration(SADomain sADomain, double d, HashableStateFactory hashableStateFactory, double d2, int i, int i2) {
        DPPInit(sADomain, d, hashableStateFactory);
        this.maxEvalDelta = d2;
        this.maxPIDelta = d2;
        this.maxIterations = i;
        this.maxPolicyIterations = i2;
        this.evaluativePolicy = new GreedyQPolicy(getCopyOfValueFunction());
    }

    public PolicyIteration(SADomain sADomain, double d, HashableStateFactory hashableStateFactory, double d2, double d3, int i, int i2) {
        DPPInit(sADomain, d, hashableStateFactory);
        this.maxEvalDelta = d3;
        this.maxPIDelta = d2;
        this.maxIterations = i;
        this.maxPolicyIterations = i2;
        this.evaluativePolicy = new GreedyQPolicy(getCopyOfValueFunction());
    }

    public void setPolicyToEvaluate(EnumerablePolicy enumerablePolicy) {
        this.evaluativePolicy = enumerablePolicy;
    }

    public Policy getComputedPolicy() {
        return this.evaluativePolicy;
    }

    public void recomputeReachableStates() {
        this.foundReachableStates = false;
    }

    public int getTotalPolicyIterations() {
        return this.totalPolicyIterations;
    }

    public int getTotalValueIterations() {
        return this.totalValueIterations;
    }

    @Override // burlap.behavior.singleagent.planning.Planner
    public GreedyQPolicy planFromState(State state) {
        int i = 0;
        if (performReachabilityFrom(state) || !this.hasRunPlanning) {
            do {
                double evaluatePolicy = evaluatePolicy();
                i++;
                this.evaluativePolicy = new GreedyQPolicy(getCopyOfValueFunction());
                if (evaluatePolicy <= this.maxPIDelta) {
                    break;
                }
            } while (i < this.maxPolicyIterations);
            this.hasRunPlanning = true;
        }
        DPrint.cl(this.debugCode, "Total policy iterations: " + i);
        this.totalPolicyIterations += i;
        return (GreedyQPolicy) this.evaluativePolicy;
    }

    @Override // burlap.behavior.singleagent.planning.stochastic.DynamicProgramming, burlap.behavior.singleagent.MDPSolver, burlap.behavior.singleagent.MDPSolverInterface
    public void resetSolver() {
        super.resetSolver();
        this.foundReachableStates = false;
        this.totalValueIterations = 0;
        this.totalPolicyIterations = 0;
    }

    protected double evaluatePolicy() {
        if (!this.foundReachableStates) {
            throw new RuntimeException("Cannot run VI until the reachable states have been found. Use planFromState method at least once or instead.");
        }
        double d = Double.NEGATIVE_INFINITY;
        Set<HashableState> keySet = this.valueFunction.keySet();
        int i = 0;
        while (true) {
            if (i >= this.maxIterations) {
                break;
            }
            double d2 = 0.0d;
            for (HashableState hashableState : keySet) {
                d2 = Math.max(Math.abs(performFixedPolicyBellmanUpdateOn(hashableState, this.evaluativePolicy) - value(hashableState)), d2);
            }
            d = Math.max(d2, d);
            if (d2 < this.maxEvalDelta) {
                i++;
                break;
            }
            i++;
        }
        DPrint.cl(this.debugCode, "Iterations in inner VI for policy eval: " + i);
        this.totalValueIterations += i;
        return d;
    }

    public boolean performReachabilityFrom(State state) {
        HashableState stateHash = stateHash(state);
        if (this.valueFunction.containsKey(stateHash) && this.foundReachableStates) {
            return false;
        }
        DPrint.cl(this.debugCode, "Starting reachability analysis");
        LinkedList linkedList = new LinkedList();
        HashSet hashSet = new HashSet();
        linkedList.offer(stateHash);
        hashSet.add(stateHash);
        while (!linkedList.isEmpty()) {
            HashableState hashableState = (HashableState) linkedList.poll();
            if (!this.valueFunction.containsKey(hashableState) && !this.model.terminal(hashableState.s())) {
                this.valueFunction.put(hashableState, Double.valueOf(this.valueInitializer.value(hashableState.s())));
                Iterator<Action> it = applicableActions(hashableState.s()).iterator();
                while (it.hasNext()) {
                    Iterator<TransitionProb> it2 = ((FullModel) this.model).transitions(hashableState.s(), it.next()).iterator();
                    while (it2.hasNext()) {
                        HashableState stateHash2 = stateHash(it2.next().eo.op);
                        if (!hashSet.contains(stateHash2) && !this.valueFunction.containsKey(stateHash2)) {
                            hashSet.add(stateHash2);
                            linkedList.offer(stateHash2);
                        }
                    }
                }
            }
        }
        DPrint.cl(this.debugCode, "Finished reachability analysis; # states: " + this.valueFunction.size());
        this.foundReachableStates = true;
        return true;
    }
}
