package edu.emory.clir.clearnlp.classification.trainer;

import edu.emory.clir.clearnlp.classification.model.SparseModel;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.vector.SparseFeatureVector;
import edu.emory.clir.clearnlp.util.BinUtils;
import edu.emory.clir.clearnlp.util.MathUtils;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/* loaded from: input_file:edu/emory/clir/clearnlp/classification/trainer/AbstractOneVsAllTrainer.class */
public abstract class AbstractOneVsAllTrainer extends AbstractTrainer {
    protected int n_threads;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:edu/emory/clir/clearnlp/classification/trainer/AbstractOneVsAllTrainer$TrainTask.class */
    public class TrainTask implements Runnable {
        int curr_label;

        public TrainTask(int i) {
            this.curr_label = i;
        }

        @Override // java.lang.Runnable
        public void run() {
            AbstractOneVsAllTrainer.this.update(this.curr_label);
        }
    }

    public AbstractOneVsAllTrainer(SparseModel sparseModel, int i) {
        super(TrainerType.ONE_VS_ALL, sparseModel);
        setNumberOfThreads(i);
    }

    public AbstractOneVsAllTrainer(StringModel stringModel, int i, int i2, int i3) {
        super(TrainerType.ONE_VS_ALL, stringModel, i, i2);
        setNumberOfThreads(i3);
    }

    public void setNumberOfThreads(int i) {
        this.n_threads = i;
    }

    @Override // edu.emory.clir.clearnlp.classification.trainer.AbstractTrainer
    public void train() {
        if (this.w_vector.isBinaryLabel()) {
            trainBinary();
        } else {
            trainMulti();
        }
    }

    private void trainBinary() {
        update(0);
    }

    private void trainMulti() {
        ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.n_threads);
        int labelSize = this.w_vector.getLabelSize();
        BinUtils.LOG.info("One vs. All\n");
        for (int i = 0; i < labelSize; i++) {
            newFixedThreadPool.execute(new TrainTask(i));
        }
        newFixedThreadPool.shutdown();
        try {
            newFixedThreadPool.awaitTermination(Long.MAX_VALUE, TimeUnit.NANOSECONDS);
        } catch (InterruptedException e) {
            e.printStackTrace();
        }
    }

    protected abstract void update(int i);

    /* JADX INFO: Access modifiers changed from: protected */
    public byte[] getBinaryLabels(int i) {
        int instanceSize = getInstanceSize();
        byte[] bArr = new byte[instanceSize];
        for (int i2 = 0; i2 < instanceSize; i2++) {
            bArr[i2] = getInstance(i2).isLabel(i) ? (byte) 1 : (byte) -1;
        }
        return bArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double getScore(float[] fArr, SparseFeatureVector sparseFeatureVector, double d) {
        double d2 = fArr[0] * d;
        int size = sparseFeatureVector.size();
        for (int i = 0; i < size; i++) {
            d2 += fArr[sparseFeatureVector.getIndex(i)] * sparseFeatureVector.getWeight(i);
        }
        return d2;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void update(float[] fArr, SparseFeatureVector sparseFeatureVector, double d, double d2) {
        fArr[0] = (float) (fArr[0] + (d2 * d));
        int size = sparseFeatureVector.size();
        for (int i = 0; i < size; i++) {
            fArr[sparseFeatureVector.getIndex(i)] = (float) (fArr[r1] + (d2 * sparseFeatureVector.getWeight(i)));
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public double[] getSumOfSquares(double d, double d2) {
        int instanceSize = getInstanceSize();
        double[] dArr = new double[instanceSize];
        double sq = d + MathUtils.sq(d2);
        for (int i = 0; i < instanceSize; i++) {
            dArr[i] = sq + getInstance(i).getFeatureVector().sumOfSquares();
        }
        return dArr;
    }
}
