package cc.factorie.app.uschema.tac;

import cc.factorie.app.uschema.BprUniversalSchemaTrainer;
import cc.factorie.app.uschema.EntityRelationKBMatrix;
import cc.factorie.app.uschema.EntityRelationKBMatrix$;
import cc.factorie.app.uschema.Evaluator$;
import cc.factorie.app.uschema.NormConstrainedBprUniversalSchemaTrainer;
import cc.factorie.app.uschema.RegularizedBprUniversalSchemaTrainer;
import cc.factorie.app.uschema.UniversalSchemaModel;
import cc.factorie.app.uschema.UniversalSchemaModel$;
import scala.MatchError;
import scala.None$;
import scala.Predef$;
import scala.Some;
import scala.Tuple3;
import scala.collection.immutable.Set;
import scala.collection.immutable.StringOps;
import scala.collection.mutable.StringBuilder;
import scala.runtime.BoxesRunTime;
import scala.util.Random;

/* compiled from: TrainTestTacData.scala */
/* loaded from: input_file:cc/factorie/app/uschema/tac/TrainTestTacData$.class */
public final class TrainTestTacData$ {
    public static final TrainTestTacData$ MODULE$ = null;
    private final TrainTestTacDataOptions opts;
    private final Set<String> testCols;

    static {
        new TrainTestTacData$();
    }

    public TrainTestTacDataOptions opts() {
        return this.opts;
    }

    public Set<String> testCols() {
        return this.testCols;
    }

