package ml.dmlc.mxnet.spark;

import ml.dmlc.mxnet.Context;
import ml.dmlc.mxnet.Shape;
import ml.dmlc.mxnet.Symbol;
import ml.dmlc.mxnet.spark.utils.Network$;
import org.apache.spark.SparkContext;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.rdd.RDD;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Predef$;
import scala.Serializable;
import scala.reflect.ClassTag$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* compiled from: MXNet.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005]b\u0001B\u0001\u0003\u0001-\u0011Q!\u0014-OKRT!a\u0001\u0003\u0002\u000bM\u0004\u0018M]6\u000b\u0005\u00151\u0011!B7y]\u0016$(BA\u0004\t\u0003\u0011!W\u000e\\2\u000b\u0003%\t!!\u001c7\u0004\u0001M\u0019\u0001\u0001\u0004\n\u0011\u00055\u0001R\"\u0001\b\u000b\u0003=\tQa]2bY\u0006L!!\u0005\b\u0003\r\u0005s\u0017PU3g!\ti1#\u0003\u0002\u0015\u001d\ta1+\u001a:jC2L'0\u00192mK\")a\u0003\u0001C\u0001/\u00051A(\u001b8jiz\"\u0012\u0001\u0007\t\u00033\u0001i\u0011A\u0001\u0005\b7\u0001\u0011\r\u0011\"\u0003\u001d\u0003\u0019awnZ4feV\tQ\u0004\u0005\u0002\u001fG5\tqD\u0003\u0002!C\u0005)1\u000f\u001c45U*\t!%A\u0002pe\u001eL!\u0001J\u0010\u0003\r1{wmZ3s\u0011\u00191\u0003\u0001)A\u0005;\u00059An\\4hKJ\u0004\u0003b\u0002\u0015\u0001\u0005\u0004%I!K\u0001\u0007a\u0006\u0014\u0018-\\:\u0016\u0003)\u0002\"!G\u0016\n\u00051\u0012!aC'Y\u001d\u0016$\b+\u0019:b[NDaA\f\u0001!\u0002\u0013Q\u0013a\u00029be\u0006l7\u000f\t\u0005\u0006a\u0001!\t!M\u0001\rg\u0016$()\u0019;dQNK'0\u001a\u000b\u0003eMj\u0011\u0001\u0001\u0005\u0006i=\u0002\r!N\u0001\nE\u0006$8\r[*ju\u0016\u0004\"!\u0004\u001c\n\u0005]r!aA%oi\")\u0011\b\u0001C\u0001u\u0005Y1/\u001a;Ok6,\u0005o\\2i)\t\u00114\bC\u0003=q\u0001\u0007Q'\u0001\u0005ok6,\u0005o\\2i\u0011\u0015q\u0004\u0001\"\u0001@\u00031\u0019X\r\u001e#j[\u0016t7/[8o)\t\u0011\u0004\tC\u0003B{\u0001\u0007!)A\u0005eS6,gn]5p]B\u00111\tR\u0007\u0002\t%\u0011Q\t\u0002\u0002\u0006'\"\f\u0007/\u001a\u0005\u0006\u000f\u0002!\t\u0001S\u0001\u000bg\u0016$h*\u001a;x_J\\GC\u0001\u001aJ\u0011\u0015Qe\t1\u0001L\u0003\u001dqW\r^<pe.\u0004\"a\u0011'\n\u00055#!AB*z[\n|G\u000eC\u0003P\u0001\u0011\u0005\u0001+\u0001\u0006tKR\u001cuN\u001c;fqR$\"AM)\t\u000bIs\u0005\u0019A*\u0002\u0007\r$\b\u0010E\u0002\u000e)ZK!!\u0016\b\u0003\u000b\u0005\u0013(/Y=\u0011\u0005\r;\u0016B\u0001-\u0005\u0005\u001d\u0019uN\u001c;fqRDQA\u0017\u0001\u0005\u0002m\u000bAb]3u\u001dVlwk\u001c:lKJ$\"A\r/\t\u000buK\u0006\u0019A\u001b\u0002\u00139,XnV8sW\u0016\u0014\b\"B0\u0001\t\u0003\u0001\u0017\u0001D:fi:+XnU3sm\u0016\u0014HC\u0001\u001ab\u0011\u0015\u0011g\f1\u00016\u0003%qW/\\*feZ,'\u000fC\u0003e\u0001\u0011\u0005Q-A\u0006tKR$\u0015\r^1OC6,GC\u0001\u001ag\u0011\u001597\r1\u0001i\u0003\u0011q\u0017-\\3\u0011\u0005%dgBA\u0007k\u0013\tYg\"\u0001\u0004Qe\u0016$WMZ\u0005\u0003[:\u0014aa\u0015;sS:<'BA6\u000f\u0011\u0015\u0001\b\u0001\"\u0001r\u00031\u0019X\r\u001e'bE\u0016dg*Y7f)\t\u0011$\u000fC\u0003h_\u0002\u0007\u0001\u000eC\u0003u\u0001\u0011\u0005Q/\u0001\u0006tKR$\u0016.\\3pkR$\"A\r<\t\u000b]\u001c\b\u0019A\u001b\u0002\u000fQLW.Z8vi\")\u0011\u0010\u0001C\u0001u\u0006y1/\u001a;Fq\u0016\u001cW\u000f^8s\u0015\u0006\u00148\u000f\u0006\u00023w\")A\u0010\u001fa\u0001Q\u0006!!.\u0019:t\u0011\u0015q\b\u0001\"\u0001��\u0003\u001d\u0019X\r\u001e&bm\u0006$2AMA\u0001\u0011\u0019\t\u0019! a\u0001Q\u0006!!.\u0019<b\u0011\u001d\t9\u0001\u0001C\u0001\u0003\u0013\t1AZ5u)\u0011\tY!!\u0005\u0011\u0007e\ti!C\u0002\u0002\u0010\t\u0011!\"\u0014-OKRlu\u000eZ3m\u0011!\t\u0019\"!\u0002A\u0002\u0005U\u0011\u0001\u00023bi\u0006\u0004b!a\u0006\u0002$\u0005\u001dRBAA\r\u0015\u0011\tY\"!\b\u0002\u0007I$GMC\u0002\u0004\u0003?Q1!!\t\"\u0003\u0019\t\u0007/Y2iK&!\u0011QEA\r\u0005\r\u0011F\t\u0012\t\u0005\u0003S\t\u0019$\u0004\u0002\u0002,)!\u0011QFA\u0018\u0003)\u0011Xm\u001a:fgNLwN\u001c\u0006\u0005\u0003c\ti\"A\u0003nY2L'-\u0003\u0003\u00026\u0005-\"\u0001\u0004'bE\u0016dW\r\u001a)pS:$\b")
/* loaded from: input_file:ml/dmlc/mxnet/spark/MXNet.class */
public class MXNet implements Serializable {
    private final Logger ml$dmlc$mxnet$spark$MXNet$$logger = LoggerFactory.getLogger(MXNet.class);
    private final MXNetParams ml$dmlc$mxnet$spark$MXNet$$params = new MXNetParams();

