package edu.emory.clir.clearnlp.classification.model;

import edu.emory.clir.clearnlp.classification.instance.AbstractInstance;
import edu.emory.clir.clearnlp.classification.instance.AbstractInstanceCollector;
import edu.emory.clir.clearnlp.classification.instance.IntInstance;
import edu.emory.clir.clearnlp.classification.map.LabelMap;
import edu.emory.clir.clearnlp.classification.prediction.StringPrediction;
import edu.emory.clir.clearnlp.classification.vector.AbstractFeatureVector;
import edu.emory.clir.clearnlp.classification.vector.AbstractWeightVector;
import edu.emory.clir.clearnlp.classification.vector.BinaryWeightVector;
import edu.emory.clir.clearnlp.classification.vector.MultiWeightVector;
import edu.emory.clir.clearnlp.collection.pair.DoubleIntPair;
import edu.emory.clir.clearnlp.collection.pair.Pair;
import edu.emory.clir.clearnlp.util.BinUtils;
import edu.emory.clir.clearnlp.util.DSUtils;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Deque;
import java.util.Iterator;
import java.util.List;
import org.tukaani.xz.LZMA2Options;
import org.tukaani.xz.XZInputStream;
import org.tukaani.xz.XZOutputStream;

/* loaded from: input_file:edu/emory/clir/clearnlp/classification/model/AbstractModel.class */
public abstract class AbstractModel<I extends AbstractInstance<F>, F extends AbstractFeatureVector> implements Serializable {
    private static final long serialVersionUID = 6096015874433178106L;
    protected AbstractInstanceCollector<I, F> i_collector;
    protected AbstractWeightVector w_vector;
    protected LabelMap m_labels;

    public AbstractModel(boolean z) {
        this.w_vector = z ? new BinaryWeightVector() : new MultiWeightVector();
        this.m_labels = new LabelMap();
    }

    public AbstractModel(ObjectInputStream objectInputStream) {
        try {
            load(objectInputStream);
        } catch (IOException | ClassNotFoundException e) {
            e.printStackTrace();
        }
    }

    private void readObject(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException {
        load(objectInputStream);
    }

    private void writeObject(ObjectOutputStream objectOutputStream) throws IOException {
        save(objectOutputStream);
    }

    public abstract void load(ObjectInputStream objectInputStream) throws IOException, ClassNotFoundException;

    public abstract void save(ObjectOutputStream objectOutputStream) throws IOException;

    public abstract void addInstance(I i);

    public void addInstances(Collection<I> collection) {
        Iterator<I> it = collection.iterator();
        while (it.hasNext()) {
            addInstance(it.next());
        }
    }

    public int getLabelIndex(String str) {
        return this.m_labels.getLabelIndex(str);
    }

    public int getLabelSize() {
        return this.w_vector.getLabelSize();
    }

    public int getFeatureSize() {
        return this.w_vector.getFeatureSize();
    }

    public String[] getLabels() {
        return this.m_labels.getLabels();
    }

    public AbstractWeightVector getWeightVector() {
        return this.w_vector;
    }

    public void setWeightVector(AbstractWeightVector abstractWeightVector) {
        this.w_vector = abstractWeightVector;
    }

    public boolean isBinaryLabel() {
        return this.w_vector.isBinaryLabel();
    }

    public void loadWeightVectorFromByteArray(byte[] bArr) throws Exception {
        ObjectInputStream objectInputStream = new ObjectInputStream(new XZInputStream(new BufferedInputStream(new ByteArrayInputStream(bArr))));
        setWeightVector((AbstractWeightVector) objectInputStream.readObject());
        objectInputStream.close();
    }

    public byte[] saveWeightVectorToByteArray() throws Exception {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new XZOutputStream(new BufferedOutputStream(byteArrayOutputStream), new LZMA2Options()));
        objectOutputStream.writeObject(this.w_vector);
        objectOutputStream.close();
        return byteArrayOutputStream.toByteArray();
    }

    public abstract IntInstance toIntInstance(I i);

    public List<IntInstance> toIntInstanceList(Deque<I> deque) {
        BinUtils.LOG.info("Vectorizing: " + deque.size() + "\n");
        ArrayList arrayList = new ArrayList();
        int i = 1;
        while (!deque.isEmpty()) {
            IntInstance intInstance = toIntInstance(deque.poll());
            if (intInstance != null) {
                arrayList.add(intInstance);
            }
            if (i % 100000 == 0) {
                BinUtils.LOG.info(".");
            }
            i++;
        }
        if (arrayList.size() > 100000) {
            BinUtils.LOG.info("\n\n");
        } else {
            BinUtils.LOG.info("\n");
        }
        arrayList.trimToSize();
        return arrayList;
    }

