package meka.classifiers.multilabel;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.locks.ReentrantLock;
import java.util.stream.IntStream;
import java.util.stream.Stream;
import meka.core.F;
import meka.core.MLUtils;
import meka.core.MultiLabelDrawable;
import weka.classifiers.AbstractClassifier;
import weka.classifiers.Classifier;
import weka.core.Drawable;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.RevisionUtils;
import weka.filters.Filter;
import weka.filters.unsupervised.attribute.NominalToBinary;

/* loaded from: input_file:meka/classifiers/multilabel/BR.class */
public class BR extends ProblemTransformationMethod implements MultiLabelDrawable {
    private static final long serialVersionUID = -5390512540469007904L;
    protected Classifier[] m_MultiClassifiers = null;
    protected Instances[] m_preFilterInstancesTemplates = null;
    protected Instances[] m_InstancesTemplates = null;
    protected NominalToBinary[] m_NominalToBinary = null;
    private int numThreads = 1;

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public String globalInfo() {
        return "The Binary Relevance Method.\nSee also MULAN framework:\nhttp://mulan.sourceforge.net";
    }

    public void setNumThreads(int i) {
        this.numThreads = i;
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public void buildClassifier(Instances instances) throws Exception {
        testCapabilities(instances);
        int classIndex = instances.classIndex();
        if (getDebug()) {
            System.out.print("Creating " + classIndex + " models (" + this.m_Classifier.getClass().getName() + "): ");
        }
        this.m_MultiClassifiers = AbstractClassifier.makeCopies(this.m_Classifier, classIndex);
        this.m_InstancesTemplates = new Instances[classIndex];
        this.m_preFilterInstancesTemplates = new Instances[classIndex];
        this.m_NominalToBinary = new NominalToBinary[classIndex];
        ReentrantLock reentrantLock = new ReentrantLock();
        IntStream.range(0, classIndex).forEach(i -> {
            this.m_NominalToBinary[i] = new NominalToBinary();
        });
        Semaphore semaphore = new Semaphore(0);
        AtomicBoolean atomicBoolean = new AtomicBoolean(false);
        List synchronizedList = Collections.synchronizedList(new LinkedList());
        LinkedList linkedList = new LinkedList();
        IntStream.range(0, classIndex).forEach(i2 -> {
            linkedList.add(new Runnable() { // from class: meka.classifiers.multilabel.BR.1
                @Override // java.lang.Runnable
                public void run() {
                    try {
                        Instances keepLabels = F.keepLabels(new Instances(instances), classIndex, new int[]{i2});
                        keepLabels.setClassIndex(0);
                        reentrantLock.lock();
                        try {
                            BR.this.m_preFilterInstancesTemplates[i2] = new Instances(keepLabels, 0);
                            reentrantLock.unlock();
                            BR.this.m_NominalToBinary[i2].setInputFormat(keepLabels);
                            Instances useFilter = Filter.useFilter(keepLabels, BR.this.m_NominalToBinary[i2]);
                            reentrantLock.lock();
                            try {
                                BR.this.m_InstancesTemplates[i2] = new Instances(useFilter, 0);
                                reentrantLock.unlock();
                                BR.this.m_MultiClassifiers[i2].buildClassifier(useFilter);
                                if (BR.this.getDebug()) {
                                    System.out.println(Thread.currentThread().getName() + ": " + useFilter.classAttribute().name());
                                }
                                semaphore.release();
                            } finally {
                            }
                        } finally {
                        }
                    } catch (Throwable th) {
                        synchronizedList.add(th);
                        atomicBoolean.set(true);
                        semaphore.release(classIndex);
                    }
                }
            });
        });
        if (this.numThreads > 1) {
            ThreadPoolExecutor threadPoolExecutor = (ThreadPoolExecutor) Executors.newFixedThreadPool(this.numThreads);
            Stream stream = linkedList.stream();
            Objects.requireNonNull(threadPoolExecutor);
            stream.forEach(threadPoolExecutor::submit);
            semaphore.acquire(classIndex);
            if (atomicBoolean.get()) {
                threadPoolExecutor.shutdownNow();
                throw new Exception((Throwable) synchronizedList.get(0));
            }
            threadPoolExecutor.shutdown();
            threadPoolExecutor.awaitTermination(24L, TimeUnit.HOURS);
        } else {
            Iterator it = linkedList.iterator();
            while (it.hasNext()) {
                ((Runnable) it.next()).run();
                if (atomicBoolean.get()) {
                    throw new Exception((Throwable) synchronizedList.get(0));
                }
            }
        }
        for (Instances instances2 : this.m_InstancesTemplates) {
            if (instances2 == null) {
                throw new Exception("Not all instances templates are filled.");
            }
        }
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public double[] distributionForInstance(Instance instance) throws Exception {
        int classIndex = instance.classIndex();
        double[] dArr = new double[classIndex];
        for (int i = 0; i < classIndex; i++) {
            if (Thread.currentThread().isInterrupted()) {
                throw new InterruptedException("Thread has been interrupted.");
            }
            Instance instance2 = (Instance) instance.copy();
            instance2.setDataset((Instances) null);
            Instance keepAttributesAt = MLUtils.keepAttributesAt(instance2, new int[]{i}, classIndex);
            Instances instances = new Instances(this.m_preFilterInstancesTemplates[i], 0);
            instances.add(keepAttributesAt);
            Instance instance3 = Filter.useFilter(instances, this.m_NominalToBinary[i]).get(0);
            instance3.setDataset(this.m_InstancesTemplates[i]);
            dArr[i] = this.m_MultiClassifiers[i].distributionForInstance(instance3)[1];
        }
        return dArr;
    }

    @Override // meka.core.MultiLabelDrawable
    public Map<Integer, Integer> graphType() {
        HashMap hashMap = new HashMap();
        if (this.m_MultiClassifiers != null) {
            for (int i = 0; i < this.m_MultiClassifiers.length; i++) {
                if (this.m_MultiClassifiers[i] instanceof Drawable) {
                    hashMap.put(Integer.valueOf(i), Integer.valueOf(this.m_MultiClassifiers[i].graphType()));
                }
            }
        }
        return hashMap;
    }

    @Override // meka.core.MultiLabelDrawable
    public Map<Integer, String> graph() throws Exception {
        HashMap hashMap = new HashMap();
        if (this.m_MultiClassifiers != null) {
            for (int i = 0; i < this.m_MultiClassifiers.length; i++) {
                if (this.m_MultiClassifiers[i] instanceof Drawable) {
                    hashMap.put(Integer.valueOf(i), this.m_MultiClassifiers[i].graph());
                }
            }
        }
        return hashMap;
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public String getRevision() {
        return RevisionUtils.extract("$Revision: 9117 $");
    }

    public static void main(String[] strArr) {
        BR br = new BR();
        br.setNumThreads(4);
        ProblemTransformationMethod.evaluation(br, strArr);
    }
}
