package cc.factorie.la;

import cc.factorie.la.WeightsMapAccumulator;
import cc.factorie.model.Weights;
import cc.factorie.model.WeightsMap;
import scala.MatchError;
import scala.Predef$;
import scala.StringContext;
import scala.collection.immutable.Nil$;
import scala.collection.mutable.HashMap;
import scala.collection.mutable.HashMap$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxedUnit;
import scala.runtime.BoxesRunTime;

/* compiled from: TensorAccumulator.scala */
@ScalaSignature(bytes = "\u0006\u0001y3A!\u0001\u0002\u0001\u0013\tA2+\\1si\u001e\u0013\u0018\rZ5f]R\f5mY;nk2\fGo\u001c:\u000b\u0005\r!\u0011A\u00017b\u0015\t)a!\u0001\u0005gC\u000e$xN]5f\u0015\u00059\u0011AA2d\u0007\u0001\u00192\u0001\u0001\u0006\u0011!\tYa\"D\u0001\r\u0015\u0005i\u0011!B:dC2\f\u0017BA\b\r\u0005\u0019\te.\u001f*fMB\u0011\u0011CE\u0007\u0002\u0005%\u00111C\u0001\u0002\u0016/\u0016Lw\r\u001b;t\u001b\u0006\u0004\u0018iY2v[Vd\u0017\r^8s\u0011\u0015)\u0002\u0001\"\u0001\u0017\u0003\u0019a\u0014N\\5u}Q\tq\u0003\u0005\u0002\u0012\u0001!9\u0011\u0004\u0001b\u0001\n\u0003Q\u0012aA7baV\t1\u0004\u0005\u0002\u001d?5\tQD\u0003\u0002\u001f\t\u0005)Qn\u001c3fY&\u0011\u0001%\b\u0002\u000b/\u0016Lw\r\u001b;t\u001b\u0006\u0004\bB\u0002\u0012\u0001A\u0003%1$\u0001\u0003nCB\u0004\u0003b\u0002\u0013\u0001\u0005\u0004%\t!J\u0001\tgR\fG/Z'baV\ta\u0005\u0005\u0003(Y9\nT\"\u0001\u0015\u000b\u0005%R\u0013aB7vi\u0006\u0014G.\u001a\u0006\u0003W1\t!bY8mY\u0016\u001cG/[8o\u0013\ti\u0003FA\u0004ICNDW*\u00199\u0011\u0005qy\u0013B\u0001\u0019\u001e\u0005\u001d9V-[4iiN\u0004\"a\u0003\u001a\n\u0005Mb!aA%oi\"1Q\u0007\u0001Q\u0001\n\u0019\n\u0011b\u001d;bi\u0016l\u0015\r\u001d\u0011\t\u000f]\u0002!\u0019!C\u0001q\u0005)Q)\u0014)U3V\t\u0011\u0007\u0003\u0004;\u0001\u0001\u0006I!M\u0001\u0007\u000b6\u0003F+\u0017\u0011\t\u000fq\u0002!\u0019!C\u0001q\u0005i1+\u0013(H\u0019\u0016{F+\u0012(T\u001fJCaA\u0010\u0001!\u0002\u0013\t\u0014AD*J\u001d\u001ecUi\u0018+F\u001dN{%\u000b\t\u0005\b\u0001\u0002\u0011\r\u0011\"\u00019\u0003-\t5iQ+N+2\u000bEk\u0014*\t\r\t\u0003\u0001\u0015!\u00032\u00031\t5iQ+N+2\u000bEk\u0014*!\u0011\u0015!\u0005\u0001\"\u0001F\u0003\u0015\u0019G.Z1s)\u00051\u0005CA\u0006H\u0013\tAEB\u0001\u0003V]&$\b\"\u0002&\u0001\t\u0003Q\u0012AB4fi6\u000b\u0007\u000fC\u0003M\u0001\u0011\u0005Q*\u0001\u0006bG\u000e,X.\u001e7bi\u0016$BA\u0012(Q+\")qj\u0013a\u0001]\u0005\u00191.Z=\t\u000bE[\u0005\u0019\u0001*\u0002\u0003Q\u0004\"!E*\n\u0005Q\u0013!A\u0002+f]N|'\u000fC\u0003W\u0017\u0002\u0007q+A\u0001e!\tY\u0001,\u0003\u0002Z\u0019\t1Ai\\;cY\u0016DQ\u0001\u0014\u0001\u0005\u0002m#2A\u0012/^\u0011\u0015y%\f1\u0001/\u0011\u0015\t&\f1\u0001S\u0001")
/* loaded from: input_file:cc/factorie/la/SmartGradientAccumulator.class */
public class SmartGradientAccumulator implements WeightsMapAccumulator {
    private final WeightsMap map;
    private final HashMap<Weights, Object> stateMap;
    private final int EMPTY;
    private final int SINGLE_TENSOR;
    private final int ACCUMULATOR;

