package dragon.ml.seqmodel.crf;

import dragon.matrix.DoubleFlatDenseMatrix;
import dragon.ml.seqmodel.crf.LBFGS;
import dragon.ml.seqmodel.data.DataSequence;
import dragon.ml.seqmodel.data.Dataset;
import dragon.ml.seqmodel.feature.Feature;
import dragon.ml.seqmodel.feature.FeatureGenerator;
import dragon.ml.seqmodel.model.ModelGraph;
import dragon.util.MathUtil;
import java.util.Date;

/* loaded from: input_file:dragon/ml/seqmodel/crf/LBFGSBasicTrainer.class */
public class LBFGSBasicTrainer extends AbstractTrainer {
    protected int mForHessian;
    protected double epsForConvergence;
    protected double invSigmaSquare;

    public LBFGSBasicTrainer(ModelGraph modelGraph, FeatureGenerator featureGenerator) {
        super(modelGraph, featureGenerator);
        this.mForHessian = 7;
        this.epsForConvergence = 0.001d;
        this.invSigmaSquare = 0.01d;
    }

    public void setGradientHistory(int i) {
        this.mForHessian = i;
    }

    public void setAccuracy(int i) {
        this.epsForConvergence = i;
    }

    public void setInvSigmaSquare(int i) {
        this.invSigmaSquare = i;
    }

    @Override // dragon.ml.seqmodel.crf.Trainer
    public boolean train(Dataset dataset) {
        dataset.startScan();
        while (dataset.hasNext()) {
            this.model.mapLabelToState(dataset.next());
        }
        if (!this.featureGenerator.train(dataset)) {
            return false;
        }
        int featureNum = this.featureGenerator.getFeatureNum();
        this.lambda = new double[featureNum];
        double[] dArr = new double[featureNum];
        double[] dArr2 = new double[featureNum];
        int i = 0;
        int[] iArr = {-1, 0};
        int[] iArr2 = {0};
        for (int i2 = 0; i2 < this.lambda.length; i2++) {
            this.lambda[i2] = 0.0d;
        }
        do {
            double computeFunctionGradient = computeFunctionGradient(dataset, this.lambda, dArr);
            System.out.println(new Date().toString() + " Iteration: " + i + " log likelihood " + computeFunctionGradient + " norm(grad logli) " + norm(dArr) + " norm(x) " + norm(this.lambda));
            double d = (-1.0d) * computeFunctionGradient;
            for (int i3 = 0; i3 < this.lambda.length; i3++) {
                int i4 = i3;
                dArr[i4] = dArr[i4] * (-1.0d);
            }
            try {
                LBFGS.lbfgs(featureNum, this.mForHessian, this.lambda, d, dArr, false, dArr2, iArr, this.epsForConvergence, xtol, iArr2);
                i++;
                if (iArr2[0] == 0) {
                    return true;
                }
            } catch (LBFGS.ExceptionWithIflag e) {
                System.err.println("CRF: lbfgs failed.\n" + e);
                if (e.iflag != -1) {
                    return false;
                }
                System.err.println("Possible reasons could be: \n \t 1. Bug in the feature generation or data handling code\n\t 2. Not enough features to make observed feature value==expected value\n");
                return false;
            }
        } while (i <= this.maxIteration);
        return true;
    }

