package rbms;

import Jama.Matrix;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Random;
import weka.core.Utils;

/* loaded from: input_file:rbms/RBM.class */
public class RBM {
    protected double LEARNING_RATE = 0.1d;
    protected double MOMENTUM = 0.1d;
    protected double COST = 2.0E-4d * this.LEARNING_RATE;
    protected int m_E = 1000;
    protected int m_H = 10;
    private boolean m_V = false;
    private int batch_size = 0;
    protected Matrix W = null;
    protected Matrix dW_ = null;
    protected Random m_R = new Random(0);

    public RBM() {
    }

    public RBM(String[] strArr) throws Exception {
        setOptions(strArr);
    }

    public void setOptions(String[] strArr) throws Exception {
        try {
            setH(Integer.parseInt(Utils.getOption('H', strArr)));
            setE(Integer.parseInt(Utils.getOption('E', strArr)));
            setLearningRate(Double.parseDouble(Utils.getOption('r', strArr)));
            setMomentum(Double.parseDouble(Utils.getOption('m', strArr)));
        } catch (Exception e) {
            System.err.println("Missing option!");
            e.printStackTrace();
            System.exit(1);
        }
    }

    public String[] getOptions() throws Exception {
        ArrayList arrayList = new ArrayList();
        arrayList.add("-r");
        arrayList.add(String.valueOf(this.LEARNING_RATE));
        arrayList.add("-m");
        arrayList.add(String.valueOf(this.MOMENTUM));
        arrayList.add("-E");
        arrayList.add(String.valueOf(getE()));
        arrayList.add("-H");
        arrayList.add(String.valueOf(getH()));
        return (String[]) arrayList.toArray(new String[arrayList.size()]);
    }

    public double[] prob_z(double[] dArr) {
        return Mat.removeBias(Mat.sigma(new Matrix(Mat.addBias(dArr), 1).times(this.W).getArray()[0]));
    }

    public double[][] prob_Z(double[][] dArr) {
        return Mat.removeBias(prob_Z(new Matrix(Mat.addBias(dArr))).getArray());
    }

    public Matrix prob_Z(Matrix matrix) {
        Matrix sigma = Mat.sigma(matrix.times(this.W));
        Mat.fillCol(sigma.getArray(), 0, 1.0d);
        return sigma;
    }

    public double[][] propUp(double[][] dArr) {
        return Mat.threshold(prob_Z(dArr), 0.5d);
    }

    public Matrix sample_Z(Matrix matrix) {
        return Mat.sample(prob_Z(matrix), this.m_R);
    }

    public double[] sample_z(double[] dArr) {
        return Mat.sample(prob_z(dArr), this.m_R);
    }

    public double[] sample_x(double[] dArr) {
        return Mat.sample(prob_x(dArr), this.m_R);
    }

    public Matrix sample_X(Matrix matrix) {
        return Mat.sample(prob_X(matrix), this.m_R);
    }

    public double[] prob_x(double[] dArr) {
        return Mat.removeBias(Mat.sigma(new Matrix(Mat.addBias(dArr), 1).times(this.W.transpose()).getArray()[0]));
    }

    public Matrix prob_X(Matrix matrix) {
        Matrix matrix2 = new Matrix(Mat.sigma(matrix.times(this.W.transpose()).getArray()));
        Mat.fillCol(matrix2.getArray(), 0, 1.0d);
        return matrix2;
    }

    public static Matrix makeW(int i, int i2, Random random) {
        double[][] multiply = Mat.multiply(Mat.randn(i + 1, i2 + 1, random), 0.2d);
        Mat.fillRow(multiply, 0, 0.0d);
        Mat.fillCol(multiply, 0, 0.0d);
        return new Matrix(multiply);
    }

    protected Matrix makeW(int i, int i2) {
        return makeW(i, i2, this.m_R);
    }

    private void initWeights(double[][] dArr) {
        initWeights(dArr[0].length, this.m_H);
    }

    private void initWeights(int i, int i2) {
        this.W = makeW(i, i2);
        this.dW_ = new Matrix(this.W.getRowDimension(), this.W.getColumnDimension());
    }

    public void initWeights(int i) {
        initWeights(i, this.m_H);
    }

    public void update(Matrix matrix) {
        Matrix timesEquals = epoch(matrix).minusEquals(this.W.times(this.COST)).timesEquals(this.LEARNING_RATE);
        this.W.plusEquals(timesEquals.plus(this.dW_.timesEquals(this.MOMENTUM)));
        this.dW_ = timesEquals;
    }

    public void update(Matrix matrix, double d) {
        Matrix times = epoch(matrix).minusEquals(this.W.times(this.COST)).timesEquals(this.LEARNING_RATE).times(d);
        this.W.plusEquals(times.plus(this.dW_.timesEquals(this.MOMENTUM)));
        this.dW_ = times;
    }

