package dragon.ml.seqmodel.crf;

import dragon.ml.seqmodel.data.DataSequence;
import dragon.ml.seqmodel.data.Dataset;
import dragon.ml.seqmodel.feature.Feature;
import dragon.ml.seqmodel.feature.FeatureGenerator;
import dragon.ml.seqmodel.model.ModelGraph;
import dragon.util.MathUtil;

/* loaded from: input_file:dragon/ml/seqmodel/crf/CollinsBasicTrainer.class */
public class CollinsBasicTrainer extends AbstractTrainer {
    protected int topSolutions;
    protected double beta;
    protected boolean useUpdated;

    public CollinsBasicTrainer(ModelGraph modelGraph, FeatureGenerator featureGenerator) {
        super(modelGraph, featureGenerator);
        this.topSolutions = Math.min(3, modelGraph.getStateNum());
        this.beta = 0.05d;
        this.useUpdated = false;
    }

    @Override // dragon.ml.seqmodel.crf.Trainer
    public boolean train(Dataset dataset) {
        dataset.startScan();
        while (dataset.hasNext()) {
            this.model.mapLabelToState(dataset.next());
        }
        if (!this.featureGenerator.train(dataset)) {
            return false;
        }
        int featureNum = this.featureGenerator.getFeatureNum();
        this.lambda = new double[featureNum];
        double[] dArr = new double[featureNum];
        double[] dArr2 = new double[featureNum];
        MathUtil.initArray(this.lambda, 0.0d);
        MathUtil.initArray(dArr, 0.0d);
        MathUtil.initArray(this.lambda, 0.0d);
        Labeler labeler = getLabeler();
        DataSequence[] dataSequenceArr = new DataSequence[this.topSolutions];
        int[] iArr = new int[this.topSolutions];
        int i = 0;
        for (int i2 = 0; i2 < this.maxIteration; i2++) {
            int i3 = 0;
            dataset.startScan();
            while (dataset.hasNext()) {
                if (i > 0) {
                    MathUtil.copyArray(dArr2, dArr);
                    MathUtil.multiArray(dArr, 1.0d / i);
                }
                MathUtil.initArray(iArr, 0);
                DataSequence next = dataset.next();
                labeler.label(next.copy(), this.useUpdated ? dArr : this.lambda);
                double sequenceScore = getSequenceScore(next, this.useUpdated ? dArr : this.lambda);
                int i4 = 0;
                for (int i5 = 0; i5 < this.topSolutions; i5++) {
                    DataSequence copy = next.copy();
                    if (labeler.getBestSolution(copy, i5) < sequenceScore * (1.0d - this.beta)) {
                        break;
                    }
                    this.model.mapLabelToState(copy);
                    if (!isCorrect(next, copy)) {
                        dataSequenceArr[i4] = copy;
                        i4++;
                    }
                }
                if (i4 > 0) {
                    int markovOrder = this.model.getMarkovOrder() - 1;
                    while (true) {
                        int i6 = markovOrder;
                        if (i6 < next.length()) {
                            int segmentEnd = getSegmentEnd(next, i6);
                            boolean z = false;
                            for (int i7 = 0; i7 < i4; i7++) {
                                if (iArr[i7] != i6 || getSegmentEnd(dataSequenceArr[i7], iArr[i7]) != segmentEnd || next.getLabel(segmentEnd) != dataSequenceArr[i7].getLabel(segmentEnd)) {
                                    z = true;
                                    break;
                                }
                            }
                            if (z) {
                                i3++;
                                updateWeights(next, i6, segmentEnd, 1.0d, this.lambda);
                                for (int i8 = 0; i8 < i4; i8++) {
                                    while (iArr[i8] <= segmentEnd) {
                                        int segmentEnd2 = getSegmentEnd(dataSequenceArr[i8], iArr[i8]);
                                        updateWeights(dataSequenceArr[i8], iArr[i8], segmentEnd2, (-1.0d) / i4, this.lambda);
                                        iArr[i8] = segmentEnd2 + 1;
                                    }
                                }
                            }
                            for (int i9 = 0; i9 < i4; i9++) {
                                while (iArr[i9] <= segmentEnd) {
                                    iArr[i9] = getSegmentEnd(dataSequenceArr[i9], iArr[i9]) + 1;
                                }
                            }
                            markovOrder = segmentEnd + 1;
                        }
                    }
                }
                MathUtil.sumArray(dArr2, this.lambda);
                i++;
            }
            System.out.println("Iteration " + i2 + " numErrs " + i3);
            if (i3 == 0) {
                break;
            }
        }
        MathUtil.multiArray(dArr2, 1.0d / i);
        MathUtil.copyArray(dArr2, this.lambda);
        return true;
    }

    protected boolean isCorrect(DataSequence dataSequence, DataSequence dataSequence2) {
        for (int i = 0; i < dataSequence.length(); i++) {
            if (dataSequence.getLabel(i) != dataSequence2.getLabel(i)) {
                return false;
            }
        }
        return true;
    }

    protected void updateWeights(DataSequence dataSequence, int i, int i2, double d, double[] dArr) {
        this.featureGenerator.startScanFeaturesAt(dataSequence, i, i2);
        while (this.featureGenerator.hasNext()) {
            Feature next = this.featureGenerator.next();
            int index = next.getIndex();
            int label = next.getLabel();
            int prevLabel = next.getPrevLabel();
            if (dataSequence.getLabel(i2) == label && (prevLabel < 0 || prevLabel == dataSequence.getLabel(i - 1))) {
                dArr[index] = dArr[index] + (d * next.getValue());
            }
        }
    }

    protected double getSequenceScore(DataSequence dataSequence, double[] dArr) {
        int markovOrder = this.model.getMarkovOrder() - 1;
        double d = 0.0d;
        while (markovOrder < dataSequence.length()) {
            int segmentEnd = getSegmentEnd(dataSequence, markovOrder);
            this.featureGenerator.startScanFeaturesAt(dataSequence, markovOrder, segmentEnd);
            while (this.featureGenerator.hasNext()) {
                Feature next = this.featureGenerator.next();
                int index = next.getIndex();
                int label = next.getLabel();
                int prevLabel = next.getPrevLabel();
                if (dataSequence.getLabel(segmentEnd) == label && (prevLabel < 0 || prevLabel == dataSequence.getLabel(markovOrder - 1))) {
                    d += dArr[index] * next.getValue();
                }
            }
            markovOrder = segmentEnd + 1;
        }
        return d;
    }

    protected Labeler getLabeler() {
        return new ViterbiBasicLabeler(this.model, this.featureGenerator);
    }

    protected int getSegmentEnd(DataSequence dataSequence, int i) {
        return i;
    }
}
