package burlap.behavior.singleagent.learning.modellearning.modelplanners;

import burlap.behavior.policy.EnumerablePolicy;
import burlap.behavior.policy.GreedyQPolicy;
import burlap.behavior.policy.Policy;
import burlap.behavior.policy.support.ActionProb;
import burlap.behavior.singleagent.learning.modellearning.ModelLearningPlanner;
import burlap.behavior.singleagent.planning.stochastic.valueiteration.ValueIteration;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.model.FullModel;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;

/* loaded from: input_file:burlap/behavior/singleagent/learning/modellearning/modelplanners/VIModelLearningPlanner.class */
public class VIModelLearningPlanner extends ValueIteration implements ModelLearningPlanner {
    protected Set<HashableState> observedStates;
    protected Policy modelPolicy;
    protected State initialState;

    /* loaded from: input_file:burlap/behavior/singleagent/learning/modellearning/modelplanners/VIModelLearningPlanner$ReplanIfUnseenPolicy.class */
    class ReplanIfUnseenPolicy implements EnumerablePolicy {
        Policy p;

        public ReplanIfUnseenPolicy(Policy policy) {
            this.p = policy;
        }

        @Override // burlap.behavior.policy.Policy
        public Action action(State state) {
            if (!VIModelLearningPlanner.this.hasComputedValueFor(state)) {
                VIModelLearningPlanner.this.observedStates.add(VIModelLearningPlanner.this.hashingFactory.hashState(state));
                VIModelLearningPlanner.this.rerunVI();
            }
            return this.p.action(state);
        }

        @Override // burlap.behavior.policy.Policy
        public double actionProb(State state, Action action) {
            if (!VIModelLearningPlanner.this.hasComputedValueFor(state)) {
                VIModelLearningPlanner.this.observedStates.add(VIModelLearningPlanner.this.hashingFactory.hashState(state));
                VIModelLearningPlanner.this.rerunVI();
            }
            return this.p.actionProb(state, action);
        }

        @Override // burlap.behavior.policy.EnumerablePolicy
        public List<ActionProb> policyDistribution(State state) {
            if (!(this.p instanceof EnumerablePolicy)) {
                throw new RuntimeException("Cannot return policy distribution because underlying policy is not an EnumerablePolicy");
            }
            if (!VIModelLearningPlanner.this.hasComputedValueFor(state)) {
                VIModelLearningPlanner.this.observedStates.add(VIModelLearningPlanner.this.hashingFactory.hashState(state));
                VIModelLearningPlanner.this.rerunVI();
            }
            return ((EnumerablePolicy) this.p).policyDistribution(state);
        }

        @Override // burlap.behavior.policy.Policy
        public boolean definedFor(State state) {
            return this.p.definedFor(state);
        }
    }

    public VIModelLearningPlanner(SADomain sADomain, FullModel fullModel, double d, HashableStateFactory hashableStateFactory, double d2, int i) {
        super(sADomain, d, hashableStateFactory, d2, i);
        this.observedStates = new HashSet();
        setModel(fullModel);
        this.modelPolicy = new ReplanIfUnseenPolicy(new GreedyQPolicy(this));
        toggleDebugPrinting(false);
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.ModelLearningPlanner
    public void initializePlannerIn(State state) {
        this.initialState = state;
        this.observedStates.add(this.hashingFactory.hashState(state));
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.ModelLearningPlanner
    public void modelChanged(State state) {
        this.observedStates.add(this.hashingFactory.hashState(state));
        rerunVI();
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.ModelLearningPlanner
    public Policy modelPlannedPolicy() {
        return this.modelPolicy;
    }

    protected void rerunVI() {
        resetSolver();
        Iterator<HashableState> it = this.observedStates.iterator();
        while (it.hasNext()) {
            performReachabilityFrom(it.next().s());
        }
        runVI();
    }
}
