package cc.factorie.app.uschema;

import cc.factorie.la.DenseTensor;
import cc.factorie.la.DenseTensor1;
import cc.factorie.la.DenseTensorLike1;
import cc.factorie.la.Tensor1;
import scala.math.package$;
import scala.reflect.ScalaSignature;
import scala.util.Random;

/* compiled from: UniversalSchemaTrainer.scala */
@ScalaSignature(bytes = "\u0006\u0001u3A!\u0001\u0002\u0001\u0017\t!#+Z4vY\u0006\u0014\u0018N_3e\u0005B\u0014XK\\5wKJ\u001c\u0018\r\\*dQ\u0016l\u0017\r\u0016:bS:,'O\u0003\u0002\u0004\t\u00059Qo]2iK6\f'BA\u0003\u0007\u0003\r\t\u0007\u000f\u001d\u0006\u0003\u000f!\t\u0001BZ1di>\u0014\u0018.\u001a\u0006\u0002\u0013\u0005\u00111mY\u0002\u0001'\t\u0001A\u0002\u0005\u0002\u000e\u001d5\t!!\u0003\u0002\u0010\u0005\tI\"\t\u001d:V]&4XM]:bYN\u001b\u0007.Z7b)J\f\u0017N\\3s\u0011!\t\u0002A!b\u0001\n\u0003\u0011\u0012a\u0003:fOVd\u0017M]5{KJ,\u0012a\u0005\t\u0003)]i\u0011!\u0006\u0006\u0002-\u0005)1oY1mC&\u0011\u0001$\u0006\u0002\u0007\t>,(\r\\3\t\u0011i\u0001!\u0011!Q\u0001\nM\tAB]3hk2\f'/\u001b>fe\u0002B\u0001\u0002\b\u0001\u0003\u0006\u0004%\tAE\u0001\tgR,\u0007o]5{K\"Aa\u0004\u0001B\u0001B\u0003%1#A\u0005ti\u0016\u00048/\u001b>fA!A\u0001\u0005\u0001BC\u0002\u0013\u0005\u0011%A\u0002eS6,\u0012A\t\t\u0003)\rJ!\u0001J\u000b\u0003\u0007%sG\u000f\u0003\u0005'\u0001\t\u0005\t\u0015!\u0003#\u0003\u0011!\u0017.\u001c\u0011\t\u0011!\u0002!Q1A\u0005\u0002%\na!\\1ue&DX#\u0001\u0016\u0011\u00055Y\u0013B\u0001\u0017\u0003\u0005)\u0019un\\2NCR\u0014\u0018\u000e\u001f\u0005\t]\u0001\u0011\t\u0011)A\u0005U\u00059Q.\u0019;sSb\u0004\u0003\u0002\u0003\u0019\u0001\u0005\u000b\u0007I\u0011A\u0019\u0002\u000b5|G-\u001a7\u0016\u0003I\u0002\"!D\u001a\n\u0005Q\u0012!\u0001F+oSZ,'o]1m'\u000eDW-\\1N_\u0012,G\u000e\u0003\u00057\u0001\t\u0005\t\u0015!\u00033\u0003\u0019iw\u000eZ3mA!A\u0001\b\u0001BC\u0002\u0013\u0005\u0011(\u0001\u0004sC:$w.\\\u000b\u0002uA\u00111HP\u0007\u0002y)\u0011Q(F\u0001\u0005kRLG.\u0003\u0002@y\t1!+\u00198e_6D\u0001\"\u0011\u0001\u0003\u0002\u0003\u0006IAO\u0001\be\u0006tGm\\7!\u0011\u0015\u0019\u0005\u0001\"\u0001E\u0003\u0019a\u0014N\\5u}Q9QIR$I\u0013*[\u0005CA\u0007\u0001\u0011\u0015\t\"\t1\u0001\u0014\u0011\u0015a\"\t1\u0001\u0014\u0011\u0015\u0001#\t1\u0001#\u0011\u0015A#\t1\u0001+\u0011\u0015\u0001$\t1\u00013\u0011\u0015A$\t1\u0001;\u0011\u001di\u0005A1A\u0005\u0002I\taB]8x%\u0016<W\u000f\\1sSj,'\u000f\u0003\u0004P\u0001\u0001\u0006IaE\u0001\u0010e><(+Z4vY\u0006\u0014\u0018N_3sA!9\u0011\u000b\u0001b\u0001\n\u0003\u0011\u0012AD2pYJ+w-\u001e7be&TXM\u001d\u0005\u0007'\u0002\u0001\u000b\u0011B\n\u0002\u001f\r|GNU3hk2\f'/\u001b>fe\u0002BQ!\u0016\u0001\u0005BY\u000ba\"\u001e9eCR,'\t\u001d:DK2d7\u000f\u0006\u0003\u0014/f[\u0006\"\u0002-U\u0001\u0004\u0011\u0013\u0001\u0004:po&sG-\u001a=UeV,\u0007\"\u0002.U\u0001\u0004\u0011\u0013!\u0004:po&sG-\u001a=GC2\u001cX\rC\u0003])\u0002\u0007!%\u0001\u0005d_2Le\u000eZ3y\u0001")
/* loaded from: input_file:cc/factorie/app/uschema/RegularizedBprUniversalSchemaTrainer.class */
public class RegularizedBprUniversalSchemaTrainer extends BprUniversalSchemaTrainer {
    private final double regularizer;
    private final double stepsize;
    private final int dim;
    private final CoocMatrix matrix;
    private final UniversalSchemaModel model;
    private final Random random;
    private final double rowRegularizer;
    private final double colRegularizer;