    protected double norm(double[] dArr) {
        double d = 0.0d;
        for (int i = 0; i < dArr.length; i++) {
            d += dArr[i] * dArr[i];
        }
        return Math.sqrt(d);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r0v35, types: [double[]] */
    protected double computeFunctionGradient(Dataset dataset, double[] dArr, double[] dArr2) {
        double d = 0.0d;
        int markovOrder = this.model.getMarkovOrder();
        int stateNum = this.model.getStateNum();
        double[] dArr3 = new double[stateNum];
        double[] dArr4 = new double[stateNum];
        double[][] dArr5 = (double[][]) null;
        double[] dArr6 = null;
        double[] dArr7 = new double[this.featureGenerator.getFeatureNum()];
        DoubleFlatDenseMatrix doubleFlatDenseMatrix = new DoubleFlatDenseMatrix(stateNum, stateNum);
        for (int i = 0; i < dArr.length; i++) {
            try {
                dArr2[i] = (-1.0d) * dArr[i] * this.invSigmaSquare;
                d -= ((dArr[i] * dArr[i]) * this.invSigmaSquare) / 2.0d;
            } catch (Exception e) {
                e.printStackTrace();
                System.exit(0);
            }
        }
        dataset.startScan();
        while (dataset.hasNext()) {
            DataSequence next = dataset.next();
            MathUtil.initArray(dArr3, 1.0d);
            for (int i2 = 0; i2 < dArr.length; i2++) {
                dArr7[i2] = 0.0d;
            }
            if (dArr5 == null || dArr5.length < next.length()) {
                dArr5 = new double[2 * next.length()];
                for (int i3 = 0; i3 < dArr5.length; i3++) {
                    dArr5[i3] = new double[stateNum];
                }
                dArr6 = new double[2 * next.length()];
            }
            dArr6[next.length() - 1] = this.doScaling ? stateNum : 1.0d;
            MathUtil.initArray(dArr5[next.length() - 1], 1.0d / dArr6[next.length() - 1]);
            for (int length = next.length() - 1; length > markovOrder - 1; length--) {
                computeTransMatrix(dArr, next, length, length, doubleFlatDenseMatrix, true);
                MathUtil.initArray(dArr5[length - 1], 0.0d);
                genStateVector(doubleFlatDenseMatrix, dArr5[length], dArr5[length - 1], false);
                dArr6[length - 1] = this.doScaling ? MathUtil.sumArray(dArr5[length - 1]) : 1.0d;
                if (dArr6[length - 1] < 1.0d && dArr6[length - 1] > -1.0d) {
                    dArr6[length - 1] = 1.0d;
                }
                MathUtil.multiArray(dArr5[length - 1], 1.0d / dArr6[length - 1]);
            }
            double d2 = 0.0d;
            for (int i4 = markovOrder - 1; i4 < next.length(); i4++) {
                computeTransMatrix(dArr, next, i4, i4, doubleFlatDenseMatrix, true);
                MathUtil.initArray(dArr4, 0.0d);
                genStateVector(doubleFlatDenseMatrix, dArr3, dArr4, true);
                this.featureGenerator.startScanFeaturesAt(next, i4, i4);
                while (this.featureGenerator.hasNext()) {
                    Feature next2 = this.featureGenerator.next();
                    int index = next2.getIndex();
                    int label = next2.getLabel();
                    int prevLabel = next2.getPrevLabel();
                    double value = next2.getValue();
                    if (next.getLabel(i4) == label && ((i4 - 1 >= 0 && prevLabel == next.getLabel(i4 - 1)) || prevLabel < 0)) {
                        dArr2[index] = dArr2[index] + value;
                        d2 += value * dArr[index];
                    }
                    if (prevLabel < 0) {
                        dArr7[index] = dArr7[index] + (value * dArr4[label] * dArr5[i4][label]);
                    } else {
                        dArr7[index] = dArr7[index] + (value * dArr3[prevLabel] * doubleFlatDenseMatrix.getDouble(prevLabel, label) * dArr5[i4][label]);
                    }
                }
                MathUtil.copyArray(dArr4, dArr3);
                MathUtil.multiArray(dArr3, 1.0d / dArr6[i4]);
            }
            double sumArray = MathUtil.sumArray(dArr3);
            double log = d2 - Math.log(sumArray);
            for (int i5 = markovOrder - 1; i5 < next.length(); i5++) {
                log -= Math.log(dArr6[i5]);
            }
            d += log;
            for (int i6 = 0; i6 < dArr2.length; i6++) {
                int i7 = i6;
                dArr2[i7] = dArr2[i7] - (dArr7[i6] / sumArray);
            }
        }
        return d;
    }
}
