package cc.factorie.app.mf;

import cc.factorie.la.DenseTensor2;
import cc.factorie.la.Tensor1;
import cc.factorie.la.Tensor2;
import cc.factorie.la.Tensor3;
import cc.factorie.la.Tensor4;
import cc.factorie.la.WeightsMapAccumulator;
import cc.factorie.model.Parameters;
import cc.factorie.model.Weights1;
import cc.factorie.model.Weights2;
import cc.factorie.model.Weights3;
import cc.factorie.model.Weights4;
import cc.factorie.model.WeightsSet;
import cc.factorie.optimize.Example;
import cc.factorie.util.DoubleAccumulator;
import cc.factorie.variable.DiscreteDomain;
import java.util.Random;
import scala.Function0;
import scala.collection.Seq;
import scala.math.Ordering$Double$;
import scala.reflect.ScalaSignature;
import scala.runtime.BoxesRunTime;

/* compiled from: WSabie.scala */
@ScalaSignature(bytes = "\u0006\u0001\u0005]s!B\u0001\u0003\u0011\u0003Y\u0011AB,TC\nLWM\u0003\u0002\u0004\t\u0005\u0011QN\u001a\u0006\u0003\u000b\u0019\t1!\u00199q\u0015\t9\u0001\"\u0001\u0005gC\u000e$xN]5f\u0015\u0005I\u0011AA2d\u0007\u0001\u0001\"\u0001D\u0007\u000e\u0003\t1QA\u0004\u0002\t\u0002=\u0011aaV*bE&,7CA\u0007\u0011!\t\tB#D\u0001\u0013\u0015\u0005\u0019\u0012!B:dC2\f\u0017BA\u000b\u0013\u0005\u0019\te.\u001f*fM\")q#\u0004C\u00011\u00051A(\u001b8jiz\"\u0012a\u0003\u0004\u000555\u00011DA\u0006X'\u0006\u0014\u0017.Z'pI\u0016d7cA\r\u00119A\u0011Q\u0004I\u0007\u0002=)\u0011qDB\u0001\u0006[>$W\r\\\u0005\u0003Cy\u0011!\u0002U1sC6,G/\u001a:t\u0011!\u0019\u0013D!b\u0001\n\u0003!\u0013A\u00023p[\u0006Lg.F\u0001&!\t1\u0013&D\u0001(\u0015\tAc!\u0001\u0005wCJL\u0017M\u00197f\u0013\tQsE\u0001\bESN\u001c'/\u001a;f\t>l\u0017-\u001b8\t\u00111J\"\u0011!Q\u0001\n\u0015\nq\u0001Z8nC&t\u0007\u0005\u0003\u0005/3\t\u0015\r\u0011\"\u00010\u00035qW/\\#nE\u0016$G-\u001b8hgV\t\u0001\u0007\u0005\u0002\u0012c%\u0011!G\u0005\u0002\u0004\u0013:$\b\u0002\u0003\u001b\u001a\u0005\u0003\u0005\u000b\u0011\u0002\u0019\u0002\u001d9,X.R7cK\u0012$\u0017N\\4tA!Aa'\u0007BC\u0002\u0013\u0005q'A\u0002s]\u001e,\u0012\u0001\u000f\t\u0003syj\u0011A\u000f\u0006\u0003wq\nA!\u001e;jY*\tQ(\u0001\u0003kCZ\f\u0017BA ;\u0005\u0019\u0011\u0016M\u001c3p[\"A\u0011)\u0007B\u0001B\u0003%\u0001(\u0001\u0003s]\u001e\u0004\u0003\"B\f\u001a\t\u0003\u0019E\u0003\u0002#G\u000f\"\u0003\"!R\r\u000e\u00035AQa\t\"A\u0002\u0015BQA\f\"A\u0002ABQA\u000e\"A\u0002aBqAS\rC\u0002\u0013\u00051*A\u0004xK&<\u0007\u000e^:\u0016\u00031\u0003\"!H'\n\u00059s\"\u0001C,fS\u001eDGo\u001d\u001a\t\rAK\u0002\u0015!\u0003M\u0003!9X-[4iiN\u0004\u0003\"\u0002*\u001a\t\u0003\u0019\u0016aC:fiR{'+\u00198e_6$2\u0001\u0016.]!\t)\u0006,D\u0001W\u0015\t9f!\u0001\u0002mC&\u0011\u0011L\u0016\u0002\r\t\u0016t7/\u001a+f]N|'O\r\u0005\u00067F\u0003\r\u0001V\u0001\u0002i\")a'\u0015a\u0001q!)a,\u0007C\u0001?\u0006)1oY8sKR\u0019\u0001m\u00195\u0011\u0005E\t\u0017B\u00012\u0013\u0005\u0019!u.\u001e2mK\")A-\u0018a\u0001K\u0006)\u0011/^3ssB\u0011QKZ\u0005\u0003OZ\u0013q\u0001V3og>\u0014\u0018\u0007C\u0003j;\u0002\u0007Q-\u0001\u0004wK\u000e$xN\u001d\u0005\u0006Wf!\t\u0001\\\u0001\u0005e\u0006t7\u000eF\u0002nsj\u00042A\u001c<f\u001d\tyGO\u0004\u0002qg6\t\u0011O\u0003\u0002s\u0015\u00051AH]8pizJ\u0011aE\u0005\u0003kJ\tq\u0001]1dW\u0006<W-\u0003\u0002xq\n\u00191+Z9\u000b\u0005U\u0014\u0002\"\u00023k\u0001\u0004)\u0007\"B>k\u0001\u0004i\u0017a\u0002<fGR|'o\u001d\u0004\u0005{6\u0001aPA\u0007X'\u0006\u0014\u0017.Z#yC6\u0004H.Z\n\u0004yBy\b\u0003BA\u0001\u0003\u000fi!!a\u0001\u000b\u0007\u0005\u0015a!\u0001\u0005paRLW.\u001b>f\u0013\u0011\tI!a\u0001\u0003\u000f\u0015C\u0018-\u001c9mK\"Aq\u0004 B\u0001B\u0003%A\tC\u0005ey\n\u0015\r\u0011\"\u0001\u0002\u0010U\tQ\rC\u0005\u0002\u0014q\u0014\t\u0011)A\u0005K\u00061\u0011/^3ss\u0002B!\"a\u0006}\u0005\u000b\u0007I\u0011AA\b\u0003!\u0001xn]5uSZ,\u0007\"CA\u000ey\n\u0005\t\u0015!\u0003f\u0003%\u0001xn]5uSZ,\u0007\u0005\u0003\u0006\u0002 q\u0014)\u0019!C\u0001\u0003\u001f\t\u0001B\\3hCRLg/\u001a\u0005\n\u0003Ga(\u0011!Q\u0001\n\u0015\f\u0011B\\3hCRLg/\u001a\u0011\t\r]aH\u0011AA\u0014))\tI#a\u000b\u0002.\u0005=\u0012\u0011\u0007\t\u0003\u000brDaaHA\u0013\u0001\u0004!\u0005B\u00023\u0002&\u0001\u0007Q\rC\u0004\u0002\u0018\u0005\u0015\u0002\u0019A3\t\u000f\u0005}\u0011Q\u0005a\u0001K\"9\u0011Q\u0007?\u0005\u0002\u0005]\u0012AG1dGVlW\u000f\\1uKZ\u000bG.^3B]\u0012<%/\u00193jK:$HCBA\u001d\u0003\u007f\ti\u0005E\u0002\u0012\u0003wI1!!\u0010\u0013\u0005\u0011)f.\u001b;\t\u0011\u0005\u0005\u00131\u0007a\u0001\u0003\u0007\nQA^1mk\u0016\u0004B!!\u0012\u0002J5\u0011\u0011q\t\u0006\u0003w\u0019IA!a\u0013\u0002H\t\tBi\\;cY\u0016\f5mY;nk2\fGo\u001c:\t\u0011\u0005=\u00131\u0007a\u0001\u0003#\n\u0001b\u001a:bI&,g\u000e\u001e\t\u0004+\u0006M\u0013bAA+-\n)r+Z5hQR\u001cX*\u00199BG\u000e,X.\u001e7bi>\u0014\b")
/* loaded from: input_file:cc/factorie/app/mf/WSabie.class */
public final class WSabie {

