package ai.libs.jaicore.search.algorithms.mdp.mcts.uct;

import ai.libs.jaicore.search.algorithms.mdp.mcts.EBehaviorForNotFullyExploredStates;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.NodeLabel;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/uct/AUpdatingPolicy.class */
public abstract class AUpdatingPolicy<N, A> implements IPathUpdatablePolicy<N, A, Double>, ILoggingCustomizable {
    private final double gamma;
    private final boolean maximize;
    static final /* synthetic */ boolean $assertionsDisabled;
    private Logger logger = LoggerFactory.getLogger(AUpdatingPolicy.class);
    private EBehaviorForNotFullyExploredStates behaviorWhenActionForNotFullyExploredStateIsRequested = EBehaviorForNotFullyExploredStates.EXCEPTION;
    private final Map<N, NodeLabel<A>> labels = new HashMap();

    public AUpdatingPolicy(double d, boolean z) {
        this.gamma = d;
        this.maximize = z;
    }

    public NodeLabel<A> getLabelOfNode(N n) {
        if (this.labels.containsKey(n)) {
            return this.labels.get(n);
        }
        throw new IllegalArgumentException("No label for node " + n);
    }

    public abstract double getScore(N n, A a);

    public abstract A getActionBasedOnScores(Map<A, Double> map);

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, List<Double> list) {
        this.logger.debug("Updating path {} with score {}", iLabeledPath, list);
        if (iLabeledPath.isPoint()) {
            throw new IllegalArgumentException("Cannot update path consisting only of the root.");
        }
        List nodes = iLabeledPath.getNodes();
        List arcs = iLabeledPath.getArcs();
        double d = 0.0d;
        for (int size = nodes.size() - 2; size >= 0; size--) {
            Object obj = nodes.get(size);
            Object obj2 = arcs.get(size);
            NodeLabel nodeLabel = (NodeLabel) this.labels.computeIfAbsent(obj, obj3 -> {
                return new NodeLabel();
            });
            d = list.get(size).doubleValue() + (this.gamma * d);
            nodeLabel.addRewardForAction(obj2, d);
            nodeLabel.addPull(obj2);
            nodeLabel.addVisit();
            this.logger.trace("Updated label of node {}. Visits now {}. Action pulls of {} now {}. Observed total rewards for this action: {}", new Object[]{obj, Integer.valueOf(nodeLabel.getVisits()), obj2, Integer.valueOf(nodeLabel.getNumPulls(obj2)), Double.valueOf(nodeLabel.getAccumulatedRewardsOfAction(obj2))});
        }
        this.logger.debug("Path update completed.");
    }

    /* JADX WARN: Multi-variable type inference failed */
    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy
    public A getAction(N n, Collection<A> collection) {
        A actionBasedOnScores;
        this.logger.debug("Deriving action for node {}. The {} options are: {}", new Object[]{n, Integer.valueOf(collection.size()), collection});
        List list = (List) collection.stream().filter(obj -> {
            return !this.labels.containsKey(n);
        }).collect(Collectors.toList());
        if (!list.isEmpty()) {
            A a = (A) list.get(0);
            this.logger.info("Dictating action {}, because this was never played before.", a);
            return a;
        }
        NodeLabel<A> nodeLabel = this.labels.get(n);
        this.logger.debug("All actions have been tried. Label is: {}", nodeLabel);
        HashMap hashMap = new HashMap();
        boolean z = false;
        for (A a2 : collection) {
            if (!$assertionsDisabled && nodeLabel.getVisits() == 0) {
                throw new AssertionError("Visits of action " + a2 + " cannot be 0 if we already used this action before!");
            }
            this.logger.trace("Considering action {}, which has {} visits and cummulative rewards {}.", new Object[]{a2, Integer.valueOf(nodeLabel.getNumPulls(a2)), Double.valueOf(nodeLabel.getAccumulatedRewardsOfAction(a2))});
            Double valueOf = Double.valueOf(getScore(n, a2));
            if (!valueOf.isNaN()) {
                hashMap.put(a2, valueOf);
                if (!$assertionsDisabled && valueOf.isNaN()) {
                    throw new AssertionError("The score of action " + a2 + " is NaN, which cannot be the case.");
                }
            } else {
                if (this.behaviorWhenActionForNotFullyExploredStateIsRequested == EBehaviorForNotFullyExploredStates.EXCEPTION) {
                    throw new IllegalStateException("Score of action " + a2 + " is NaN, which it must not be the case!");
                }
                z = true;
            }
        }
        if (!z || this.behaviorWhenActionForNotFullyExploredStateIsRequested == EBehaviorForNotFullyExploredStates.BEST) {
            actionBasedOnScores = getActionBasedOnScores(hashMap);
        } else {
            ArrayList arrayList = new ArrayList(collection);
            Collections.shuffle(arrayList);
            actionBasedOnScores = arrayList.get(0);
        }
        if (actionBasedOnScores == null) {
            throw new IllegalStateException("Would return null, but this must not be the case! Check the method that chooses an action given the scores.");
        }
        this.logger.info("Recommending action {}.", actionBasedOnScores);
        return actionBasedOnScores;
    }

    public boolean isMaximize() {
        return this.maximize;
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
        this.logger.info("Set logger of {} to {}", this, str);
    }

    public double getGamma() {
        return this.gamma;
    }

    public EBehaviorForNotFullyExploredStates getBehaviorWhenActionForNotFullyExploredStateIsRequested() {
        return this.behaviorWhenActionForNotFullyExploredStateIsRequested;
    }

    public void setBehaviorWhenActionForNotFullyExploredStateIsRequested(EBehaviorForNotFullyExploredStates eBehaviorForNotFullyExploredStates) {
        this.behaviorWhenActionForNotFullyExploredStateIsRequested = eBehaviorForNotFullyExploredStates;
    }

    static {
        $assertionsDisabled = !AUpdatingPolicy.class.desiredAssertionStatus();
    }
}
