package burlap.behavior.singleagent.learning.actorcritic.actor;

import burlap.behavior.learningrate.ConstantLR;
import burlap.behavior.learningrate.LearningRate;
import burlap.behavior.policy.EnumerablePolicy;
import burlap.behavior.policy.PolicyUtils;
import burlap.behavior.policy.support.ActionProb;
import burlap.behavior.singleagent.learning.actorcritic.Actor;
import burlap.datastructures.BoltzmannDistribution;
import burlap.mdp.core.Domain;
import burlap.mdp.core.action.Action;
import burlap.mdp.core.action.ActionType;
import burlap.mdp.core.action.ActionUtils;
import burlap.mdp.core.state.State;
import burlap.mdp.singleagent.SADomain;
import burlap.mdp.singleagent.environment.EnvironmentOutcome;
import burlap.statehashing.HashableState;
import burlap.statehashing.HashableStateFactory;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:burlap/behavior/singleagent/learning/actorcritic/actor/BoltzmannActor.class */
public class BoltzmannActor implements Actor, EnumerablePolicy {
    protected Domain domain;
    protected List<ActionType> actionTypes;
    protected HashableStateFactory hashingFactory;
    protected LearningRate learningRate;
    protected int totalNumberOfSteps = 0;
    protected Map<HashableState, PolicyNode> preferences = new HashMap();

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:burlap/behavior/singleagent/learning/actorcritic/actor/BoltzmannActor$ActionPreference.class */
    public class ActionPreference {
        public Action ga;
        public double preference;

        public ActionPreference(Action action, double d) {
            this.ga = action;
            this.preference = d;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:burlap/behavior/singleagent/learning/actorcritic/actor/BoltzmannActor$PolicyNode.class */
    public class PolicyNode {
        public HashableState sh;
        public List<ActionPreference> preferences = new ArrayList();

        public PolicyNode(HashableState hashableState) {
            this.sh = hashableState;
        }

        public void addPreference(ActionPreference actionPreference) {
            this.preferences.add(actionPreference);
        }
    }

    public BoltzmannActor(SADomain sADomain, HashableStateFactory hashableStateFactory, double d) {
        this.domain = sADomain;
        this.actionTypes = new ArrayList(sADomain.getActionTypes());
        this.hashingFactory = hashableStateFactory;
        this.learningRate = new ConstantLR(Double.valueOf(d));
    }

    public void setLearningRate(LearningRate learningRate) {
        this.learningRate = learningRate;
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.Actor
    public void startEpisode(State state) {
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.Actor
    public void endEpisode() {
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.Actor
    public void update(EnvironmentOutcome environmentOutcome, double d) {
        HashableState hashState = this.hashingFactory.hashState(environmentOutcome.o);
        PolicyNode node = getNode(hashState);
        double pollLearningRate = this.learningRate.pollLearningRate(this.totalNumberOfSteps, hashState.s(), environmentOutcome.a);
        getMatchingPreference(hashState, environmentOutcome.a, node).preference += pollLearningRate * d;
        this.totalNumberOfSteps++;
    }

    public void addActionType(ActionType actionType) {
        if (this.actionTypes.contains(actionType)) {
            return;
        }
        this.actionTypes.add(actionType);
    }

    @Override // burlap.behavior.policy.Policy
    public Action action(State state) {
        return PolicyUtils.sampleFromActionDistribution(this, state);
    }

    @Override // burlap.behavior.policy.Policy
    public double actionProb(State state, Action action) {
        return PolicyUtils.actionProbFromEnum(this, state, action);
    }

    @Override // burlap.behavior.policy.EnumerablePolicy
    public List<ActionProb> policyDistribution(State state) {
        PolicyNode node = getNode(this.hashingFactory.hashState(state));
        double[] dArr = new double[node.preferences.size()];
        for (int i = 0; i < node.preferences.size(); i++) {
            dArr[i] = node.preferences.get(i).preference;
        }
        double[] probabilities = new BoltzmannDistribution(dArr).getProbabilities();
        ArrayList arrayList = new ArrayList(probabilities.length);
        for (int i2 = 0; i2 < probabilities.length; i2++) {
            arrayList.add(new ActionProb(node.preferences.get(i2).ga, probabilities[i2]));
        }
        return arrayList;
    }

    protected PolicyNode getNode(HashableState hashableState) {
        List<Action> allApplicableActionsForTypes = ActionUtils.allApplicableActionsForTypes(this.actionTypes, hashableState.s());
        PolicyNode policyNode = this.preferences.get(hashableState);
        if (policyNode == null) {
            policyNode = new PolicyNode(hashableState);
            Iterator<Action> it = allApplicableActionsForTypes.iterator();
            while (it.hasNext()) {
                policyNode.addPreference(new ActionPreference(it.next(), 0.0d));
            }
            this.preferences.put(hashableState, policyNode);
        }
        return policyNode;
    }

    @Override // burlap.behavior.policy.Policy
    public boolean definedFor(State state) {
        return true;
    }

    @Override // burlap.behavior.singleagent.learning.actorcritic.Actor
    public void reset() {
        this.preferences.clear();
        this.learningRate.resetDecay();
    }

    protected ActionPreference getMatchingPreference(HashableState hashableState, Action action, PolicyNode policyNode) {
        for (ActionPreference actionPreference : policyNode.preferences) {
            if (actionPreference.ga.equals(action)) {
                return actionPreference;
            }
        }
        return null;
    }
}
