package ml.dmlc.mxnet;

import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import scala.Array$;
import scala.None$;
import scala.Option;
import scala.Option$;
import scala.Predef$;
import scala.StringContext;
import scala.Tuple2;
import scala.Tuple3;
import scala.collection.GenTraversableOnce;
import scala.collection.IndexedSeq;
import scala.collection.Iterable$;
import scala.collection.Seq;
import scala.collection.TraversableOnce;
import scala.collection.immutable.Map;
import scala.collection.immutable.Map$;
import scala.collection.immutable.MapLike;
import scala.collection.immutable.Nil$;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.HashMap$;
import scala.math.Ordering$Int$;
import scala.reflect.ClassTag$;
import scala.runtime.BoxesRunTime;
import scala.runtime.RichInt$;

/* JADX WARN: Classes with same name are omitted:
  input_file:archive-tmp/mxnet-full_2.10-osx-x86_64-cpu-0.1.1.jar:ml/dmlc/mxnet/Model$.class
 */
/* compiled from: Model.scala */
/* loaded from: input_file:mxnet-full_2.10-osx-x86_64-cpu-0.1.1.jar:ml/dmlc/mxnet/Model$.class */
public final class Model$ {
    public static final Model$ MODULE$ = null;
    private final Logger logger;

    static {
        new Model$();
    }

    private Logger logger() {
        return this.logger;
    }

