package ai.libs.jaicore.search.algorithms.mdp.mcts.comparison;

import ai.libs.jaicore.basic.IOwnerBasedAlgorithmConfig;
import ai.libs.jaicore.graph.LabeledGraph;
import ai.libs.jaicore.math.probability.pl.PLInferenceProblem;
import ai.libs.jaicore.math.probability.pl.PLInferenceProblemEncoder;
import ai.libs.jaicore.math.probability.pl.PLMMAlgorithm;
import ai.libs.jaicore.search.algorithms.mdp.mcts.ActionPredictionFailedException;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IGraphDependentPolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy;
import ai.libs.jaicore.search.algorithms.mdp.mcts.UniformRandomPolicy;
import com.google.common.eventbus.EventBus;
import it.unimi.dsi.fastutil.doubles.DoubleList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Random;
import org.aeonbits.owner.ConfigFactory;
import org.api4.java.algorithm.exceptions.AlgorithmException;
import org.api4.java.algorithm.exceptions.AlgorithmExecutionCanceledException;
import org.api4.java.algorithm.exceptions.AlgorithmTimeoutedException;
import org.api4.java.common.control.ILoggingCustomizable;
import org.api4.java.common.event.IRelaxedEventEmitter;
import org.api4.java.datastructure.graph.ILabeledPath;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/jaicore/search/algorithms/mdp/mcts/comparison/PlackettLucePolicy.class */
public class PlackettLucePolicy<N, A> implements IPathUpdatablePolicy<N, A, Double>, ILoggingCustomizable, IGraphDependentPolicy<N, A>, IRelaxedEventEmitter {
    private final IPreferenceKernel<N, A> preferenceKernel;
    private final Random random;
    private final UniformRandomPolicy<N, A, Double> randomPolicy;
    private LabeledGraph<N, A> graph;
    private static final double EPSILON = 0.0d;
    private static final int GAMMA_LONG_MAX = 5;
    private static final int GAMMA_LONG_MIN_OBSERVATIONS_PER_CHILD_FOR_SUPPORT_INIT = 5;
    private static final int GAMMA_LONG_MIN_OBSERVATIONS_PER_CHILD_FOR_SUPPORT_ABS = 2;
    private static final int GAMMA_LONG_OBSERVATIONS_PER_CHILD_FOR_ONE = 10;
    private final EventBus eventBus = new EventBus();
    private Logger logger = LoggerFactory.getLogger(PlackettLucePolicy.class);
    private final Map<N, DoubleList> skillVectorsForNodes = new HashMap();
    private final Map<N, Integer> numVisits = new HashMap();
    private final Map<N, Double> deepestRelativeNodeDepthsOfNodes = new HashMap();
    private final Map<A, Double> lastLocalProbabilityOfNode = new HashMap();
    private final IGammaFunction gammaShort = new CosLinGammaFunction(3.0d, 4, GAMMA_LONG_MIN_OBSERVATIONS_PER_CHILD_FOR_SUPPORT_ABS, GAMMA_LONG_MIN_OBSERVATIONS_PER_CHILD_FOR_SUPPORT_ABS);
    private IOwnerBasedAlgorithmConfig config = ConfigFactory.create(IOwnerBasedAlgorithmConfig.class, new Map[0]);