    public abstract double[] getScores(F f);

    public abstract double[] getScores(F f, int[] iArr);

    public StringPrediction getPrediction(int i, double d) {
        return new StringPrediction(this.m_labels.getLabel(i), d);
    }

    public StringPrediction predictBest(F f) {
        return isBinaryLabel() ? predictBestBinary(f) : predictBestMulti(f);
    }

    private StringPrediction predictBestBinary(F f) {
        double[] scores = getScores(f);
        return scores[0] > 0.0d ? getPrediction(0, scores[0]) : getPrediction(1, scores[1]);
    }

    private StringPrediction predictBestMulti(F f) {
        double[] scores = getScores(f);
        int length = scores.length;
        int i = 0;
        double d = scores[0];
        for (int i2 = 1; i2 < length; i2++) {
            if (d < scores[i2]) {
                i = i2;
                d = scores[i];
            }
        }
        return getPrediction(i, d);
    }

    public StringPrediction[] predictTop2(F f) {
        return isBinaryLabel() ? predictTop2Binary(f) : predictTop2Multi(f);
    }

    private StringPrediction[] predictTop2Binary(F f) {
        double[] scores = getScores(f);
        StringPrediction prediction = getPrediction(0, scores[0]);
        StringPrediction prediction2 = getPrediction(1, scores[1]);
        return scores[0] > 0.0d ? new StringPrediction[]{prediction, prediction2} : new StringPrediction[]{prediction2, prediction};
    }

    private StringPrediction[] predictTop2Multi(F f) {
        Pair<DoubleIntPair, DoubleIntPair> p2Var = DSUtils.top2(getScores(f));
        DoubleIntPair doubleIntPair = p2Var.o1;
        DoubleIntPair doubleIntPair2 = p2Var.o2;
        return new StringPrediction[]{getPrediction(doubleIntPair.i, doubleIntPair.d), getPrediction(doubleIntPair2.i, doubleIntPair2.d)};
    }

    public StringPrediction[] predictAll(F f) {
        return isBinaryLabel() ? predictTop2Binary(f) : predictAllMulti(f);
    }

    private StringPrediction[] predictAllMulti(F f) {
        double[] scores = getScores(f);
        int labelSize = getLabelSize();
        StringPrediction[] stringPredictionArr = new StringPrediction[labelSize];
        for (int i = 0; i < labelSize; i++) {
            stringPredictionArr[i] = getPrediction(i, scores[i]);
        }
        DSUtils.sortReverseOrder(stringPredictionArr);
        return stringPredictionArr;
    }

    public StringPrediction predictBest(F f, int[] iArr) {
        return isBinaryLabel() ? predictBestBinary(f) : predictBestMulti(f, iArr);
    }

    private StringPrediction predictBestMulti(F f, int[] iArr) {
        double[] scores = getScores(f, iArr);
        int length = iArr.length;
        int i = iArr[0];
        double d = scores[i];
        for (int i2 = 1; i2 < length; i2++) {
            if (d < scores[iArr[i2]]) {
                i = iArr[i2];
                d = scores[i];
            }
        }
        return getPrediction(i, d);
    }

    public StringPrediction[] predictTop2(F f, int[] iArr) {
        return isBinaryLabel() ? predictTop2Binary(f) : predictTop2Multi(f, iArr);
    }

    private StringPrediction[] predictTop2Multi(F f, int[] iArr) {
        Pair<DoubleIntPair, DoubleIntPair> p2Var = DSUtils.top2(getScores(f, iArr), iArr);
        DoubleIntPair doubleIntPair = p2Var.o1;
        DoubleIntPair doubleIntPair2 = p2Var.o2;
        return new StringPrediction[]{getPrediction(doubleIntPair.i, doubleIntPair.d), getPrediction(doubleIntPair2.i, doubleIntPair2.d)};
    }

    public StringPrediction[] predictAll(F f, int[] iArr) {
        return isBinaryLabel() ? predictTop2Binary(f) : predictAllMulti(f, iArr);
    }

    private StringPrediction[] predictAllMulti(F f, int[] iArr) {
        double[] scores = getScores(f, iArr);
        int length = iArr.length;
        StringPrediction[] stringPredictionArr = new StringPrediction[length];
        for (int i = 0; i < length; i++) {
            int i2 = iArr[i];
            stringPredictionArr[i] = getPrediction(i2, scores[i2]);
        }
        DSUtils.sortReverseOrder(stringPredictionArr);
        return stringPredictionArr;
    }
}
