package edu.emory.clir.clearnlp.component;

import edu.emory.clir.clearnlp.classification.instance.StringInstance;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.trainer.AdaGradSVM;
import edu.emory.clir.clearnlp.classification.vector.StringFeatureVector;
import edu.emory.clir.clearnlp.component.configuration.AbstractConfiguration;
import edu.emory.clir.clearnlp.component.evaluation.AbstractEval;
import edu.emory.clir.clearnlp.component.mode.dep.state.AbstractDEPState;
import edu.emory.clir.clearnlp.component.state.AbstractState;
import edu.emory.clir.clearnlp.component.utils.CFlag;
import edu.emory.clir.clearnlp.dependency.DEPTree;
import edu.emory.clir.clearnlp.feature.AbstractFeatureExtractor;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.tukaani.xz.LZMA2Options;
import org.tukaani.xz.XZInputStream;
import org.tukaani.xz.XZOutputStream;

/* loaded from: input_file:edu/emory/clir/clearnlp/component/AbstractStatisticalComponent.class */
public abstract class AbstractStatisticalComponent<LabelType, StateType extends AbstractState<?, LabelType>, EvalType extends AbstractEval<?>, FeatureType extends AbstractFeatureExtractor<?, ?, ?>> extends AbstractComponent {
    protected AbstractConfiguration t_configuration;
    protected FeatureType[] f_extractors;
    protected StringModel[] s_models;
    protected EvalType c_eval;
    protected CFlag c_flag;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: edu.emory.clir.clearnlp.component.AbstractStatisticalComponent$1, reason: invalid class name */
    /* loaded from: input_file:edu/emory/clir/clearnlp/component/AbstractStatisticalComponent$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$edu$emory$clir$clearnlp$component$utils$CFlag = new int[CFlag.values().length];

        static {
            try {
                $SwitchMap$edu$emory$clir$clearnlp$component$utils$CFlag[CFlag.TRAIN.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$edu$emory$clir$clearnlp$component$utils$CFlag[CFlag.BOOTSTRAP.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    public AbstractStatisticalComponent() {
    }

    public AbstractStatisticalComponent(AbstractConfiguration abstractConfiguration) {
        setConfiguration(abstractConfiguration);
        setFlag(CFlag.COLLECT);
    }

    public AbstractStatisticalComponent(AbstractConfiguration abstractConfiguration, FeatureType[] featuretypeArr, Object obj, boolean z, int i) {
        setConfiguration(abstractConfiguration);
        setFlag(CFlag.TRAIN);
        setFeatureExtractors(featuretypeArr);
        setLexicons(obj);
        setModels(createModels(z, i));
    }

    public AbstractStatisticalComponent(AbstractConfiguration abstractConfiguration, FeatureType[] featuretypeArr, Object obj, StringModel[] stringModelArr, boolean z) {
        setConfiguration(abstractConfiguration);
        if (z) {
            setFlag(CFlag.BOOTSTRAP);
        } else {
            setFlag(CFlag.EVALUATE);
            initEval();
        }
        setFeatureExtractors(featuretypeArr);
        setLexicons(obj);
        setModels(stringModelArr);
    }

    public AbstractStatisticalComponent(AbstractConfiguration abstractConfiguration, ObjectInputStream objectInputStream) {
        setConfiguration(abstractConfiguration);
        initDecode(objectInputStream);
    }

    public AbstractStatisticalComponent(AbstractConfiguration abstractConfiguration, byte[] bArr) {
        setConfiguration(abstractConfiguration);
        initDecode(bArr);
    }

    private StringModel[] createModels(boolean z, int i) {
        StringModel[] stringModelArr = new StringModel[i];
        for (int i2 = 0; i2 < i; i2++) {
            stringModelArr[i2] = new StringModel(z);
        }
        return stringModelArr;
    }

    protected void initDecode(ObjectInputStream objectInputStream) {
        setFlag(CFlag.DECODE);
        try {
            load(objectInputStream);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    protected void initDecode(byte[] bArr) {
        try {
            initDecode(new ObjectInputStream(new XZInputStream(new BufferedInputStream(new ByteArrayInputStream(bArr)))));
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    public void setConfiguration(AbstractConfiguration abstractConfiguration) {
        this.t_configuration = abstractConfiguration;
    }

    /* JADX WARN: Multi-variable type inference failed */
    public void load(ObjectInputStream objectInputStream) throws Exception {
        setFeatureExtractors((AbstractFeatureExtractor[]) objectInputStream.readObject());
        setLexicons(objectInputStream.readObject());
        setModels(loadModels(objectInputStream));
    }

    public void save(ObjectOutputStream objectOutputStream) throws Exception {
        objectOutputStream.writeObject(this.f_extractors);
        objectOutputStream.writeObject(getLexicons());
        saveModels(objectOutputStream);
    }

    private StringModel[] loadModels(ObjectInputStream objectInputStream) throws Exception {
        int readInt = objectInputStream.readInt();
        StringModel[] stringModelArr = new StringModel[readInt];
        for (int i = 0; i < readInt; i++) {
            stringModelArr[i] = new StringModel(objectInputStream);
        }
        return stringModelArr;
    }

    private void saveModels(ObjectOutputStream objectOutputStream) throws Exception {
        objectOutputStream.writeInt(this.s_models.length);
        for (StringModel stringModel : this.s_models) {
            stringModel.save(objectOutputStream);
        }
    }

