package meka.classifiers.multilabel.meta;

import Jama.Matrix;
import meka.classifiers.multilabel.BR;
import meka.classifiers.multilabel.NN.AbstractDeepNeuralNet;
import meka.classifiers.multilabel.ProblemTransformationMethod;
import meka.core.M;
import meka.core.MLUtils;
import rbms.DBM;
import rbms.RBM;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.TechnicalInformation;
import weka.core.TechnicalInformationHandler;

/* loaded from: input_file:meka/classifiers/multilabel/meta/DeepML.class */
public class DeepML extends AbstractDeepNeuralNet implements TechnicalInformationHandler {
    private static final long serialVersionUID = 3388606529764305098L;
    protected RBM dbm = null;
    protected long rbm_time = 0;

    protected RBM createDBM(int i) throws Exception {
        return this.m_N == 1 ? new RBM(getOptions()) : new DBM(getOptions());
    }

    public DeepML() {
        this.m_Classifier = new BR();
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    protected String defaultClassifierString() {
        return "meka.classifiers.multilabel.BR";
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public void buildClassifier(Instances instances) throws Exception {
        testCapabilities(instances);
        int classIndex = instances.classIndex();
        int numAttributes = instances.numAttributes() - classIndex;
        double[][] xfromD = MLUtils.getXfromD(instances);
        this.dbm = createDBM(numAttributes);
        this.dbm.setSeed(this.m_Seed);
        this.dbm.setE(this.m_E);
        long currentTimeMillis = System.currentTimeMillis();
        this.dbm.train(xfromD, classIndex);
        this.rbm_time = System.currentTimeMillis() - currentTimeMillis;
        double[][] prob_Z = this.dbm.prob_Z(xfromD);
        if (getDebug()) {
            Matrix[] ws = this.dbm.getWs();
            System.out.println("X = \n" + M.toString(xfromD));
            System.out.println("W = \n" + M.toString(ws[0].getArray()));
            System.out.println("Y = \n" + M.toString(MLUtils.getYfromD(instances), 0));
            System.out.println("Z = \n" + M.toString(M.threshold(prob_Z, 0.5d), 0));
        }
        this.m_InstancesTemplate = new Instances(MLUtils.replaceZasAttributes(instances, prob_Z, classIndex));
        this.m_Classifier.buildClassifier(this.m_InstancesTemplate);
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public double[] distributionForInstance(Instance instance) throws Exception {
        int classIndex = instance.classIndex();
        double[] prob_z = this.dbm.prob_z(MLUtils.getxfromInstance(instance));
        Instance instance2 = (Instance) this.m_InstancesTemplate.firstInstance().copy();
        MLUtils.setValues(instance2, prob_z, classIndex);
        instance2.setDataset(this.m_InstancesTemplate);
        return this.m_Classifier.distributionForInstance(instance2);
    }

    @Override // meka.classifiers.multilabel.NN.AbstractNeuralNet, meka.classifiers.multilabel.ProblemTransformationMethod
    public String toString() {
        return super.toString() + ", RBM-Build_Time=" + this.rbm_time;
    }

    @Override // meka.classifiers.multilabel.ProblemTransformationMethod
    public String globalInfo() {
        return "Create a new feature space using a stack of RBMs, then employ a multi-label classifier on top. For more information see:\n" + getTechnicalInformation().toString();
    }

    public TechnicalInformation getTechnicalInformation() {
        TechnicalInformation technicalInformation = new TechnicalInformation(TechnicalInformation.Type.INPROCEEDINGS);
        technicalInformation.setValue(TechnicalInformation.Field.AUTHOR, "Jesse Read and Jaako Hollmen");
        technicalInformation.setValue(TechnicalInformation.Field.TITLE, "A Deep Interpretation of Classifier Chains");
        technicalInformation.setValue(TechnicalInformation.Field.BOOKTITLE, "Advances in Intelligent Data Analysis {XIII} - 13th International Symposium, {IDA} 2014");
        technicalInformation.setValue(TechnicalInformation.Field.PAGES, "251--262");
        technicalInformation.setValue(TechnicalInformation.Field.YEAR, "2014");
        return technicalInformation;
    }

    public static void main(String[] strArr) throws Exception {
        ProblemTransformationMethod.evaluation(new DeepML(), strArr);
    }
}
