package edu.emory.clir.clearnlp.component.mode.dep;

import edu.emory.clir.clearnlp.classification.instance.StringInstance;
import edu.emory.clir.clearnlp.classification.model.StringModel;
import edu.emory.clir.clearnlp.classification.prediction.StringPrediction;
import edu.emory.clir.clearnlp.classification.vector.StringFeatureVector;
import edu.emory.clir.clearnlp.collection.pair.ObjectIntPair;
import edu.emory.clir.clearnlp.component.AbstractStatisticalComponent;
import edu.emory.clir.clearnlp.component.configuration.AbstractConfiguration;
import edu.emory.clir.clearnlp.component.mode.dep.state.AbstractDEPState;
import edu.emory.clir.clearnlp.component.mode.dep.state.DEPStateBranch;
import edu.emory.clir.clearnlp.dependency.DEPNode;
import edu.emory.clir.clearnlp.dependency.DEPTree;
import edu.emory.clir.clearnlp.feature.AbstractFeatureExtractor;
import java.io.ObjectInputStream;
import java.util.List;

/* loaded from: input_file:edu/emory/clir/clearnlp/component/mode/dep/AbstractDEPParser.class */
public abstract class AbstractDEPParser extends AbstractStatisticalComponent<DEPLabel, AbstractDEPState, DEPEval, DEPFeatureExtractor> implements DEPTransition {
    private DEPConfiguration d_configuration;
    private int[][] label_indices;

    public AbstractDEPParser(DEPConfiguration dEPConfiguration, DEPFeatureExtractor[] dEPFeatureExtractorArr, Object obj) {
        super((AbstractConfiguration) dEPConfiguration, (AbstractFeatureExtractor[]) dEPFeatureExtractorArr, obj, false, 1);
        init(dEPConfiguration);
    }

    public AbstractDEPParser(DEPConfiguration dEPConfiguration, DEPFeatureExtractor[] dEPFeatureExtractorArr, Object obj, StringModel[] stringModelArr, boolean z) {
        super(dEPConfiguration, dEPFeatureExtractorArr, obj, stringModelArr, z);
        init(dEPConfiguration);
    }

    public AbstractDEPParser(DEPConfiguration dEPConfiguration, ObjectInputStream objectInputStream) {
        super(dEPConfiguration, objectInputStream);
        init(dEPConfiguration);
    }

    public AbstractDEPParser(DEPConfiguration dEPConfiguration, byte[] bArr) {
        super(dEPConfiguration, bArr);
        init(dEPConfiguration);
    }

    private void init(DEPConfiguration dEPConfiguration) {
        this.label_indices = AbstractDEPState.initLabelIndices(this.s_models[0].getLabels());
        this.d_configuration = dEPConfiguration;
    }

    @Override // edu.emory.clir.clearnlp.component.AbstractStatisticalComponent
    public Object getLexicons() {
        return null;
    }

    @Override // edu.emory.clir.clearnlp.component.AbstractStatisticalComponent
    public void setLexicons(Object obj) {
    }

    @Override // edu.emory.clir.clearnlp.component.AbstractStatisticalComponent
    protected void initEval() {
        this.c_eval = new DEPEval(((DEPConfiguration) this.t_configuration).evaluatePunctuation());
    }