    public void saveCheckpoint(String str, int i, Symbol symbol, Map<String, NDArray> map, Map<String, NDArray> map2) {
        symbol.save(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "-symbol.json"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str})));
        Map<String, NDArray> $plus$plus = ((MapLike) map.map(new Model$$anonfun$1(), Map$.MODULE$.canBuildFrom())).$plus$plus((GenTraversableOnce) map2.map(new Model$$anonfun$2(), Map$.MODULE$.canBuildFrom()));
        String format = new StringOps(Predef$.MODULE$.augmentString("%s-%04d.params")).format(Predef$.MODULE$.genericWrapArray(new Object[]{str, BoxesRunTime.boxToInteger(i)}));
        NDArray$.MODULE$.save(format, $plus$plus);
        logger().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Saved checkpoint to ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{format})));
    }

    /* JADX WARN: Multi-variable type inference failed */
    public Tuple3<Symbol, Map<String, NDArray>, Map<String, NDArray>> loadCheckpoint(String str, int i) {
        Symbol load = Symbol$.MODULE$.load(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"", "-symbol.json"})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str})));
        Tuple2<String[], NDArray[]> load2 = NDArray$.MODULE$.load(new StringOps(Predef$.MODULE$.augmentString("%s-%04d.params")).format(Predef$.MODULE$.genericWrapArray(new Object[]{str, BoxesRunTime.boxToInteger(i)})));
        HashMap hashMap = (HashMap) HashMap$.MODULE$.apply(Nil$.MODULE$);
        HashMap hashMap2 = (HashMap) HashMap$.MODULE$.apply(Nil$.MODULE$);
        Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(load2.mo241_1()).zip(Predef$.MODULE$.wrapRefArray(load2.mo240_2()), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).withFilter(new Model$$anonfun$loadCheckpoint$1()).foreach(new Model$$anonfun$loadCheckpoint$2(hashMap, hashMap2));
        return new Tuple3<>(load, hashMap.toMap(Predef$.MODULE$.conforms()), hashMap2.toMap(Predef$.MODULE$.conforms()));
    }

    /* JADX WARN: Type inference failed for: r0v6, types: [scala.collection.Iterable] */
    public Tuple2<Option<KVStore>, Object> createKVStore(String str, int i, Map<String, NDArray> map) {
        if (i == 1 && !str.contains("dist")) {
            return new Tuple2<>(None$.MODULE$, BoxesRunTime.boxToBoolean(false));
        }
        String str2 = str;
        if (str2 != null ? str2.equals("local") : "local" == 0) {
            str2 = BoxesRunTime.unboxToInt(((TraversableOnce) map.values().map(new Model$$anonfun$3(), Iterable$.MODULE$.canBuildFrom())).mo395max(Ordering$Int$.MODULE$)) < 16777216 ? "local_update_cpu" : "local_allreduce_cpu";
            logger().info(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Auto - select kvstore type = ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{str2})));
        }
        return new Tuple2<>(Option$.MODULE$.apply(KVStore$.MODULE$.create(str2)), BoxesRunTime.boxToBoolean(!str2.contains("local_allreduce")));
    }

    public Tuple2<Option<KVStore>, Object> createKVStore(KVStore kVStore) {
        return new Tuple2<>(Option$.MODULE$.apply(kVStore), BoxesRunTime.boxToBoolean((kVStore == null || kVStore.type().contains("local_allreduce")) ? false : true));
    }

    public void ml$dmlc$mxnet$Model$$initializeKVStore(KVStore kVStore, IndexedSeq<NDArray[]> indexedSeq, Map<String, NDArray> map, IndexedSeq<String> indexedSeq2, boolean z) {
        Predef$.MODULE$.require(indexedSeq.length() == indexedSeq2.length());
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(0), indexedSeq.length()).foreach$mVc$sp(new Model$$anonfun$ml$dmlc$mxnet$Model$$initializeKVStore$1(kVStore, indexedSeq, map, indexedSeq2, z));
    }

    public void ml$dmlc$mxnet$Model$$updateParamsOnKVStore(NDArray[][] nDArrayArr, NDArray[][] nDArrayArr2, Option<KVStore> option) {
        Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(nDArrayArr).zip(Predef$.MODULE$.wrapRefArray(nDArrayArr2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach(new Model$$anonfun$ml$dmlc$mxnet$Model$$updateParamsOnKVStore$1(option));
    }

    public void ml$dmlc$mxnet$Model$$updateParams(NDArray[][] nDArrayArr, NDArray[][] nDArrayArr2, MXKVStoreUpdater mXKVStoreUpdater, int i, Option<KVStore> option) {
        Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps((Object[]) Predef$.MODULE$.refArrayOps(nDArrayArr).zip(Predef$.MODULE$.wrapRefArray(nDArrayArr2), Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).zipWithIndex(Array$.MODULE$.canBuildFrom(ClassTag$.MODULE$.apply(Tuple2.class)))).foreach(new Model$$anonfun$ml$dmlc$mxnet$Model$$updateParams$1(mXKVStoreUpdater, i, option));
    }

    private Option<KVStore> updateParams$default$5() {
        return None$.MODULE$;
    }

    public void trainMultiDevice(Symbol symbol, Context[] contextArr, Seq<String> seq, Seq<String> seq2, Seq<String> seq3, Map<String, NDArray> map, Map<String, NDArray> map2, int i, int i2, int i3, Optimizer optimizer, Option<KVStore> option, boolean z, DataIter dataIter, Option<DataIter> option2, EvalMetric evalMetric, Option<EpochEndCallback> option3, Option<BatchEndCallback> option4, Logger logger, Seq<Object> seq4, Option<Monitor> option5) {
        DataParallelExecutorManager dataParallelExecutorManager = new DataParallelExecutorManager(symbol, contextArr, seq2, seq, seq3, dataIter, seq4, logger);
        option5.foreach(new Model$$anonfun$trainMultiDevice$2(dataParallelExecutorManager));
        dataParallelExecutorManager.setParams(map, map2);
        MXKVStoreUpdater updater = Optimizer$.MODULE$.getUpdater(optimizer);
        option.foreach(new Model$$anonfun$trainMultiDevice$3(map, z, dataParallelExecutorManager));
        if (z) {
            option.foreach(new Model$$anonfun$trainMultiDevice$4(optimizer));
        }
        RichInt$.MODULE$.until$extension0(Predef$.MODULE$.intWrapper(i), i2).foreach$mVc$sp(new Model$$anonfun$trainMultiDevice$1(symbol, contextArr, map, map2, i2, i3, option, z, dataIter, option2, evalMetric, option3, option4, logger, option5, dataParallelExecutorManager, updater));
        updater.dispose();
        dataParallelExecutorManager.dispose();
    }

    public DataIter trainMultiDevice$default$14() {
        return null;
    }

    public Option<DataIter> trainMultiDevice$default$15() {
        return None$.MODULE$;
    }

    public Option<EpochEndCallback> trainMultiDevice$default$17() {
        return None$.MODULE$;
    }

    public Option<BatchEndCallback> trainMultiDevice$default$18() {
        return None$.MODULE$;
    }

    public Logger trainMultiDevice$default$19() {
        return logger();
    }

    public Seq<Object> trainMultiDevice$default$20() {
        return Nil$.MODULE$;
    }

    public Option<Monitor> trainMultiDevice$default$21() {
        return None$.MODULE$;
    }

    private Model$() {
        MODULE$ = this;
        this.logger = LoggerFactory.getLogger(Model.class);
    }
}