    public PlackettLucePolicy(IPreferenceKernel<N, A> iPreferenceKernel, Random random) {
        this.preferenceKernel = iPreferenceKernel;
        this.random = random;
        this.randomPolicy = new UniformRandomPolicy<>(new Random(random.nextLong()));
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPolicy
    public A getAction(N n, Collection<A> collection) throws ActionPredictionFailedException {
        DoubleList defaultSkillVector;
        if (!this.preferenceKernel.canProduceReliableRankings(n)) {
            this.logger.info("Preference kernel tells us that it cannot produce reliable information yet. Choosing one action at random.");
            return this.randomPolicy.getAction(n, collection);
        }
        try {
            try {
                int intValue = this.numVisits.get(n).intValue();
                int size = this.graph.getSuccessors(n).size();
                this.logger.info("Computing action for node {} with {} successors.", n, Integer.valueOf(size));
                CombinedGammaFunction combinedGammaFunction = new CombinedGammaFunction(this.gammaShort, new CosLinGammaFunction(5.0d, size * GAMMA_LONG_OBSERVATIONS_PER_CHILD_FOR_ONE, size * 5, size * GAMMA_LONG_MIN_OBSERVATIONS_PER_CHILD_FOR_SUPPORT_ABS));
                double doubleValue = this.deepestRelativeNodeDepthsOfNodes.get(n).doubleValue();
                double probabilityOfNode = getProbabilityOfNode(n);
                double nodeGamma = combinedGammaFunction.getNodeGamma(intValue, probabilityOfNode, doubleValue);
                if (size <= 1) {
                    if (size < 1) {
                        throw new UnsupportedOperationException("Cannot compute action for nodes without successors.");
                    }
                    if (collection.size() != 1) {
                        throw new IllegalStateException();
                    }
                    return collection.iterator().next();
                }
                if (this.random.nextDouble() < EPSILON * (1.0d - doubleValue)) {
                    return this.randomPolicy.getAction(n, collection);
                }
                this.logger.debug("Computing PL-Problem instance");
                PLInferenceProblemEncoder pLInferenceProblemEncoder = new PLInferenceProblemEncoder();
                PLInferenceProblem encode = pLInferenceProblemEncoder.encode(this.preferenceKernel.getRankingsForActions(n));
                if (nodeGamma != EPSILON) {
                    this.logger.debug("Start computation of skills for {}. Using {} rankings based on {} visits. Gamma value is {} based on node probability {} and depth {}", new Object[]{n, Integer.valueOf(encode.getRankings().size()), Integer.valueOf(intValue), Double.valueOf(nodeGamma), Double.valueOf(probabilityOfNode), Double.valueOf(doubleValue)});
                    defaultSkillVector = new PLMMAlgorithm(encode, this.skillVectorsForNodes.get(n), this.config).call();
                    this.skillVectorsForNodes.put(n, defaultSkillVector);
                } else {
                    defaultSkillVector = PLMMAlgorithm.getDefaultSkillVector(encode.getNumObjects());
                }
                if (defaultSkillVector.size() != encode.getNumObjects()) {
                    throw new IllegalStateException("Have " + defaultSkillVector.size() + " skills (" + defaultSkillVector + ") for " + encode.getNumObjects() + " objects.");
                }
                int size2 = defaultSkillVector.size();
                double d = 0.0d;
                for (int i = 0; i < size2; i++) {
                    double pow = Math.pow(defaultSkillVector.getDouble(i), nodeGamma);
                    defaultSkillVector.set(i, pow);
                    d += pow;
                }
                if (d == EPSILON) {
                    throw new IllegalStateException();
                }
                for (int i2 = 0; i2 < size2; i2++) {
                    if (this.logger.isDebugEnabled()) {
                        this.logger.debug("Estimating skill of successor {} with action {} and {} visits by {} -> {}", new Object[]{Integer.valueOf(i2), pLInferenceProblemEncoder.getObjectAtIndex(i2), this.numVisits.get(pLInferenceProblemEncoder.getObjectAtIndex(i2)), Double.valueOf(defaultSkillVector.getDouble(i2)), Double.valueOf(defaultSkillVector.getDouble(i2) / d)});
                    }
                    defaultSkillVector.set(i2, defaultSkillVector.getDouble(i2) / d);
                }
                double d2 = 1.0d;
                if (collection.size() != defaultSkillVector.size()) {
                    d2 = 0.0d;
                    Iterator<A> it = collection.iterator();
                    while (it.hasNext()) {
                        d2 += defaultSkillVector.getDouble(pLInferenceProblemEncoder.getIndexOfObject(it.next()));
                    }
                    if (d2 == EPSILON) {
                        this.logger.info("Choosing option with prob 0");
                        return collection.iterator().next();
                    }
                }
                double nextDouble = this.random.nextDouble() * d2;
                double d3 = 0.0d;
                A a = null;
                for (A a2 : collection) {
                    double d4 = defaultSkillVector.getDouble(pLInferenceProblemEncoder.getIndexOfObject(a2));
                    if (Double.isNaN(d4)) {
                        this.logger.error("Probability of successor is NaN! Skill vector: {}", defaultSkillVector);
                    }
                    this.lastLocalProbabilityOfNode.put(a2, Double.valueOf(d4));
                    d3 += d4;
                    if (a == null && d3 >= nextDouble) {
                        a = a2;
                        this.logger.debug("Chose successor {} with skill {}", a, Double.valueOf(d4));
                    }
                }
                if (a == null) {
                    throw new IllegalStateException("Could not find child among " + collection.size() + " successors. Mass of remaining options is " + d2 + ". Drawn random number is " + nextDouble + ". Sum of skimmed probs is " + d3);
                }
                return a;
            } catch (AlgorithmTimeoutedException | AlgorithmExecutionCanceledException | AlgorithmException e) {
                throw new ActionPredictionFailedException(e);
            }
        } catch (InterruptedException e2) {
            this.logger.info("Policy thread has been interrupted. Re-interrupting thread, because no InterruptedException can be thrown here.");
            Thread.currentThread().interrupt();
            return null;
        }
    }

    private double getProbabilityOfNode(N n) {
        Object obj = n;
        double d = 1.0d;
        while (!this.graph.getPredecessors(obj).isEmpty()) {
            if (this.lastLocalProbabilityOfNode.containsKey(obj)) {
                d *= this.lastLocalProbabilityOfNode.get(obj).doubleValue();
            } else {
                this.logger.warn("No probability known for node {}", obj);
            }
            obj = this.graph.getPredecessors(obj).iterator().next();
        }
        return d;
    }

    public void registerListener(Object obj) {
        this.eventBus.register(obj);
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IGraphDependentPolicy
    public void setGraph(LabeledGraph<N, A> labeledGraph) {
        this.graph = labeledGraph;
        this.preferenceKernel.setExplorationGraph(labeledGraph);
    }

    public String getLoggerName() {
        return this.logger.getName();
    }

    public void setLoggerName(String str) {
        this.logger = LoggerFactory.getLogger(str);
        if (this.preferenceKernel instanceof ILoggingCustomizable) {
            this.preferenceKernel.setLoggerName(str + ".kernel");
        }
    }

    @Override // ai.libs.jaicore.search.algorithms.mdp.mcts.IPathUpdatablePolicy
    public void updatePath(ILabeledPath<N, A> iLabeledPath, List<Double> list) {
        this.preferenceKernel.signalNewScore(iLabeledPath, list.stream().reduce((d, d2) -> {
            return Double.valueOf(d.doubleValue() + d2.doubleValue());
        }).get().doubleValue());
        int i = 0;
        for (Object obj : iLabeledPath.getNodes()) {
            double numberOfNodes = (i * 1.0d) / iLabeledPath.getNumberOfNodes();
            this.numVisits.put(obj, Integer.valueOf(((Integer) this.numVisits.computeIfAbsent(obj, obj2 -> {
                return 0;
            })).intValue() + 1));
            if (!this.deepestRelativeNodeDepthsOfNodes.containsKey(obj) || this.deepestRelativeNodeDepthsOfNodes.get(obj).doubleValue() < numberOfNodes) {
                this.deepestRelativeNodeDepthsOfNodes.put(obj, Double.valueOf(numberOfNodes));
            }
            i++;
        }
    }
}
