package ai.libs.jaicore.search.algorithms.mdp.mcts.ensemble;

import ai.libs.jaicore.graph.LabeledGraph;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IGraphDependentPolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Random;
import org.api4.java.datastructure.graph.ILabeledPath;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/ensemble/EnsembleTreePolicy.class */
public class EnsembleTreePolicy<N, A> implements IPathUpdatablePolicy<N, A, Double>, IGraphDependentPolicy<N, A> {
    private final List<IPathUpdatablePolicy<N, A, Double>> treePolicies;
    private IPolicy<N, A> lastPolicy;
    private int calls;
    private final Random rand = new Random(0);
    private Map<IPolicy<N, A>, Double> meansOfObservations = new HashMap();
    private Map<IPolicy<N, A>, Integer> numberOfTimesChosen = new HashMap();

    public EnsembleTreePolicy(Collection<? extends IPathUpdatablePolicy<N, A, Double>> collection) {
        this.treePolicies = new ArrayList(collection);
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy
    public A getAction(N n, Collection<A> collection) throws ActionPredictionFailedException, InterruptedException {
        this.calls++;
        if (this.rand.nextDouble() < 1.1d) {
            this.lastPolicy = this.treePolicies.get(this.rand.nextInt(this.treePolicies.size()));
            return this.lastPolicy.getAction(n, collection);
        }
        double d = Double.MAX_VALUE;
        IPathUpdatablePolicy<N, A, Double> iPathUpdatablePolicy = null;
        for (IPathUpdatablePolicy<N, A, Double> iPathUpdatablePolicy2 : this.treePolicies) {
            double doubleValue = this.numberOfTimesChosen.containsKey(iPathUpdatablePolicy2) ? this.meansOfObservations.get(iPathUpdatablePolicy2).doubleValue() + ((-1.0d) * Math.sqrt(2.0d) * Math.sqrt(Math.log(this.calls) / this.numberOfTimesChosen.get(iPathUpdatablePolicy2).intValue())) : 0.0d;
            if (doubleValue < d) {
                d = doubleValue;
                iPathUpdatablePolicy = iPathUpdatablePolicy2;
            }
        }
        Objects.requireNonNull(iPathUpdatablePolicy);
        this.lastPolicy = iPathUpdatablePolicy;
        return iPathUpdatablePolicy.getAction(n, collection);
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, List<Double> list) {
        Iterator<IPathUpdatablePolicy<N, A, Double>> it = this.treePolicies.iterator();
        while (it.hasNext()) {
            it.next().updatePath(iLabeledPath, list);
        }
        int intValue = this.numberOfTimesChosen.computeIfAbsent(this.lastPolicy, iPolicy -> {
            return 0;
        }).intValue();
        this.numberOfTimesChosen.put(this.lastPolicy, Integer.valueOf(intValue + 1));
        this.meansOfObservations.put(this.lastPolicy, Double.valueOf(((this.meansOfObservations.computeIfAbsent(this.lastPolicy, iPolicy2 -> {
            return Double.valueOf(0.0d);
        }).doubleValue() * intValue) + list.stream().reduce((d, d2) -> {
            return Double.valueOf(d.doubleValue() + d2.doubleValue());
        }).get().doubleValue()) / (intValue + 1)));
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IGraphDependentPolicy
    public void setGraph(LabeledGraph<N, A> labeledGraph) {
        for (IPathUpdatablePolicy<N, A, Double> iPathUpdatablePolicy : this.treePolicies) {
            if (iPathUpdatablePolicy instanceof IGraphDependentPolicy) {
                ((IGraphDependentPolicy) iPathUpdatablePolicy).setGraph(labeledGraph);
            }
        }
    }
}