    public void update(double[][] dArr) {
        update(new Matrix(Mat.addBias(dArr)));
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    public void update(double[] dArr) {
        update((double[][]) new double[]{dArr});
    }

    /* JADX WARN: Type inference failed for: r2v1, types: [double[], double[][]] */
    public void update(double[] dArr, double d) {
        update(new Matrix(Mat.addBias((double[][]) new double[]{dArr})), d);
    }

    public double train(double[][] dArr) throws Exception {
        initWeights(dArr);
        Matrix matrix = new Matrix(Mat.addBias(dArr));
        double d = Double.MAX_VALUE;
        int i = 0;
        while (true) {
            if (i >= this.m_E) {
                break;
            }
            if (this.m_V) {
                double calculateError = calculateError(matrix);
                if (d < calculateError) {
                    System.out.println("broken out @" + i);
                    break;
                }
                d = calculateError;
            }
            update(matrix);
            i++;
        }
        return d;
    }

    public double train(double[][] dArr, int i) throws Exception {
        initWeights(dArr);
        double[][] addBias = Mat.addBias(dArr);
        int length = addBias.length;
        if (i == length) {
            return train(addBias);
        }
        int ceil = (int) Math.ceil((length * 1.0d) / i);
        Matrix[] matrixArr = new Matrix[ceil];
        int i2 = 0;
        int i3 = 0;
        while (i2 < length) {
            matrixArr[i3] = new Matrix((double[][]) Arrays.copyOfRange(addBias, i2, Math.min(i2 + i, length)));
            i2 += i;
            i3++;
        }
        for (int i4 = 0; i4 < this.m_E; i4++) {
            for (Matrix matrix : matrixArr) {
                update(matrix, 1.0d / ceil);
            }
        }
        return 1.0d;
    }

    public double train(double[][] dArr, int i, Random random) throws Exception {
        initWeights(dArr);
        double[][] addBias = Mat.addBias(dArr);
        int length = addBias.length;
        int ceil = (int) Math.ceil((length * 1.0d) / i);
        Matrix[] matrixArr = new Matrix[ceil];
        int i2 = 0;
        int i3 = 0;
        while (i2 < length) {
            matrixArr[i3] = new Matrix((double[][]) Arrays.copyOfRange(addBias, i2, Math.min(i2 + i, length)));
            i2 += i;
            i3++;
        }
        for (int i4 = 0; i4 < this.m_E; i4++) {
            for (int i5 = 0; i5 < ceil; i5++) {
                update(matrixArr[random.nextInt(ceil)]);
            }
        }
        return 1.0d;
    }

    public double calculateError(Matrix matrix) {
        return Mat.meanSquaredError(matrix.getArray(), prob_X(prob_Z(matrix)).getArray());
    }

    public Matrix epoch(Matrix matrix) {
        int length = matrix.getArray().length;
        Matrix prob_Z = prob_Z(matrix);
        Matrix times = matrix.transpose().times(prob_Z);
        Matrix prob_X = prob_X(prob_Z);
        return times.minusEquals(prob_X.transpose().times(prob_Z(prob_X))).times(1.0d / length);
    }

    public Matrix sample_epoch(Matrix matrix) {
        int length = matrix.getArray().length;
        Matrix sample_Z = sample_Z(matrix);
        Matrix times = matrix.transpose().times(sample_Z);
        Matrix sample_X = sample_X(sample_Z);
        Matrix times2 = sample_X.transpose().times(prob_Z(sample_X));
        System.out.println("" + Mat.meanSquaredError(matrix.getArray(), sample_X.getArray()));
        return times.minusEquals(times2).times(1.0d / length);
    }

    public void setH(int i) {
        this.m_H = i;
    }

    public int getH() {
        return this.m_H;
    }

    public void setE(int i) {
        if (i >= 0) {
            this.m_E = i;
        } else {
            this.m_V = true;
            this.m_E = -i;
        }
    }

    public int getE() {
        return this.m_E;
    }

    public void setLearningRate(double d) {
        this.LEARNING_RATE = d;
        this.COST = 2.0E-4d * this.LEARNING_RATE;
    }

    public double getLearningRate() {
        return this.LEARNING_RATE;
    }

    public void setMomentum(double d) {
        this.MOMENTUM = d;
    }

    public double getMomentum() {
        return this.MOMENTUM;
    }

    public void setSeed(int i) {
        this.m_R = new Random(i);
    }

    public Matrix[] getWs() {
        return new Matrix[]{this.W};
    }

    public Matrix getW() {
        return this.W;
    }

    public String toString() {
        return Mat.toString(getW());
    }

    public static void main(String[] strArr) throws Exception {
    }
}