    public byte[] toByteArray() throws Exception {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new XZOutputStream(new BufferedOutputStream(byteArrayOutputStream), new LZMA2Options()));
        save(objectOutputStream);
        objectOutputStream.close();
        return byteArrayOutputStream.toByteArray();
    }

    public byte[] modelsToByteArray() throws Exception {
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new XZOutputStream(new BufferedOutputStream(byteArrayOutputStream), new LZMA2Options()));
        for (StringModel stringModel : this.s_models) {
            stringModel.save(objectOutputStream);
        }
        objectOutputStream.close();
        return byteArrayOutputStream.toByteArray();
    }

    public void byteArrayToModels(byte[] bArr) throws Exception {
        ObjectInputStream objectInputStream = new ObjectInputStream(new XZInputStream(new BufferedInputStream(new ByteArrayInputStream(bArr))));
        for (StringModel stringModel : this.s_models) {
            stringModel.load(objectInputStream);
        }
        objectInputStream.close();
    }

    public abstract Object getLexicons();

    public abstract void setLexicons(Object obj);

    public FeatureType[] getFeatureExtractors() {
        return this.f_extractors;
    }

    public void setFeatureExtractors(FeatureType[] featuretypeArr) {
        this.f_extractors = featuretypeArr;
    }

    public StringModel getModel(int i) {
        return this.s_models[i];
    }

    public StringModel[] getModels() {
        return this.s_models;
    }

    public void setModels(StringModel[] stringModelArr) {
        this.s_models = stringModelArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public List<StringInstance> process(StateType statetype) {
        LabelType decode;
        ArrayList arrayList = isTrainOrBootstrap() ? new ArrayList() : null;
        while (!statetype.isTerminate()) {
            switch (AnonymousClass1.$SwitchMap$edu$emory$clir$clearnlp$component$utils$CFlag[this.c_flag.ordinal()]) {
                case 1:
                    decode = train(statetype, arrayList);
                    break;
                case AbstractDEPState.IS_DESC_NO_HEAD /* 2 */:
                    decode = bootstrap(statetype, arrayList);
                    break;
                default:
                    decode = decode(statetype);
                    break;
            }
            statetype.next(decode);
        }
        return arrayList;
    }

    protected LabelType train(StateType statetype, List<StringInstance> list) {
        StringFeatureVector createStringFeatureVector = createStringFeatureVector(statetype);
        LabelType labeltype = (LabelType) statetype.getGoldLabel();
        if (!createStringFeatureVector.isEmpty()) {
            list.add(new StringInstance(labeltype.toString(), createStringFeatureVector));
        }
        return labeltype;
    }

    protected LabelType bootstrap(StateType statetype, List<StringInstance> list) {
        StringFeatureVector createStringFeatureVector = createStringFeatureVector(statetype);
        Object goldLabel = statetype.getGoldLabel();
        if (!createStringFeatureVector.isEmpty()) {
            list.add(new StringInstance(goldLabel.toString(), createStringFeatureVector));
        }
        return getAutoLabel(statetype, createStringFeatureVector);
    }

    protected LabelType decode(StateType statetype) {
        return getAutoLabel(statetype, createStringFeatureVector(statetype));
    }

    protected abstract StringFeatureVector createStringFeatureVector(StateType statetype);

    protected abstract LabelType getAutoLabel(StateType statetype, StringFeatureVector stringFeatureVector);

    public EvalType getEval() {
        return this.c_eval;
    }

    protected abstract void initEval();

    public CFlag getFlag() {
        return this.c_flag;
    }

    public void setFlag(CFlag cFlag) {
        this.c_flag = cFlag;
    }

    public boolean isCollect() {
        return this.c_flag == CFlag.COLLECT;
    }

    public boolean isTrain() {
        return this.c_flag == CFlag.TRAIN;
    }

    public boolean isBootstrap() {
        return this.c_flag == CFlag.BOOTSTRAP;
    }

    public boolean isEvaluate() {
        return this.c_flag == CFlag.EVALUATE;
    }

    public boolean isDecode() {
        return this.c_flag == CFlag.DECODE;
    }

    public boolean isTrainOrBootstrap() {
        return isTrain() || isBootstrap();
    }

    public boolean isDecodeOrEvaluate() {
        return isDecode() || isEvaluate();
    }

    public abstract void onlineTrain(List<DEPTree> list);

    /* JADX INFO: Access modifiers changed from: protected */
    public void onlineTrainSingleAdaGrad(List<DEPTree> list) {
        byte[] byteArray;
        double d;
        double onlineScore = onlineScore(list);
        if (onlineScore == 100.0d) {
            return;
        }
        onlineBootstrap(list);
        AdaGradSVM adaGradSVM = new AdaGradSVM(this.s_models[0], 0, 0, false, 0.01d, 0.1d, 0.0d);
        do {
            try {
                byteArray = toByteArray();
                d = onlineScore;
                adaGradSVM.train();
                onlineScore = onlineScore(list);
            } catch (Exception e) {
                e.printStackTrace();
                return;
            }
        } while (d < onlineScore);
        initDecode(byteArray);
    }

    protected double onlineScore(List<DEPTree> list) {
        CFlag cFlag = this.c_flag;
        this.c_flag = CFlag.EVALUATE;
        initEval();
        Iterator<DEPTree> it = list.iterator();
        while (it.hasNext()) {
            process(it.next());
        }
        this.c_flag = cFlag;
        return this.c_eval.getScore();
    }

    protected void onlineBootstrap(List<DEPTree> list) {
        CFlag cFlag = this.c_flag;
        this.c_flag = CFlag.BOOTSTRAP;
        for (DEPTree dEPTree : list) {
            onlineLexicons(dEPTree);
            process(dEPTree);
        }
        this.c_flag = cFlag;
    }

    protected abstract void onlineLexicons(DEPTree dEPTree);
}