    /* compiled from: WSabie.scala */
    /* loaded from: input_file:cc/factorie/app/mf/WSabie$WSabieExample.class */
    public static class WSabieExample implements Example {
        private final WSabieModel model;
        private final Tensor1 query;
        private final Tensor1 positive;
        private final Tensor1 negative;

        public Tensor1 query() {
            return this.query;
        }

        public Tensor1 positive() {
            return this.positive;
        }

        public Tensor1 negative() {
            return this.negative;
        }

        @Override // cc.factorie.optimize.Example
        public void accumulateValueAndGradient(DoubleAccumulator doubleAccumulator, WeightsMapAccumulator weightsMapAccumulator) {
            Tensor2 tensor2 = (Tensor2) this.model.weights().mo1466value();
            Tensor1 $times = tensor2.$times(query());
            Tensor1 $times2 = tensor2.$times(positive());
            Tensor1 $times3 = tensor2.$times(negative());
            double dot = $times.mo1562dot($times2);
            double dot2 = $times.mo1562dot($times3);
            if (dot < dot2 + 1) {
                if (doubleAccumulator != null) {
                    doubleAccumulator.accumulate(BoxesRunTime.boxToDouble((dot - dot2) - 1));
                }
                if (weightsMapAccumulator != null) {
                    weightsMapAccumulator.accumulate(this.model.weights(), $times2.outer(query()));
                    weightsMapAccumulator.accumulate(this.model.weights(), $times3.outer(query()), -1.0d);
                    weightsMapAccumulator.accumulate(this.model.weights(), $times.outer(positive()));
                    weightsMapAccumulator.accumulate(this.model.weights(), $times.outer(negative()), -1.0d);
                }
            }
        }