    public Logger ml$dmlc$mxnet$spark$MXNet$$logger() {
        return this.ml$dmlc$mxnet$spark$MXNet$$logger;
    }

    public MXNetParams ml$dmlc$mxnet$spark$MXNet$$params() {
        return this.ml$dmlc$mxnet$spark$MXNet$$params;
    }

    public MXNet setBatchSize(int i) {
        ml$dmlc$mxnet$spark$MXNet$$params().batchSize_$eq(i);
        return this;
    }

    public MXNet setNumEpoch(int i) {
        ml$dmlc$mxnet$spark$MXNet$$params().numEpoch_$eq(i);
        return this;
    }

    public MXNet setDimension(Shape shape) {
        ml$dmlc$mxnet$spark$MXNet$$params().dimension_$eq(shape);
        return this;
    }

    public MXNet setNetwork(Symbol symbol) {
        ml$dmlc$mxnet$spark$MXNet$$params().setNetwork(symbol);
        return this;
    }

    public MXNet setContext(Context[] contextArr) {
        ml$dmlc$mxnet$spark$MXNet$$params().context_$eq(contextArr);
        return this;
    }

    public MXNet setNumWorker(int i) {
        ml$dmlc$mxnet$spark$MXNet$$params().numWorker_$eq(i);
        return this;
    }

    public MXNet setNumServer(int i) {
        ml$dmlc$mxnet$spark$MXNet$$params().numServer_$eq(i);
        return this;
    }