    @Override // edu.emory.clir.clearnlp.component.AbstractComponent
    public void process(DEPTree dEPTree) {
        DEPStateBranch dEPStateBranch = new DEPStateBranch(dEPTree, this.c_flag, this.d_configuration);
        List<StringInstance> process = process((AbstractDEPParser) dEPStateBranch);
        if (dEPStateBranch.startBranching()) {
            while (dEPStateBranch.nextBranch()) {
                dEPStateBranch.saveBest(process((AbstractDEPParser) dEPStateBranch));
            }
            List<StringInstance> best = dEPStateBranch.setBest();
            if (best != null) {
                process.addAll(best);
            }
        }
        if (isTrainOrBootstrap()) {
            this.s_models[0].addInstances(process);
            return;
        }
        processHeadless(dEPStateBranch);
        if (isEvaluate()) {
            ((DEPEval) this.c_eval).countCorrect(dEPTree, dEPStateBranch.getOracle());
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.clir.clearnlp.component.AbstractStatisticalComponent
    public StringFeatureVector createStringFeatureVector(AbstractDEPState abstractDEPState) {
        return ((DEPFeatureExtractor[]) this.f_extractors)[0].createStringFeatureVector(abstractDEPState);
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // edu.emory.clir.clearnlp.component.AbstractStatisticalComponent
    public DEPLabel getAutoLabel(AbstractDEPState abstractDEPState, StringFeatureVector stringFeatureVector) {
        StringPrediction[] predictions = getPredictions(abstractDEPState, stringFeatureVector);
        DEPLabel dEPLabel = new DEPLabel(predictions[0]);
        if (dEPLabel.isArc(DEPTransition.ARC_NO)) {
            abstractDEPState.save2ndHead(predictions);
        }
        abstractDEPState.saveBranch(predictions);
        return dEPLabel;
    }

    protected StringPrediction[] getPredictions(AbstractDEPState abstractDEPState, StringFeatureVector stringFeatureVector) {
        int[] labelIndices = abstractDEPState.getLabelIndices(this.label_indices);
        StringPrediction[] predictTop2 = labelIndices != null ? this.s_models[0].predictTop2(stringFeatureVector, labelIndices) : this.s_models[0].predictTop2(stringFeatureVector);
        for (StringPrediction stringPrediction : predictTop2) {
            stringPrediction.setScore(1.0d / (1.0d + Math.exp(-stringPrediction.getScore())));
        }
        return predictTop2;
    }

    private void processHeadless(AbstractDEPState abstractDEPState) {
        int treeSize = abstractDEPState.getTreeSize();
        for (int i = 1; i < treeSize; i++) {
            DEPNode node = abstractDEPState.getNode(i);
            if (!node.hasHead() && !abstractDEPState.find2ndHead(node)) {
                ObjectIntPair<StringPrediction> objectIntPair = new ObjectIntPair<>(null, -1000);
                processHeadlessAll(abstractDEPState, node, objectIntPair, this.label_indices[5], -1);
                processHeadlessAll(abstractDEPState, node, objectIntPair, this.label_indices[4], 1);
                if (objectIntPair.o == null) {
                    node.setHead(abstractDEPState.getNode(0), this.d_configuration.getRootLabel());
                } else {
                    node.setHead(abstractDEPState.getNode(objectIntPair.i), new DEPLabel(objectIntPair.o).getDeprel());
                }
            }
        }
    }

    private void processHeadlessAll(AbstractDEPState abstractDEPState, DEPNode dEPNode, ObjectIntPair<StringPrediction> objectIntPair, int[] iArr, int i) {
        int id = dEPNode.getID();
        int treeSize = abstractDEPState.getTreeSize();
        int i2 = id;
        while (true) {
            int i3 = i2 + i;
            if (0 > i3 || i3 >= treeSize) {
                return;
            }
            if (!abstractDEPState.getNode(i3).isDescendantOf(dEPNode)) {
                if (i < 0) {
                    abstractDEPState.reset(i3, id);
                } else {
                    abstractDEPState.reset(id, i3);
                }
                StringPrediction predictBest = this.s_models[0].predictBest(createStringFeatureVector(abstractDEPState), iArr);
                if (objectIntPair.o == null || objectIntPair.o.compareTo(predictBest) < 0) {
                    objectIntPair.set(predictBest, i3);
                }
            }
            i2 = i3;
        }
    }

    @Override // edu.emory.clir.clearnlp.component.AbstractStatisticalComponent
    public void onlineTrain(List<DEPTree> list) {
        onlineTrainSingleAdaGrad(list);
    }

    @Override // edu.emory.clir.clearnlp.component.AbstractStatisticalComponent
    protected void onlineLexicons(DEPTree dEPTree) {
    }
}