        public WSabieExample(WSabieModel wSabieModel, Tensor1 tensor1, Tensor1 tensor12, Tensor1 tensor13) {
            this.model = wSabieModel;
            this.query = tensor1;
            this.positive = tensor12;
            this.negative = tensor13;
        }
    }

    /* compiled from: WSabie.scala */
    /* loaded from: input_file:cc/factorie/app/mf/WSabie$WSabieModel.class */
    public static class WSabieModel implements Parameters {
        private final DiscreteDomain domain;
        private final int numEmbeddings;
        private final Random rng;
        private final Weights2 weights;
        private final WeightsSet parameters;

        @Override // cc.factorie.model.Parameters
        public WeightsSet parameters() {
            return this.parameters;
        }

        @Override // cc.factorie.model.Parameters
        public void cc$factorie$model$Parameters$_setter_$parameters_$eq(WeightsSet weightsSet) {
            this.parameters = weightsSet;
        }

        @Override // cc.factorie.model.Parameters
        public Weights1 Weights(Function0<Tensor1> function0) {
            return Parameters.Cclass.Weights((Parameters) this, (Function0) function0);
        }

        @Override // cc.factorie.model.Parameters
        /* renamed from: Weights */
        public Weights2 mo153Weights(Function0<Tensor2> function0) {
            return Parameters.Cclass.m1765Weights((Parameters) this, (Function0) function0);
        }

        @Override // cc.factorie.model.Parameters
        /* renamed from: Weights */
        public Weights3 mo154Weights(Function0<Tensor3> function0) {
            return Parameters.Cclass.m1766Weights((Parameters) this, (Function0) function0);
        }

        @Override // cc.factorie.model.Parameters
        /* renamed from: Weights */
        public Weights4 mo155Weights(Function0<Tensor4> function0) {
            return Parameters.Cclass.m1767Weights((Parameters) this, (Function0) function0);
        }

        public DiscreteDomain domain() {
            return this.domain;
        }

        public int numEmbeddings() {
            return this.numEmbeddings;
        }

        public Random rng() {
            return this.rng;
        }

        public Weights2 weights() {
            return this.weights;
        }

        public DenseTensor2 setToRandom(DenseTensor2 denseTensor2, Random random) {
            int i = 0;
            while (true) {
                int i2 = i;
                if (i2 >= denseTensor2.length()) {
                    return denseTensor2;
                }
                denseTensor2.update(i2, random.nextGaussian());
                i = i2 + 1;
            }
        }

        public double score(Tensor1 tensor1, Tensor1 tensor12) {
            return ((Tensor2) weights().mo1466value()).$times(tensor1).mo1562dot(((Tensor2) weights().mo1466value()).$times(tensor12));
        }

        public Seq<Tensor1> rank(Tensor1 tensor1, Seq<Tensor1> seq) {
            return (Seq) seq.sortBy(new WSabie$WSabieModel$$anonfun$rank$1(this, tensor1), Ordering$Double$.MODULE$);
        }

        public WSabieModel(DiscreteDomain discreteDomain, int i, Random random) {
            this.domain = discreteDomain;
            this.numEmbeddings = i;
            this.rng = random;
            cc$factorie$model$Parameters$_setter_$parameters_$eq(new WeightsSet());
            this.weights = mo153Weights((Function0<Tensor2>) new WSabie$WSabieModel$$anonfun$1(this));
        }
    }
}