    public MXNet setDataName(String str) {
        ml$dmlc$mxnet$spark$MXNet$$params().dataName_$eq(str);
        return this;
    }

    public MXNet setLabelName(String str) {
        ml$dmlc$mxnet$spark$MXNet$$params().labelName_$eq(str);
        return this;
    }

    public MXNet setTimeout(int i) {
        ml$dmlc$mxnet$spark$MXNet$$params().timeout_$eq(i);
        return this;
    }

    public MXNet setExecutorJars(String str) {
        ml$dmlc$mxnet$spark$MXNet$$params().jars_$eq(str.split(",|:"));
        return this;
    }

    public MXNet setJava(String str) {
        ml$dmlc$mxnet$spark$MXNet$$params().javabin_$eq(str);
        return this;
    }

    public MXNetModel fit(RDD<LabeledPoint> rdd) {
        RDD<LabeledPoint> rdd2;
        SparkContext context = rdd.context();
        Predef$.MODULE$.refArrayOps(ml$dmlc$mxnet$spark$MXNet$$params().jars()).foreach(new MXNet$$anonfun$fit$1(this, context));
        if (ml$dmlc$mxnet$spark$MXNet$$params().numWorker() > rdd.partitions().length) {
            ml$dmlc$mxnet$spark$MXNet$$logger().info("repartitioning training set to {} partitions", BoxesRunTime.boxToInteger(ml$dmlc$mxnet$spark$MXNet$$params().numWorker()));
            int numWorker = ml$dmlc$mxnet$spark$MXNet$$params().numWorker();
            rdd2 = rdd.repartition(numWorker, rdd.repartition$default$2(numWorker));
        } else if (ml$dmlc$mxnet$spark$MXNet$$params().numWorker() < rdd.partitions().length) {
            ml$dmlc$mxnet$spark$MXNet$$logger().info("repartitioning training set to {} partitions", BoxesRunTime.boxToInteger(ml$dmlc$mxnet$spark$MXNet$$params().numWorker()));
            int numWorker2 = ml$dmlc$mxnet$spark$MXNet$$params().numWorker();
            boolean coalesce$default$2 = rdd.coalesce$default$2();
            rdd2 = rdd.coalesce(numWorker2, coalesce$default$2, rdd.coalesce$default$3(numWorker2, coalesce$default$2));
        } else {
            rdd2 = rdd;
        }
        RDD<LabeledPoint> rdd3 = rdd2;
        String ipAddress = Network$.MODULE$.ipAddress();
        int availablePort = Network$.MODULE$.availablePort();
        ml$dmlc$mxnet$spark$MXNet$$logger().info("Starting scheduler on {}:{}", ipAddress, BoxesRunTime.boxToInteger(availablePort));
        ParameterServer parameterServer = new ParameterServer(ml$dmlc$mxnet$spark$MXNet$$params().runtimeClasspath(), "scheduler", ipAddress, availablePort, ml$dmlc$mxnet$spark$MXNet$$params().numServer(), ml$dmlc$mxnet$spark$MXNet$$params().numWorker(), ml$dmlc$mxnet$spark$MXNet$$params().timeout(), ml$dmlc$mxnet$spark$MXNet$$params().javabin(), ParameterServer$.MODULE$.$lessinit$greater$default$9());
        Predef$.MODULE$.require(parameterServer.startProcess(), new MXNet$$anonfun$fit$2(this));
        context.parallelize(RichInt$.MODULE$.to$extension0(Predef$.MODULE$.intWrapper(1), ml$dmlc$mxnet$spark$MXNet$$params().numServer()), ml$dmlc$mxnet$spark$MXNet$$params().numServer(), ClassTag$.MODULE$.Int()).foreachPartition(new MXNet$$anonfun$fit$3(this, ipAddress, availablePort));
        RDD cache = rdd3.mapPartitions(new MXNet$$anonfun$1(this, ipAddress, availablePort), rdd3.mapPartitions$default$2(), ClassTag$.MODULE$.apply(MXNetModel.class)).cache();
        cache.foreachPartition(new MXNet$$anonfun$fit$4(this));
        MXNetModel mXNetModel = (MXNetModel) cache.first();
        ml$dmlc$mxnet$spark$MXNet$$logger().info("Waiting for scheduler ...");
        parameterServer.waitFor();
        return mXNetModel;
    }
}
