package meka.classifiers.multilabel;

import Jama.Matrix;
import java.util.Random;
import meka.classifiers.multilabel.NN.AbstractDeepNeuralNet;
import meka.core.M;
import meka.core.MLUtils;
import rbms.DBM;
import rbms.RBM;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;

/* loaded from: input_file:meka/classifiers/multilabel/DBPNN.class */
public class DBPNN extends AbstractDeepNeuralNet implements TechnicalInformationHandler {
    private static final long serialVersionUID = 5007534249445210725L;
    protected RBM dbm = null;
    protected long rbm_time = 0;

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public void buildClassifier(Instances instances) throws Exception {
        testCapabilities(instances);
        int classIndex = instances.classIndex();
        int numAttributes = instances.numAttributes() - classIndex;
        double[][] xfromD = MLUtils.getXfromD(instances);
        double[][] yfromD = MLUtils.getYfromD(instances);
        if (getDebug()) {
            System.out.println("Build RBM(s) ... ");
        }
        this.dbm = new DBM(getOptions());
        this.dbm.setE(this.m_E);
        ((DBM) this.dbm).setH(this.m_H, this.m_N);
        long currentTimeMillis = System.currentTimeMillis();
        this.dbm.train(xfromD, this.m_H);
        this.rbm_time = System.currentTimeMillis() - currentTimeMillis;
        if (getDebug()) {
            Matrix[] ws = this.dbm.getWs();
            System.out.println("X = \n" + M.toString(xfromD));
            for (Matrix matrix : ws) {
                System.out.println("W = \n" + M.toString(matrix.getArray()));
            }
            System.out.println("Y = \n" + M.toString(yfromD));
        }
        Matrix[] trimBiases = trimBiases(this.dbm.getWs());
        if (!(this.m_Classifier instanceof BPNN)) {
            System.err.println("[WARNING] Was expecting BPNN as the base classifier (will set it now, with default parameters) ...");
            this.m_Classifier = new BPNN();
        } else if (getDebug()) {
            System.out.println("You have chosen to use BPNN (good!)");
        }
        int length = trimBiases.length - 1;
        trimBiases[length] = RBM.makeW(trimBiases[length].getRowDimension() - 1, trimBiases[length].getColumnDimension() - 1, new Random(1L));
        ((BPNN) this.m_Classifier).presetWeights(trimBiases, classIndex);
        ((BPNN) this.m_Classifier).train(xfromD, yfromD);
        if (getDebug()) {
            System.out.println("W = \n" + M.toString(trimBiases[0].getArray()));
            System.out.println("W = \n" + M.toString(trimBiases[1].getArray()));
            System.out.println("Y = \n" + M.toString(M.threshold(((BPNN) this.m_Classifier).popY(xfromD), 0.5d)));
        }
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public double[] distributionForInstance(Instance instance) throws Exception {
        return this.m_Classifier.distributionForInstance(instance);
    }

    protected static Matrix trimBiases(Matrix matrix) {
        return new Matrix(M.removeBias(matrix.getArray()));
    }

    protected static Matrix[] trimBiases(Matrix[] matrixArr) {
        for (int i = 0; i < matrixArr.length; i++) {
            matrixArr[i] = trimBiases(matrixArr[i]);
        }
        return matrixArr;
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public String globalInfo() {
        return "A Deep Back-Propagation Neural Network. For more information see:\n" + getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.ARTICLE);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Geoffrey Hinton and Ruslan Salakhutdinov");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "Reducing the Dimensionality of Data with Neural Networks");
        technicalInformation.setValue(TechnicalInformation.Field.JOURNAL, "Science");
        technicalInformation.setValue(TechnicalInformation.Field.VOLUME, "313");
        technicalInformation.setValue(TechnicalInformation.Field.NUMBER, "5786");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "504-507");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2006");
        return technicalInformation;
    }

    public static void main(String[] strArr) throws Exception {
        ProblemTransformationMethod.evaluation(new DBPNN(), strArr);
    }
}
