package ai.libs.mlplan.metamining.similaritymeasures;

import de.jungblut.math.DoubleVector;
import de.jungblut.math.dense.DenseDoubleVector;
import de.jungblut.math.minimize.CostFunction;
import de.jungblut.math.minimize.CostGradientTuple;
import de.jungblut.math.minimize.GradientDescent;
import java.util.Random;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/libs/mlplan/metamining/similaritymeasures/F1Optimizer.class */
public class F1Optimizer implements IHeterogenousSimilarityMeasureComputer {
    private static final double ALPHA_START = 1.0E-9d;
    private static final double ALPHA_MAX = 1.0E-5d;
    private static final int ITERATIONS_PER_PROBE = 100;
    private static final int LIMIT = 1;
    private static final double MAX_DESIRED_ERROR = 0.0d;
    private INDArray rrt;
    private INDArray x;
    private INDArray u;
    private Logger logger = LoggerFactory.getLogger(F1Optimizer.class);
    private final Random rand = new Random();

    @Override // ai.libs.mlplan.metamining.similaritymeasures.IHeterogenousSimilarityMeasureComputer
    public void build(INDArray iNDArray, INDArray iNDArray2, INDArray iNDArray3) {
        double d;
        this.rrt = iNDArray3.mmul(iNDArray3.transpose());
        this.x = iNDArray;
        int columns = iNDArray.columns();
        double[] dArr = new double[columns * LIMIT];
        int i = 0;
        for (int i2 = 0; i2 < columns; i2 += LIMIT) {
            for (int i3 = 0; i3 < LIMIT; i3 += LIMIT) {
                int i4 = i;
                i += LIMIT;
                dArr[i4] = (this.rand.nextDouble() - 0.5d) * 100.0d;
            }
        }
        DoubleVector denseDoubleVector = new DenseDoubleVector(dArr);
        INDArray vector2matrix = vector2matrix(denseDoubleVector, columns, LIMIT);
        double cost = getCost(vector2matrix);
        this.logger.debug("X = {}", iNDArray);
        this.logger.debug("randomly initialized U = {}", vector2matrix);
        this.logger.debug("loss of randomly initialized U: {}", Double.valueOf(cost));
        CostFunction costFunction = doubleVector -> {
            INDArray vector2matrix2 = vector2matrix(doubleVector, iNDArray.columns(), LIMIT);
            return new CostGradientTuple(getCost(vector2matrix2), matrix2vector(getGradientAsMatrix(vector2matrix2)));
        };
        double d2 = 1.0E-9d;
        while (cost > MAX_DESIRED_ERROR) {
            double d3 = cost;
            DoubleVector doubleVector2 = denseDoubleVector;
            denseDoubleVector = new GradientDescent(d2, 1.0d).minimize(costFunction, denseDoubleVector, ITERATIONS_PER_PROBE, false);
            vector2matrix = vector2matrix(denseDoubleVector, columns, LIMIT);
            cost = getCost(vector2matrix);
            if (d3 < cost) {
                denseDoubleVector = doubleVector2;
                cost = d3;
                d = d2 / 2.0d;
            } else if (d3 <= cost) {
                break;
            } else {
                d = d2 * 2.0d;
            }
            d2 = Math.min(d, ALPHA_MAX);
            this.logger.debug("Current Cost {} (alpha = {})", Double.valueOf(cost), Double.valueOf(d2));
        }
        this.u = vector2matrix;
    }

    public INDArray vector2matrix(DoubleVector doubleVector, int i, int i2) {
        double[] dArr = new double[doubleVector.getLength()];
        for (int i3 = 0; i3 < doubleVector.getLength(); i3 += LIMIT) {
            dArr[i3] = doubleVector.get(i3);
        }
        return Nd4j.create(dArr, new int[]{i, i2});
    }

    public DoubleVector matrix2vector(INDArray iNDArray) {
        int rows = iNDArray.rows();
        int columns = iNDArray.columns();
        double[] dArr = new double[rows * columns];
        int i = 0;
        for (int i2 = 0; i2 < rows; i2 += LIMIT) {
            for (int i3 = 0; i3 < columns; i3 += LIMIT) {
                int i4 = i;
                i += LIMIT;
                dArr[i4] = iNDArray.getDouble(i2, i3);
            }
        }
        return new DenseDoubleVector(dArr);
    }

    public double getCost(INDArray iNDArray) {
        INDArray mmul = this.x.mmul(iNDArray);
        INDArray sub = this.rrt.sub(mmul.mmul(mmul.transpose()));
        double d = 0.0d;
        int columns = sub.columns();
        for (int i = 0; i < columns; i += LIMIT) {
            for (int i2 = 0; i2 < columns; i2 += LIMIT) {
                d += Math.pow(sub.getDouble(i, i2), 2.0d);
            }
        }
        return d;
    }

    public INDArray getGradientAsMatrix(INDArray iNDArray) {
        int columns = this.x.columns();
        int columns2 = iNDArray.columns();
        float[][] fArr = new float[columns][columns2];
        for (int i = 0; i < columns; i += LIMIT) {
            for (int i2 = 0; i2 < columns2; i2 += LIMIT) {
                fArr[i][i2] = getFirstDerivative(iNDArray, i, i2);
            }
        }
        return Nd4j.create(fArr);
    }

    public float getFirstDerivative(INDArray iNDArray, int i, int i2) {
        INDArray mmul = this.x.mmul(iNDArray);
        INDArray sub = this.rrt.sub(mmul.mmul(mmul.transpose()));
        int rows = this.x.rows();
        float[] fArr = new float[rows];
        for (int i3 = 0; i3 < rows; i3 += LIMIT) {
            fArr[i3] = this.x.getRow(i3).mmul(iNDArray.getColumn(i2)).getFloat(0L, 0L);
        }
        float f = 0.0f;
        for (int i4 = 0; i4 < rows; i4 += LIMIT) {
            float f2 = this.x.getFloat(i4, i);
            for (int i5 = 0; i5 < rows; i5 += LIMIT) {
                f += (-2.0f) * sub.getFloat(i4, i5) * ((f2 * fArr[i5]) + (this.x.getFloat(i5, i) * fArr[i4]));
            }
        }
        return f;
    }

    @Override // ai.libs.mlplan.metamining.similaritymeasures.IHeterogenousSimilarityMeasureComputer
    public double computeSimilarity(INDArray iNDArray, INDArray iNDArray2) {
        return MAX_DESIRED_ERROR;
    }

    public INDArray getX() {
        return this.x;
    }

    public INDArray getU() {
        return this.u;
    }
}
