package edu.emory.clir.clearnlp.component.mode.dep.state;

import com.carrotsearch.hppc.IntCollection;
import edu.emory.clir.clearnlp.classification.instance.StringInstance;
import edu.emory.clir.clearnlp.classification.prediction.StringPrediction;
import edu.emory.clir.clearnlp.collection.map.ObjectIntHashMap;
import edu.emory.clir.clearnlp.collection.stack.IntPStack;
import edu.emory.clir.clearnlp.collection.triple.ObjectObjectDoubleTriple;
import edu.emory.clir.clearnlp.component.mode.dep.DEPConfiguration;
import edu.emory.clir.clearnlp.component.mode.dep.DEPLabel;
import edu.emory.clir.clearnlp.component.mode.dep.DEPTransition;
import edu.emory.clir.clearnlp.component.utils.CFlag;
import edu.emory.clir.clearnlp.dependency.DEPTree;
import edu.emory.clir.clearnlp.util.arc.DEPArc;
import java.util.Collections;
import java.util.List;
import java.util.PriorityQueue;

/* loaded from: input_file:edu/emory/clir/clearnlp/component/mode/dep/state/DEPStateBranch.class */
public class DEPStateBranch extends AbstractDEPState implements DEPTransition {
    private ObjectObjectDoubleTriple<DEPArc[], List<StringInstance>> best_tree;
    private PriorityQueue<DEPBranch> q_branches;
    private boolean save_branch;
    private int beam_size;
    private int max_heads;
    public static ObjectIntHashMap<String> mmm = new ObjectIntHashMap<>();
    private String beta_index;
    private String best_index;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:edu/emory/clir/clearnlp/component/mode/dep/state/DEPStateBranch$DEPBranch.class */
    public class DEPBranch implements Comparable<DEPBranch> {
        private DEPArc[] heads;
        private IntPStack stack;
        private IntPStack inter;
        private int input;
        private DEPLabel fstLabel;
        private DEPLabel sndLabel;
        private double totalScore;
        private int numTransitions;

        public DEPBranch(DEPLabel dEPLabel, DEPLabel dEPLabel2) {
            this.heads = DEPStateBranch.this.d_tree.getHeads(DEPStateBranch.this.i_input + 1);
            this.stack = new IntPStack((IntCollection) DEPStateBranch.this.i_stack);
            this.inter = new IntPStack((IntCollection) DEPStateBranch.this.i_inter);
            this.input = DEPStateBranch.this.i_input;
            this.fstLabel = dEPLabel;
            this.sndLabel = dEPLabel2;
            this.totalScore = DEPStateBranch.this.total_score;
            this.numTransitions = DEPStateBranch.this.num_transitions;
        }

        public void reset() {
            DEPStateBranch.this.beta_index = this.fstLabel.getArc() + "-" + this.fstLabel.getList() + " " + this.sndLabel.getArc() + "-" + this.sndLabel.getList();
            DEPStateBranch.this.d_tree.setHeads(this.heads);
            DEPStateBranch.this.i_stack = this.stack;
            DEPStateBranch.this.i_inter = this.inter;
            DEPStateBranch.this.i_input = this.input;
            DEPStateBranch.this.total_score = this.totalScore;
            DEPStateBranch.this.num_transitions = this.numTransitions;
            DEPStateBranch.this.next(this.sndLabel);
        }

        @Override // java.lang.Comparable
        public int compareTo(DEPBranch dEPBranch) {
            return this.sndLabel.compareTo(dEPBranch.sndLabel);
        }
    }

    public DEPStateBranch() {
    }

    public DEPStateBranch(DEPTree dEPTree, CFlag cFlag, DEPConfiguration dEPConfiguration) {
        super(dEPTree, cFlag, dEPConfiguration);
        init(dEPConfiguration);
    }

    private void init(DEPConfiguration dEPConfiguration) {
        this.beam_size = this.t_configuration.getBeamSize();
        this.save_branch = this.beam_size > 1;
        if (this.save_branch) {
            this.q_branches = new PriorityQueue<>(Collections.reverseOrder());
        }
    }

    @Override // edu.emory.clir.clearnlp.component.mode.dep.state.AbstractDEPState
    public boolean startBranching() {
        if (this.q_branches == null || this.q_branches.isEmpty()) {
            return false;
        }
        this.best_tree = new ObjectObjectDoubleTriple<>(this.d_tree.getHeads(), null, getScore());
        this.beam_size = Math.min(this.beam_size - 1, this.q_branches.size());
        this.max_heads = this.d_tree.countHeaded();
        this.save_branch = false;
        this.best_index = null;
        return true;
    }

    @Override // edu.emory.clir.clearnlp.component.mode.dep.state.AbstractDEPState
    public boolean nextBranch() {
        int i = this.beam_size;
        this.beam_size = i - 1;
        if (0 >= i) {
            return false;
        }
        this.q_branches.poll().reset();
        return true;
    }

    @Override // edu.emory.clir.clearnlp.component.mode.dep.state.AbstractDEPState
    public void saveBest(List<StringInstance> list) {
        int countHeaded = this.d_tree.countHeaded();
        double score = getScore();
        if (countHeaded < this.max_heads || score <= this.best_tree.d) {
            return;
        }
        this.best_tree.set(this.d_tree.getHeads(), list, score);
        this.max_heads = countHeaded;
        this.best_index = this.beta_index;
    }

    @Override // edu.emory.clir.clearnlp.component.mode.dep.state.AbstractDEPState
    public List<StringInstance> setBest() {
        this.d_tree.setHeads(this.best_tree.o1);
        if (this.best_index != null) {
            mmm.add(this.best_index);
        }
        return this.best_tree.o2;
    }

    /* JADX WARN: Multi-variable type inference failed */
    private double getScore() {
        return this.c_flag == CFlag.BOOTSTRAP ? this.d_tree.getScoreCounts((DEPArc[]) this.g_oracle, this.t_configuration.evaluatePunctuation())[1] : this.total_score / this.num_transitions;
    }

    @Override // edu.emory.clir.clearnlp.component.mode.dep.state.AbstractDEPState
    public void saveBranch(StringPrediction[] stringPredictionArr) {
        if (this.save_branch) {
            StringPrediction stringPrediction = stringPredictionArr[0];
            StringPrediction stringPrediction2 = stringPredictionArr[1];
            if (stringPrediction.getScore() - stringPrediction2.getScore() < 1.0d) {
                addBranch(new DEPLabel(stringPrediction), new DEPLabel(stringPrediction2));
            }
        }
    }

    private void addBranch(DEPLabel dEPLabel, DEPLabel dEPLabel2) {
        if (dEPLabel.isArc(dEPLabel2) && dEPLabel.isList(dEPLabel2)) {
            return;
        }
        this.q_branches.add(new DEPBranch(dEPLabel, dEPLabel2));
    }
}
