package rbms;

import Jama.Matrix;

/* loaded from: input_file:rbms/DBM.class */
public class DBM extends RBM {
    protected RBM[] rbm = null;
    protected int[] h = null;

    public DBM(String[] strArr) throws Exception {
        super.setOptions(strArr);
    }

    public RBM[] getRBMs() {
        return this.rbm;
    }

    @Override // rbms.RBM
    public double[] prob_z(double[] dArr) {
        if (this.rbm == null) {
            return null;
        }
        for (int i = 0; i < this.h.length; i++) {
            dArr = this.rbm[i].prob_z(dArr);
        }
        return dArr;
    }

    @Override // rbms.RBM
    public double[][] prob_Z(double[][] dArr) {
        if (this.rbm == null) {
            return (double[][]) null;
        }
        for (int i = 0; i < this.h.length; i++) {
            dArr = this.rbm[i].prob_Z(dArr);
        }
        return dArr;
    }

    public void setH(int[] iArr) {
        this.h = iArr;
    }

    public void setH(int i, int i2, int i3) {
        int[] iArr = new int[i3];
        for (int i4 = 0; i4 < i3 - 1; i4++) {
            iArr[i4] = i;
        }
        iArr[i3 - 1] = i2;
        this.h = iArr;
    }

    public void setH(int i, int i2) {
        int[] iArr = new int[i2];
        for (int i3 = 0; i3 < i2; i3++) {
            iArr[i3] = i;
        }
        this.h = iArr;
    }

    @Override // rbms.RBM
    public void setH(int i) {
        setH(i, 2);
    }

    @Override // rbms.RBM
    public Matrix[] getWs() {
        Matrix[] matrixArr = new Matrix[this.rbm.length];
        for (int i = 0; i < matrixArr.length; i++) {
            matrixArr[i] = this.rbm[i].getW();
        }
        return matrixArr;
    }

    @Override // rbms.RBM
    public double train(double[][] dArr) throws Exception {
        return train(dArr, 0);
    }

    @Override // rbms.RBM
    public double train(double[][] dArr, int i) throws Exception {
        int length = this.h.length;
        this.rbm = new RBM[length];
        for (int i2 = 0; i2 < length; i2++) {
            this.rbm[i2] = new RBM(getOptions());
            this.rbm[i2].setH(this.h[i2]);
            if (i == 0) {
                this.rbm[i2].train(dArr);
            } else {
                this.rbm[i2].train(dArr, i);
            }
            dArr = this.rbm[i2].prob_Z(dArr);
        }
        return 1.0d;
    }

    @Override // rbms.RBM
    public void update(Matrix matrix) {
        for (int i = 0; i < this.h.length; i++) {
            this.rbm[i].update(matrix);
            try {
                matrix = this.rbm[i].prob_Z(matrix);
            } catch (Exception e) {
                System.err.println("AHH!!");
                e.printStackTrace();
            }
        }
    }

    @Override // rbms.RBM
    public void update(Matrix matrix, double d) {
        for (int i = 0; i < this.h.length; i++) {
            this.rbm[i].update(matrix, d);
            try {
                matrix = this.rbm[i].prob_Z(matrix);
            } catch (Exception e) {
                System.err.println("AHH!!");
                e.printStackTrace();
            }
        }
    }

    @Override // rbms.RBM
    public void update(double[][] dArr) {
        for (int i = 0; i < this.h.length; i++) {
            this.rbm[i].update(dArr);
            try {
                dArr = this.rbm[i].prob_Z(dArr);
            } catch (Exception e) {
                System.err.println("AHH!!");
                e.printStackTrace();
            }
        }
    }
}