    @Override // cc.factorie.la.WeightsMapAccumulator
    public void accumulate(WeightsMap weightsMap) {
        WeightsMapAccumulator.Cclass.accumulate(this, weightsMap);
    }

    @Override // cc.factorie.la.WeightsMapAccumulator
    public void accumulate(WeightsMap weightsMap, double d) {
        WeightsMapAccumulator.Cclass.accumulate(this, weightsMap, d);
    }

    public WeightsMap map() {
        return this.map;
    }

    public HashMap<Weights, Object> stateMap() {
        return this.stateMap;
    }

    public int EMPTY() {
        return this.EMPTY;
    }

    public int SINGLE_TENSOR() {
        return this.SINGLE_TENSOR;
    }

    public int ACCUMULATOR() {
        return this.ACCUMULATOR;
    }

    public void clear() {
        map().clear();
        stateMap().clear();
    }

    public WeightsMap getMap() {
        return map();
    }

    @Override // cc.factorie.la.WeightsMapAccumulator
    public void accumulate(Weights weights, Tensor tensor, double d) {
        Tensor sparseIndexedTensor4;
        int unboxToInt = BoxesRunTime.unboxToInt(stateMap().getOrElse(weights, new SmartGradientAccumulator$$anonfun$1(this)));
        if (ACCUMULATOR() == unboxToInt) {
            map().apply(weights).$plus$eq(tensor, d);
            BoxedUnit boxedUnit = BoxedUnit.UNIT;
            return;
        }
        if (SINGLE_TENSOR() == unboxToInt) {
            Tensor apply = map().apply(weights);
            if (apply instanceof Outer1Tensor2) {
                Outer1Tensor2 outer1Tensor2 = (Outer1Tensor2) apply;
                if (outer1Tensor2.tensor1().mo1657isDense() && outer1Tensor2.tensor2().mo1657isDense()) {
                    sparseIndexedTensor4 = new DenseTensor2(outer1Tensor2.dim1(), outer1Tensor2.dim2());
                    Tensor tensor2 = sparseIndexedTensor4;
                    tensor2.$plus$eq(map().apply(weights));
                    tensor2.$plus$eq(tensor, d);
                    map().update(weights, tensor2);
                    stateMap().update(weights, BoxesRunTime.boxToInteger(ACCUMULATOR()));
                    BoxedUnit boxedUnit2 = BoxedUnit.UNIT;
                    return;
                }
            }
            if (apply instanceof Tensor1) {
                sparseIndexedTensor4 = new SparseIndexedTensor1(((Tensor1) apply).dim1());
            } else if (apply instanceof Tensor2) {
                Tensor2 tensor22 = (Tensor2) apply;
                sparseIndexedTensor4 = new SparseIndexedTensor2(tensor22.dim1(), tensor22.dim2());
            } else if (apply instanceof Tensor3) {
                Tensor3 tensor3 = (Tensor3) apply;
                sparseIndexedTensor4 = new SparseIndexedTensor3(tensor3.dim1(), tensor3.dim2(), tensor3.dim3());
            } else {
                if (!(apply instanceof Tensor4)) {
                    throw new Error(new StringContext(Predef$.MODULE$.wrapRefArray(new String[]{"Any concrete tensor should be either a Tensor1, Tensor2, Tensor3, or Tensor4. Offending class: ", ""})).s(Predef$.MODULE$.genericWrapArray(new Object[]{map().apply(weights).getClass().getName()})));
                }
                Tensor4 tensor4 = (Tensor4) apply;
                sparseIndexedTensor4 = new SparseIndexedTensor4(tensor4.dim1(), tensor4.dim2(), tensor4.dim3(), tensor4.dim4());
            }
            Tensor tensor23 = sparseIndexedTensor4;
            tensor23.$plus$eq(map().apply(weights));
            tensor23.$plus$eq(tensor, d);
            map().update(weights, tensor23);
            stateMap().update(weights, BoxesRunTime.boxToInteger(ACCUMULATOR()));
            BoxedUnit boxedUnit22 = BoxedUnit.UNIT;
            return;
        }
        if (EMPTY() != unboxToInt) {
            throw new MatchError(BoxesRunTime.boxToInteger(unboxToInt));
        }
        boolean z = false;
        SparseTensor sparseTensor = null;
        if (tensor instanceof SparseTensor) {
            z = true;
            sparseTensor = (SparseTensor) tensor;
            if (!(sparseTensor instanceof SparseIndexedTensor)) {
                stateMap().update(weights, BoxesRunTime.boxToInteger(SINGLE_TENSOR()));
                Tensor newSparse = Tensor$.MODULE$.newSparse(sparseTensor);
                newSparse.$plus$eq(sparseTensor, d);
                map().update(weights, newSparse);
                BoxedUnit boxedUnit3 = BoxedUnit.UNIT;
                BoxedUnit boxedUnit4 = BoxedUnit.UNIT;
            }
        }
        if (tensor instanceof Singleton2BinaryLayeredTensor3) {
            Singleton2BinaryLayeredTensor3 singleton2BinaryLayeredTensor3 = (Singleton2BinaryLayeredTensor3) tensor;
            stateMap().update(weights, BoxesRunTime.boxToInteger(ACCUMULATOR()));
            Tensor newSparse2 = Tensor$.MODULE$.newSparse(singleton2BinaryLayeredTensor3);
            newSparse2.$plus$eq(singleton2BinaryLayeredTensor3, d);
            map().update(weights, newSparse2);
            BoxedUnit boxedUnit5 = BoxedUnit.UNIT;
        } else if (tensor instanceof DenseTensor) {
            stateMap().update(weights, BoxesRunTime.boxToInteger(ACCUMULATOR()));
            Tensor copy = ((DenseTensor) tensor).copy();
            copy.$times$eq(d);
            map().update(weights, copy);
            BoxedUnit boxedUnit6 = BoxedUnit.UNIT;
        } else if (z) {
            stateMap().update(weights, BoxesRunTime.boxToInteger(ACCUMULATOR()));
            Tensor copy2 = sparseTensor.copy();
            copy2.$times$eq(d);
            map().update(weights, copy2);
            BoxedUnit boxedUnit7 = BoxedUnit.UNIT;
        } else if (tensor instanceof Outer2Tensor) {
            Outer2Tensor outer2Tensor = (Outer2Tensor) tensor;
            stateMap().update(weights, BoxesRunTime.boxToInteger(SINGLE_TENSOR()));
            outer2Tensor.$times$eq(d);
            map().update(weights, outer2Tensor);
            BoxedUnit boxedUnit8 = BoxedUnit.UNIT;
        } else if (tensor instanceof ReadOnlyTensor) {
            ReadOnlyTensor readOnlyTensor = (ReadOnlyTensor) tensor;
            stateMap().update(weights, BoxesRunTime.boxToInteger(ACCUMULATOR()));
            Tensor newSparse3 = Tensor$.MODULE$.newSparse(readOnlyTensor);
            newSparse3.$plus$eq(readOnlyTensor, d);
            map().update(weights, newSparse3);
            BoxedUnit boxedUnit9 = BoxedUnit.UNIT;
        } else {
            if (tensor == null) {
                throw new MatchError(tensor);
            }
            stateMap().update(weights, BoxesRunTime.boxToInteger(ACCUMULATOR()));
            Tensor newDense = Tensor$.MODULE$.newDense(tensor);
            newDense.$plus$eq(tensor, d);
            map().update(weights, newDense);
            BoxedUnit boxedUnit10 = BoxedUnit.UNIT;
        }
        BoxedUnit boxedUnit42 = BoxedUnit.UNIT;
    }

    @Override // cc.factorie.la.WeightsMapAccumulator
    public void accumulate(Weights weights, Tensor tensor) {
        accumulate(weights, tensor, 1.0d);
    }

    public SmartGradientAccumulator() {
        WeightsMapAccumulator.Cclass.$init$(this);
        this.map = new WeightsMap(new SmartGradientAccumulator$$anonfun$2(this));
        this.stateMap = HashMap$.MODULE$.apply(Nil$.MODULE$);
        this.EMPTY = 0;
        this.SINGLE_TENSOR = 1;
        this.ACCUMULATOR = 3;
    }
}
