package burlap.behavior.singleagent.learning.modellearning.models;

import burlap.behavior.singleagent.learning.modellearning.KWIKModel;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.action.ActionUtils;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import burlap.mdp.singleagent.model.FullModel;
import burlap.mdp.singleagent.model.TransitionProb;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;

/* loaded from: input_file:burlap/behavior/singleagent/learning/modellearning/models/TabularModel.class */
public class TabularModel implements KWIKModel {
    protected SADomain sourceDomain;
    protected HashableStateFactory hashingFactory;
    protected Map<HashableState, StateNode> stateNodes = new HashMap();
    protected Set<HashableState> terminalStates = new HashSet();
    protected int nConfident;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:burlap/behavior/singleagent/learning/modellearning/models/TabularModel$OutcomeState.class */
    public class OutcomeState {
        HashableState osh;
        int nTimes = 1;

        public OutcomeState(HashableState hashableState) {
            this.osh = hashableState;
        }

        public int hashCode() {
            return this.osh.hashCode();
        }

        public boolean equals(Object obj) {
            if (obj instanceof OutcomeState) {
                return this.osh.equals(((OutcomeState) obj).osh);
            }
            return false;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:burlap/behavior/singleagent/learning/modellearning/models/TabularModel$StateActionNode.class */
    public class StateActionNode {
        Action ga;
        int nTries;
        double sumR;
        Map<HashableState, OutcomeState> outcomes;

        public StateActionNode(Action action) {
            this.ga = action;
            this.sumR = 0.0d;
            this.nTries = 0;
            this.outcomes = new HashMap();
        }

        public StateActionNode(Action action, double d, HashableState hashableState) {
            this.ga = action;
            this.sumR = d;
            this.nTries = 1;
            this.outcomes = new HashMap();
            this.outcomes.put(hashableState, new OutcomeState(hashableState));
        }

        public void update(double d, HashableState hashableState) {
            this.nTries++;
            this.sumR += d;
            OutcomeState outcomeState = this.outcomes.get(hashableState);
            if (outcomeState != null) {
                outcomeState.nTimes++;
            } else {
                this.outcomes.put(hashableState, new OutcomeState(hashableState));
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:burlap/behavior/singleagent/learning/modellearning/models/TabularModel$StateNode.class */
    public class StateNode {
        HashableState sh;
        Map<Action, StateActionNode> actionNodes = new HashMap();

        public StateNode(HashableState hashableState) {
            this.sh = hashableState;
        }

        public StateActionNode actionNode(Action action) {
            return this.actionNodes.get(action);
        }

        public StateActionNode addActionNode(Action action) {
            StateActionNode stateActionNode = new StateActionNode(action);
            this.actionNodes.put(action, stateActionNode);
            return stateActionNode;
        }
    }

    public TabularModel(SADomain sADomain, HashableStateFactory hashableStateFactory, int i) {
        this.sourceDomain = sADomain;
        this.hashingFactory = hashableStateFactory;
        this.nConfident = i;
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.KWIKModel
    public boolean transitionIsModeled(State state, Action action) {
        StateActionNode stateActionNode = getStateActionNode(this.hashingFactory.hashState(state), action);
        return stateActionNode != null && stateActionNode.nTries >= this.nConfident;
    }

    @Override // burlap.mdp.singleagent.model.FullModel
    public List<TransitionProb> transitions(State state, Action action) {
        ArrayList arrayList = new ArrayList();
        StateActionNode stateActionNode = getStateActionNode(this.hashingFactory.hashState(state), action);
        if (stateActionNode == null) {
            arrayList.add(new TransitionProb(1.0d, new EnvironmentOutcome(state, action, state, 0.0d, false)));
        } else {
            double d = stateActionNode.sumR / stateActionNode.nTries;
            Iterator<OutcomeState> it = stateActionNode.outcomes.values().iterator();
            while (it.hasNext()) {
                State s = it.next().osh.s();
                arrayList.add(new TransitionProb(r0.nTimes / stateActionNode.nTries, new EnvironmentOutcome(state, action, s, d, this.terminalStates.contains(s))));
            }
        }
        return arrayList;
    }

    @Override // burlap.mdp.singleagent.model.SampleModel
    public EnvironmentOutcome sample(State state, Action action) {
        return FullModel.Helper.sampleByEnumeration(this, state, action);
    }

    @Override // burlap.mdp.singleagent.model.SampleModel
    public boolean terminal(State state) {
        return this.terminalStates.contains(this.hashingFactory.hashState(state));
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.LearnedModel
    public void updateModel(EnvironmentOutcome environmentOutcome) {
        HashableState hashState = this.hashingFactory.hashState(environmentOutcome.o);
        HashableState hashState2 = this.hashingFactory.hashState(environmentOutcome.op);
        if (environmentOutcome.terminated) {
            this.terminalStates.add(hashState2);
        }
        getOrCreateActionNode(hashState, environmentOutcome.a).update(environmentOutcome.r, hashState2);
    }

    protected StateActionNode getStateActionNode(HashableState hashableState, Action action) {
        StateNode stateNode = this.stateNodes.get(hashableState);
        if (stateNode == null) {
            return null;
        }
        return stateNode.actionNode(action);
    }

    protected StateActionNode getOrCreateActionNode(HashableState hashableState, Action action) {
        StateNode stateNode = this.stateNodes.get(hashableState);
        StateActionNode stateActionNode = null;
        if (stateNode == null) {
            StateNode stateNode2 = new StateNode(hashableState);
            this.stateNodes.put(hashableState, stateNode2);
            for (Action action2 : ActionUtils.allApplicableActionsForTypes(this.sourceDomain.getActionTypes(), hashableState.s())) {
                StateActionNode addActionNode = stateNode2.addActionNode(action2);
                if (action2.equals(action)) {
                    stateActionNode = addActionNode;
                }
            }
        } else {
            stateActionNode = stateNode.actionNode(action);
        }
        if (stateActionNode == null) {
            throw new RuntimeException("Could not finding matching grounded action in model for action: " + action.toString());
        }
        return stateActionNode;
    }

    @Override // burlap.behavior.singleagent.learning.modellearning.LearnedModel
    public void resetModel() {
        this.stateNodes.clear();
        this.terminalStates.clear();
    }
}
