package ml.dmlc.mxnet.spark;

import ml.dmlc.mxnet.Accuracy;
import ml.dmlc.mxnet.DataIter;
import ml.dmlc.mxnet.FeedForward;
import ml.dmlc.mxnet.FeedForward$;
import ml.dmlc.mxnet.KVStore;
import ml.dmlc.mxnet.KVStore$;
import ml.dmlc.mxnet.KVStoreServer$;
import ml.dmlc.mxnet.NDArray;
import ml.dmlc.mxnet.Xavier;
import ml.dmlc.mxnet.Xavier$;
import ml.dmlc.mxnet.optimizer.SGD;
import ml.dmlc.mxnet.optimizer.SGD$;
import ml.dmlc.mxnet.spark.io.LabeledPointIter;
import org.apache.spark.mllib.regression.LabeledPoint;
import scala.Predef$;
import scala.Serializable;
import scala.collection.Iterator;
import scala.collection.immutable.Map;
import scala.package$;
import scala.runtime.AbstractFunction1;
import scala.runtime.BoxesRunTime;

/* compiled from: MXNet.scala */
/* loaded from: input_file:ml/dmlc/mxnet/spark/MXNet$$anonfun$1.class */
public final class MXNet$$anonfun$1 extends AbstractFunction1<Iterator<LabeledPoint>, Iterator<MXNetModel>> implements Serializable {
    public static final long serialVersionUID = 0;
    private final /* synthetic */ MXNet $outer;
    private final String schedulerIP$1;
    private final int schedulerPort$1;

    public final Iterator<MXNetModel> apply(Iterator<LabeledPoint> iterator) {
        LabeledPointIter labeledPointIter = new LabeledPointIter(iterator, this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().dimension(), this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().batchSize(), this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().dataName(), this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().labelName());
        int i = 0;
        while (true) {
            int i2 = i;
            if (!labeledPointIter.hasNext()) {
                this.$outer.ml$dmlc$mxnet$spark$MXNet$$logger().debug("Number of samples: {}", BoxesRunTime.boxToInteger(i2));
                labeledPointIter.reset();
                this.$outer.ml$dmlc$mxnet$spark$MXNet$$logger().info("Launching worker ...");
                this.$outer.ml$dmlc$mxnet$spark$MXNet$$logger().info("Batch {}", BoxesRunTime.boxToInteger(this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().batchSize()));
                Thread.sleep(20000L);
                KVStoreServer$.MODULE$.init(ParameterServer$.MODULE$.buildEnv("worker", this.schedulerIP$1, this.schedulerPort$1, this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().numServer(), this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().numWorker()));
                KVStore create = KVStore$.MODULE$.create("dist_async");
                create.setBarrierBeforeExit(false);
                SGD sgd = new SGD(0.01f, 0.9f, 1.0E-5f, SGD$.MODULE$.$lessinit$greater$default$4(), SGD$.MODULE$.$lessinit$greater$default$5());
                this.$outer.ml$dmlc$mxnet$spark$MXNet$$logger().debug("Define model");
                FeedForward feedForward = new FeedForward(this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().getNetwork(), this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().context(), this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().numEpoch(), (i2 / this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().batchSize()) / create.numWorkers(), sgd, new Xavier(Xavier$.MODULE$.$lessinit$greater$default$1(), "in", 2.34f), FeedForward$.MODULE$.$lessinit$greater$default$7(), (Map) null, (Map) null, FeedForward$.MODULE$.$lessinit$greater$default$10(), 0);
                this.$outer.ml$dmlc$mxnet$spark$MXNet$$logger().info("Start training ...");
                feedForward.fit(labeledPointIter, (DataIter) null, new Accuracy(), create);
                this.$outer.ml$dmlc$mxnet$spark$MXNet$$logger().info("Training finished, waiting for other workers ...");
                labeledPointIter.dispose();
                create.setBarrierBeforeExit(true);
                create.dispose();
                return package$.MODULE$.Iterator().apply(Predef$.MODULE$.wrapRefArray(new MXNetModel[]{new MXNetModel(feedForward, this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().dimension(), this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().batchSize(), this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().dataName(), this.$outer.ml$dmlc$mxnet$spark$MXNet$$params().labelName())}));
            }
            i = i2 + ((NDArray) labeledPointIter.m19next().label().head()).shape().apply(0);
        }
    }

    public MXNet$$anonfun$1(MXNet mXNet, String str, int i) {
        if (mXNet == null) {
            throw null;
        }
        this.$outer = mXNet;
        this.schedulerIP$1 = str;
        this.schedulerPort$1 = i;
    }
}
