package ai.libs.reduction.single.confusion;

import ai.libs.jaicore.basic.sets.SetUtil;
import ai.libs.jaicore.ml.WekaUtil;
import ai.libs.jaicore.ml.classification.multiclass.reduction.MCTreeNodeReD;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.core.Instances;

/* loaded from: input_file:ai/libs/reduction/single/confusion/ConfusionBasedGreedyOptimizingAlgorithm.class */
public class ConfusionBasedGreedyOptimizingAlgorithm extends AConfusionBasedAlgorithm {
    private static Logger logger = LoggerFactory.getLogger(ConfusionBasedGreedyOptimizingAlgorithm.class);

    public MCTreeNodeReD buildClassifier(Instances instances, Collection<String> collection) throws Exception {
        if (logger.isInfoEnabled()) {
            logger.info("START: {}", instances.relationName());
        }
        List stratifiedSplit = WekaUtil.getStratifiedSplit(instances, 0, 0.699999988079071d);
        int numClasses = instances.numClasses();
        logger.info("Computing confusion matrices ...");
        HashMap hashMap = new HashMap();
        for (String str : collection) {
            logger.info("\t{} ...", str);
            try {
                Classifier forName = AbstractClassifier.forName(str, (String[]) null);
                forName.buildClassifier((Instances) stratifiedSplit.get(0));
                Evaluation evaluation = new Evaluation((Instances) stratifiedSplit.get(0));
                evaluation.evaluateModel(forName, (Instances) stratifiedSplit.get(1), new Object[0]);
                hashMap.put(str, evaluation.confusionMatrix());
            } catch (Exception e) {
                logger.error("Could not train classifier: {}", e);
            }
        }
        logger.info("done");
        HashMap hashMap2 = new HashMap();
        for (Map.Entry entry : hashMap.entrySet()) {
            hashMap2.put((String) entry.getKey(), getZeroConflictSets((double[][]) entry.getValue()));
        }
        Collection<List> cartesianProduct = SetUtil.cartesianProduct(hashMap.keySet(), 2);
        int i = Integer.MAX_VALUE;
        String str2 = null;
        String str3 = null;
        String str4 = null;
        Collection<Integer> collection2 = null;
        Collection<Integer> collection3 = null;
        int i2 = 0;
        for (List list : cartesianProduct) {
            i2++;
            String str5 = (String) list.get(0);
            String str6 = (String) list.get(1);
            logger.info("\tConsidering {}/{} ({}/{})", new Object[]{str5, str6, Integer.valueOf(i2), Integer.valueOf(cartesianProduct.size())});
            double[][] dArr = (double[][]) hashMap.get(str5);
            double[][] dArr2 = (double[][]) hashMap.get(str6);
            Collection<Collection<Integer>> collection4 = (Collection) hashMap2.get(str5);
            Collection<Collection<Integer>> collection5 = (Collection) hashMap2.get(str6);
            int i3 = 0;
            Collection<Integer> collection6 = null;
            Collection<Integer> collection7 = null;
            for (Collection<Integer> collection8 : collection4) {
                for (Collection<Integer> collection9 : collection5) {
                    Collection union = SetUtil.union(new Collection[]{collection8, collection9});
                    if (union.size() > i3) {
                        i3 = union.size();
                        collection6 = collection8;
                        collection7 = collection9;
                    }
                }
            }
            for (int i4 = 0; i4 < numClasses; i4++) {
                if (!collection6.contains(Integer.valueOf(i4)) && !collection7.contains(Integer.valueOf(i4))) {
                    ArrayList arrayList = new ArrayList(collection6);
                    arrayList.add(Integer.valueOf(i4));
                    int penaltyOfCluster = getPenaltyOfCluster(arrayList, dArr);
                    ArrayList arrayList2 = new ArrayList(collection7);
                    arrayList2.add(Integer.valueOf(i4));
                    if (penaltyOfCluster < getPenaltyOfCluster(arrayList2, dArr2)) {
                        collection6 = arrayList;
                    } else {
                        collection7 = arrayList2;
                    }
                }
            }
            int penaltyOfCluster2 = getPenaltyOfCluster(collection6, dArr);
            int penaltyOfCluster3 = getPenaltyOfCluster(collection7, dArr2);
            HashMap hashMap3 = new HashMap();
            Iterator<Integer> it = collection6.iterator();
            while (it.hasNext()) {
                hashMap3.put(instances.classAttribute().value(it.next().intValue()), "l");
            }
            Iterator<Integer> it2 = collection7.iterator();
            while (it2.hasNext()) {
                hashMap3.put(instances.classAttribute().value(it2.next().intValue()), "r");
            }
            Instances refactoredInstances = WekaUtil.getRefactoredInstances(instances, hashMap3);
            List stratifiedSplit2 = WekaUtil.getStratifiedSplit(refactoredInstances, 0, 0.699999988079071d);
            for (String str7 : collection) {
                try {
                    logger.info("\t\tConsidering {}/{}/{}", new Object[]{str5, str6, str7});
                    Classifier forName2 = AbstractClassifier.forName(str7, (String[]) null);
                    forName2.buildClassifier((Instances) stratifiedSplit2.get(0));
                    Evaluation evaluation2 = new Evaluation(refactoredInstances);
                    evaluation2.evaluateModel(forName2, (Instances) stratifiedSplit2.get(1), new Object[0]);
                    int incorrect = penaltyOfCluster2 + penaltyOfCluster3 + ((int) evaluation2.incorrect());
                    if (incorrect < i) {
                        i = incorrect;
                        logger.info("New best system: {}/{}/{} with {}", new Object[]{str5, str6, str7, Integer.valueOf(i)});
                        collection2 = collection6;
                        collection3 = collection7;
                        str2 = str5;
                        str3 = str6;
                        str4 = str7;
                    }
                } catch (Exception e2) {
                    logger.error("Encountered error: {}", e2);
                }
            }
        }
        if (collection2 == null) {
            throw new IllegalStateException("Best left classes must not be null");
        }
        MCTreeNodeReD mCTreeNodeReD = new MCTreeNodeReD(str4, (Collection) collection2.stream().map(num -> {
            return instances.classAttribute().value(num.intValue());
        }).collect(Collectors.toList()), str2, (Collection) collection3.stream().map(num2 -> {
            return instances.classAttribute().value(num2.intValue());
        }).collect(Collectors.toList()), str3);
        mCTreeNodeReD.buildClassifier(instances);
        return mCTreeNodeReD;
    }

    private Collection<Collection<Integer>> getZeroConflictSets(double[][] dArr) {
        int leastConflictingClass;
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        do {
            leastConflictingClass = getLeastConflictingClass(dArr, arrayList);
            if (leastConflictingClass >= 0) {
                Collection<Integer> arrayList3 = new ArrayList();
                arrayList3.add(Integer.valueOf(leastConflictingClass));
                do {
                    arrayList3 = incrementCluster(arrayList3, dArr, arrayList);
                    if (arrayList3.contains(-1)) {
                        throw new IllegalStateException("Computed illegal cluster: " + arrayList3);
                    }
                } while (getPenaltyOfCluster(arrayList3, dArr) == 0);
                arrayList.addAll(arrayList3);
                arrayList2.add(arrayList3);
            }
        } while (leastConflictingClass >= 0);
        return arrayList2;
    }
}