    public double regularizer() {
        return this.regularizer;
    }

    @Override // cc.factorie.app.uschema.BprUniversalSchemaTrainer
    public double stepsize() {
        return this.stepsize;
    }

    @Override // cc.factorie.app.uschema.BprUniversalSchemaTrainer
    public int dim() {
        return this.dim;
    }

    @Override // cc.factorie.app.uschema.BprUniversalSchemaTrainer
    public CoocMatrix matrix() {
        return this.matrix;
    }

    @Override // cc.factorie.app.uschema.BprUniversalSchemaTrainer
    public UniversalSchemaModel model() {
        return this.model;
    }

    @Override // cc.factorie.app.uschema.BprUniversalSchemaTrainer
    public Random random() {
        return this.random;
    }

    public double rowRegularizer() {
        return this.rowRegularizer;
    }

    public double colRegularizer() {
        return this.colRegularizer;
    }

    @Override // cc.factorie.app.uschema.BprUniversalSchemaTrainer
    public double updateBprCells(int i, int i2, int i3) {
        double calculateProb = UniversalSchemaModel$.MODULE$.calculateProb(model().score(i, i3) - model().score(i2, i3));
        double log = package$.MODULE$.log(calculateProb);
        double stepsize = stepsize() * (1 - calculateProb);
        DenseTensor1 copy = ((DenseTensor1) model().colVectors().apply(i3)).copy();
        ((DenseTensor) model().colVectors().apply(i3)).$times$eq(1 - (stepsize() * colRegularizer()));
        ((DenseTensorLike1) model().colVectors().apply(i3)).$plus$eq(((Tensor1) model().rowVectors().apply(i)).$minus((Tensor1) model().rowVectors().apply(i2)), stepsize);
        ((DenseTensor) model().rowVectors().apply(i)).$times$eq(1 - (stepsize() * rowRegularizer()));
        ((DenseTensorLike1) model().rowVectors().apply(i)).$plus$eq(copy, stepsize);
        ((DenseTensor) model().rowVectors().apply(i2)).$times$eq(1 - (stepsize() * rowRegularizer()));
        ((DenseTensorLike1) model().rowVectors().apply(i2)).$plus$eq(copy, -stepsize);
        return log;
    }

    public RegularizedBprUniversalSchemaTrainer(double d, double d2, int i, CoocMatrix coocMatrix, UniversalSchemaModel universalSchemaModel, Random random) {
        this.regularizer = d;
        this.stepsize = d2;
        this.dim = i;
        this.matrix = coocMatrix;
        this.model = universalSchemaModel;
        this.random = random;
        this.rowRegularizer = d;
        this.colRegularizer = d;
    }
}