    public void main(String[] strArr) {
        BprUniversalSchemaTrainer regularizedBprUniversalSchemaTrainer;
        opts().parse(Predef$.MODULE$.wrapRefArray(strArr));
        long currentTimeMillis = System.currentTimeMillis();
        EntityRelationKBMatrix entityRelationKBMatrix = (EntityRelationKBMatrix) EntityRelationKBMatrix$.MODULE$.fromTsv(opts().tacData().value(), EntityRelationKBMatrix$.MODULE$.fromTsv$default$2()).prune(2, 1);
        Predef$.MODULE$.println(new StringOps("Reading from file and pruning took %.2f s").format(Predef$.MODULE$.genericWrapArray(new Object[]{BoxesRunTime.boxToDouble((System.currentTimeMillis() - currentTimeMillis) / 1000.0d)})));
        Predef$.MODULE$.println("Stats:");
        Predef$.MODULE$.println(new StringBuilder().append("Num Rows:").append(BoxesRunTime.boxToInteger(entityRelationKBMatrix.numRows())).toString());
        Predef$.MODULE$.println(new StringBuilder().append("Num Cols:").append(BoxesRunTime.boxToInteger(entityRelationKBMatrix.numCols())).toString());
        Predef$.MODULE$.println(new StringBuilder().append("Num cells:").append(BoxesRunTime.boxToInteger(entityRelationKBMatrix.nnz())).toString());
        Random random = new Random(0);
        Tuple3<EntityRelationKBMatrix, EntityRelationKBMatrix, EntityRelationKBMatrix> randomTestSplit = entityRelationKBMatrix.randomTestSplit(0, 10000, None$.MODULE$, new Some(testCols()), random);
        if (randomTestSplit == null) {
            throw new MatchError(randomTestSplit);
        }
        Tuple3 tuple3 = new Tuple3((EntityRelationKBMatrix) randomTestSplit._1(), (EntityRelationKBMatrix) randomTestSplit._2(), (EntityRelationKBMatrix) randomTestSplit._3());
        EntityRelationKBMatrix entityRelationKBMatrix2 = (EntityRelationKBMatrix) tuple3._1();
        EntityRelationKBMatrix entityRelationKBMatrix3 = (EntityRelationKBMatrix) tuple3._3();
        UniversalSchemaModel randomModel = UniversalSchemaModel$.MODULE$.randomModel(entityRelationKBMatrix.numRows(), entityRelationKBMatrix.numCols(), BoxesRunTime.unboxToInt(opts().dim().value()), random);
        if (BoxesRunTime.unboxToBoolean(opts().useMaxNorm().value())) {
            Predef$.MODULE$.println("use norm constraint");
            regularizedBprUniversalSchemaTrainer = new NormConstrainedBprUniversalSchemaTrainer(BoxesRunTime.unboxToDouble(opts().maxNorm().value()), BoxesRunTime.unboxToDouble(opts().stepsize().value()), BoxesRunTime.unboxToInt(opts().dim().value()), entityRelationKBMatrix2.matrix(), randomModel, random);
        } else {
            Predef$.MODULE$.println("use regularization");
            regularizedBprUniversalSchemaTrainer = new RegularizedBprUniversalSchemaTrainer(BoxesRunTime.unboxToDouble(opts().regularizer().value()), BoxesRunTime.unboxToDouble(opts().stepsize().value()), BoxesRunTime.unboxToInt(opts().dim().value()), entityRelationKBMatrix2.matrix(), randomModel, random);
        }
        BprUniversalSchemaTrainer bprUniversalSchemaTrainer = regularizedBprUniversalSchemaTrainer;
        Predef$.MODULE$.println(new StringBuilder().append("Initial MAP: ").append(BoxesRunTime.boxToDouble(Evaluator$.MODULE$.meanAveragePrecision(randomModel.similaritiesAndLabels(entityRelationKBMatrix2.matrix(), entityRelationKBMatrix3.matrix(), randomModel.similaritiesAndLabels$default$3())))).toString());
        bprUniversalSchemaTrainer.train(10);
        Predef$.MODULE$.println(new StringBuilder().append("MAP after 10 iterations: ").append(BoxesRunTime.boxToDouble(Evaluator$.MODULE$.meanAveragePrecision(randomModel.similaritiesAndLabels(entityRelationKBMatrix2.matrix(), entityRelationKBMatrix3.matrix(), randomModel.similaritiesAndLabels$default$3())))).toString());
        bprUniversalSchemaTrainer.train(40);
        Predef$.MODULE$.println(new StringBuilder().append("MAP after 50 iterations: ").append(BoxesRunTime.boxToDouble(Evaluator$.MODULE$.meanAveragePrecision(randomModel.similaritiesAndLabels(entityRelationKBMatrix2.matrix(), entityRelationKBMatrix3.matrix(), randomModel.similaritiesAndLabels$default$3())))).toString());
        bprUniversalSchemaTrainer.train(50);
        Predef$.MODULE$.println(new StringBuilder().append("MAP after 100 iterations: ").append(BoxesRunTime.boxToDouble(Evaluator$.MODULE$.meanAveragePrecision(randomModel.similaritiesAndLabels(entityRelationKBMatrix2.matrix(), entityRelationKBMatrix3.matrix(), randomModel.similaritiesAndLabels$default$3())))).toString());
        bprUniversalSchemaTrainer.train(100);
        Predef$.MODULE$.println(new StringBuilder().append("MAP after 200 iterations: ").append(BoxesRunTime.boxToDouble(Evaluator$.MODULE$.meanAveragePrecision(randomModel.similaritiesAndLabels(entityRelationKBMatrix2.matrix(), entityRelationKBMatrix3.matrix(), randomModel.similaritiesAndLabels$default$3())))).toString());
    }

    private TrainTestTacData$() {
        MODULE$ = this;
        this.opts = new TrainTestTacDataOptions();
        this.testCols = Predef$.MODULE$.Set().apply(Predef$.MODULE$.wrapRefArray(new String[]{"org:alternate_names", "org:city_of_headquarters", "org:country_of_headquarters", "org:date_dissolved", "org:date_founded", "org:founded_by", "org:member_of", "org:members", "org:number_of_employees_members", "org:parents", "org:political_religious_affiliation", "org:shareholders", "org:stateorprovince_of_headquarters", "org:subsidiaries", "org:top_members_employees", "org:website", "per:age", "per:alternate_names", "per:cause_of_death", "per:charges", "per:children", "per:cities_of_residence", "per:city_of_birth", "per:city_of_death", "per:countries_of_residence", "per:country_of_birth", "per:country_of_death", "per:date_of_birth", "per:date_of_death", "per:employee_or_member_of", "per:origin", "per:other_family", "per:parents", "per:religion", "per:schools_attended", "per:siblings", "per:spouse", "per:stateorprovince_of_birth", "per:stateorprovince_of_death", "per:statesorprovinces_of_residence", "per:title"}));
    }
}
