package ai.libs.jaicore.search.algorithms.mdp.mcts.uuct;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.api4.java.datastructure.graph.ILabeledPath;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/uuct/UUCBPolicy.class */
public class UUCBPolicy<N, A> implements IPathUpdatablePolicy<N, A, Double> {
    private static final double ALPHA = 3.0d;
    private final IUCBUtilityFunction utilityFunction;
    private final double a;
    private final double b;
    private final double q;
    private final Map<N, Map<A, DoubleList>> observations = new HashMap();
    private int t = 0;

    public UUCBPolicy(IUCBUtilityFunction iUCBUtilityFunction) {
        this.utilityFunction = iUCBUtilityFunction;
        this.a = iUCBUtilityFunction.getA();
        this.b = iUCBUtilityFunction.getB();
        this.q = iUCBUtilityFunction.getQ();
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy
    public A getAction(N n, Collection<A> collection) throws ActionPredictionFailedException {
        double d = -1.7976931348623157E308d;
        A a = null;
        Map<A, DoubleList> map = this.observations.get(n);
        if (map == null) {
            return (A) SetUtil.getRandomElement(collection, new Random().nextLong());
        }
        for (A a2 : collection) {
            DoubleList doubleList = map.get(a2);
            if (doubleList != null) {
                double utility = this.utilityFunction.getUtility(doubleList) + phiInverse((ALPHA * Math.log(this.t)) / doubleList.size());
                if (utility > d) {
                    d = utility;
                    a = a2;
                }
            }
        }
        return a == null ? (A) SetUtil.getRandomElement(collection, new Random().nextLong()) : a;
    }

    private double phiInverse(double d) {
        return Math.max(2.0d * this.b * Math.sqrt(d / this.a), 2.0d * this.b * Math.pow(d / this.a, this.q / 2.0d));
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, List<Double> list) {
        double doubleValue = list.stream().reduce((d, d2) -> {
            return Double.valueOf(d.doubleValue() + d2.doubleValue());
        }).get().doubleValue();
        iLabeledPath.getPathToParentOfHead().getNodes().forEach(obj -> {
            DoubleList doubleList = (DoubleList) this.observations.computeIfAbsent(obj, obj -> {
                return new HashMap();
            }).computeIfAbsent(iLabeledPath.getOutArc(obj), obj2 -> {
                return new DoubleArrayList();
            });
            int size = doubleList.size();
            if (size == 0) {
                doubleList.add(doubleValue);
                return;
            }
            if (doubleValue <= doubleList.getDouble(0)) {
                doubleList.add(0, doubleValue);
                return;
            }
            double d3 = doubleList.getDouble(0);
            for (int i = 1; i < size; i++) {
                double d4 = doubleList.getDouble(i);
                if (doubleValue >= d3 && doubleValue <= d4) {
                    doubleList.add(i, doubleValue);
                    return;
                }
                d3 = d4;
            }
        });
        this.t++;
    }
}
