package ai.libs.jaicore.ml.dyadranking.activelearning;

import ai.libs.jaicore.math.linearalgebra.Vector;
import ai.libs.jaicore.ml.core.exception.TrainingException;
import ai.libs.jaicore.ml.dyadranking.Dyad;
import ai.libs.jaicore.ml.dyadranking.algorithm.PLNetDyadRanker;
import ai.libs.jaicore.ml.dyadranking.dataset.SparseDyadRankingInstance;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Random;
import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.clusterers.Clusterer;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/jaicore/ml/dyadranking/activelearning/ConfidenceIntervalClusteringBasedActiveDyadRanker.class */
public class ConfidenceIntervalClusteringBasedActiveDyadRanker extends ARandomlyInitializingDyadRanker {
    private static final Logger log = LoggerFactory.getLogger(ConfidenceIntervalClusteringBasedActiveDyadRanker.class);
    private Clusterer clusterer;

    /* loaded from: input_file:ai/libs/jaicore/ml/dyadranking/activelearning/ConfidenceIntervalClusteringBasedActiveDyadRanker$ListComparator.class */
    private class ListComparator implements Comparator<List<Dyad>> {
        private ListComparator() {
        }

        @Override // java.util.Comparator
        public int compare(List<Dyad> list, List<Dyad> list2) {
            if (list.size() > list2.size()) {
                return -1;
            }
            return list.size() < list2.size() ? 1 : 0;
        }
    }

    public ConfidenceIntervalClusteringBasedActiveDyadRanker(PLNetDyadRanker pLNetDyadRanker, IDyadRankingPoolProvider iDyadRankingPoolProvider, int i, int i2, int i3, Clusterer clusterer) {
        super(pLNetDyadRanker, iDyadRankingPoolProvider, i, i2, i3);
        this.clusterer = clusterer;
    }

    @Override // ai.libs.jaicore.ml.dyadranking.activelearning.ARandomlyInitializingDyadRanker, ai.libs.jaicore.ml.dyadranking.activelearning.ActiveDyadRanker
    public void activelyTrainWithOneInstance() {
        PriorityQueue priorityQueue = new PriorityQueue(new ListComparator());
        HashSet hashSet = new HashSet();
        Map<Dyad, SummaryStatistics> dyadStats = getDyadStats();
        for (Vector vector : getInstanceFeatures()) {
            Attribute attribute = new Attribute("upper_bound");
            Attribute attribute2 = new Attribute("lower_bound");
            ArrayList arrayList = new ArrayList();
            arrayList.add(attribute);
            arrayList.add(attribute2);
            Instances instances = new Instances("confidence_intervalls", arrayList, this.poolProvider.getDyadsByInstance(vector).size());
            for (Dyad dyad : this.poolProvider.getDyadsByInstance(vector)) {
                double skillForDyad = this.ranker.getSkillForDyad(dyad);
                dyadStats.get(dyad).addValue(skillForDyad);
                instances.add(new DenseInstance(1.0d, new double[]{skillForDyad + dyadStats.get(dyad).getStandardDeviation(), skillForDyad - dyadStats.get(dyad).getStandardDeviation()}));
            }
            try {
                this.clusterer.buildClusterer(instances);
                ArrayList arrayList2 = new ArrayList();
                int numberOfClusters = this.clusterer.numberOfClusters();
                for (int i = 0; i < numberOfClusters; i++) {
                    arrayList2.add(new ArrayList());
                }
                for (Dyad dyad2 : this.poolProvider.getDyadsByInstance(vector)) {
                    double skillForDyad2 = this.ranker.getSkillForDyad(dyad2);
                    ((List) arrayList2.get(this.clusterer.clusterInstance(new DenseInstance(1.0d, new double[]{skillForDyad2 + dyadStats.get(dyad2).getStandardDeviation(), skillForDyad2 - dyadStats.get(dyad2).getStandardDeviation()})))).add(dyad2);
                }
                for (int i2 = 0; i2 < arrayList2.size(); i2++) {
                    priorityQueue.add(arrayList2.get(i2));
                }
            } catch (Exception e) {
                log.error(e.getMessage());
            }
        }
        Random random = getRandom();
        for (int i3 = 0; i3 < getMinibatchSize(); i3++) {
            List list = (List) priorityQueue.poll();
            if (list.size() >= 2) {
                double d = -1.0d;
                int[] iArr = {0, 1};
                boolean z = false;
                for (int i4 = 1; i4 < list.size(); i4++) {
                    for (int i5 = 0; i5 < i4; i5++) {
                        double confidenceIntervalOverlapForDyads = getConfidenceIntervalOverlapForDyads((Dyad) list.get(i4), (Dyad) list.get(i5));
                        if (confidenceIntervalOverlapForDyads > d) {
                            iArr[0] = i4;
                            iArr[1] = i5;
                            d = confidenceIntervalOverlapForDyads;
                            z = true;
                        }
                    }
                }
                if (!z) {
                    iArr[0] = random.nextInt(list.size());
                    iArr[1] = random.nextInt(list.size());
                    while (iArr[0] == iArr[1]) {
                        iArr[1] = random.nextInt(list.size());
                    }
                }
                LinkedList linkedList = new LinkedList();
                linkedList.add(((Dyad) list.get(iArr[0])).getAlternative());
                linkedList.add(((Dyad) list.get(iArr[1])).getAlternative());
                hashSet.add(this.poolProvider.query(new SparseDyadRankingInstance(((Dyad) list.get(iArr[0])).getInstance(), linkedList)));
            }
        }
        try {
            updateRanker(hashSet);
        } catch (TrainingException e2) {
            log.error(e2.getMessage());
        }
    }

    private double getConfidenceIntervalOverlapForDyads(Dyad dyad, Dyad dyad2) {
        double skillForDyad = this.ranker.getSkillForDyad(dyad);
        double skillForDyad2 = this.ranker.getSkillForDyad(dyad2);
        Map<Dyad, SummaryStatistics> dyadStats = getDyadStats();
        double standardDeviation = skillForDyad - dyadStats.get(dyad).getStandardDeviation();
        double standardDeviation2 = skillForDyad + dyadStats.get(dyad).getStandardDeviation();
        double standardDeviation3 = skillForDyad2 - dyadStats.get(dyad2).getStandardDeviation();
        double standardDeviation4 = skillForDyad2 + dyadStats.get(dyad2).getStandardDeviation();
        if (standardDeviation > standardDeviation4 || standardDeviation2 < standardDeviation3) {
            return 0.0d;
        }
        return Math.abs(Math.min(standardDeviation2, standardDeviation4) - Math.max(standardDeviation, standardDeviation3));
    }
}
