package de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers;

import com.googlecode.cqengine.index.support.CloseableIterator;
import de.uni_mannheim.informatik.dws.melt.matching_base.FileUtil;
import de.uni_mannheim.informatik.dws.melt.matching_base.Filter;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServer;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Alignment;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Correspondence;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.CorrespondenceRelation;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.Iterator;
import java.util.Properties;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.jena.ext.xerces.impl.xs.SchemaSymbols;
import org.apache.jena.ontology.OntModel;
import org.apache.jena.rdf.model.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/melt/matching_ml/python/nlptransformers/TransformersFineTuner.class */
public class TransformersFineTuner extends TransformersBase implements Filter {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) TransformersFineTuner.class);
    private static final String NEWLINE = System.getProperty("line.separator");
    protected File resultingModelLocation;
    protected File trainingFile;

    public TransformersFineTuner(TextExtractor textExtractor, String str, File file, File file2) {
        super(textExtractor, str);
        this.tmpDir = file2;
        this.resultingModelLocation = file;
        this.trainingFile = FileUtil.createFileWithRandomNumber(this.tmpDir, "alignment_transformers_train", ".txt");
    }

    public TransformersFineTuner(TextExtractor textExtractor, String str, File file) {
        this(textExtractor, str, file, FileUtil.SYSTEM_TMP_FOLDER);
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherYAAAJena, de.uni_mannheim.informatik.dws.melt.matching_base.IMatcher
    public Alignment match(OntModel ontModel, OntModel ontModel2, Alignment alignment, Properties properties) throws Exception {
        LOGGER.info("Append text to training file: {}", this.trainingFile);
        writeTrainingFile(ontModel, ontModel2, alignment, this.trainingFile, true);
        return alignment;
    }

    public File createTrainingFile(OntModel ontModel, OntModel ontModel2, Alignment alignment) throws IOException {
        File createFileWithRandomNumber = FileUtil.createFileWithRandomNumber(this.tmpDir, "alignment_transformers_train", ".txt");
        if (writeTrainingFile(ontModel, ontModel2, alignment, createFileWithRandomNumber, false) != 0) {
            return createFileWithRandomNumber;
        }
        LOGGER.warn("No training file is created because no correspondences have enough text.");
        createFileWithRandomNumber.delete();
        return null;
    }

    public int writeTrainingFile(OntModel ontModel, OntModel ontModel2, Alignment alignment, File file, boolean z) throws IOException {
        String str;
        String str2;
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file, z), StandardCharsets.UTF_8));
        try {
            if (this.multipleTextsToMultipleExamples) {
                CloseableIterator<Correspondence> it2 = alignment.iterator();
                while (it2.hasNext()) {
                    Correspondence next = it2.next();
                    for (String str3 : this.extractor.extract(ontModel.getResource(next.getEntityOne()))) {
                        if (!StringUtils.isBlank(str3)) {
                            for (String str4 : this.extractor.extract(ontModel2.getResource(next.getEntityTwo()))) {
                                if (!StringUtils.isBlank(str4)) {
                                    if (next.getRelation() == CorrespondenceRelation.EQUIVALENCE) {
                                        str2 = SchemaSymbols.ATTVAL_TRUE_1;
                                        i2++;
                                    } else {
                                        str2 = SchemaSymbols.ATTVAL_FALSE_0;
                                        i3++;
                                    }
                                    bufferedWriter.write(StringEscapeUtils.escapeCsv(str3) + "," + StringEscapeUtils.escapeCsv(str4) + "," + str2 + NEWLINE);
                                }
                            }
                        }
                    }
                }
                LOGGER.info("Wrote {} training examples. {} positive, {} negative (number of unused is to determined).", Integer.valueOf(i2 + i3), Integer.valueOf(i2), Integer.valueOf(i3));
            } else {
                CloseableIterator<Correspondence> it3 = alignment.iterator();
                while (it3.hasNext()) {
                    Correspondence next2 = it3.next();
                    String textFromResource = getTextFromResource(ontModel.getResource(next2.getEntityOne()));
                    String textFromResource2 = getTextFromResource(ontModel2.getResource(next2.getEntityTwo()));
                    if (StringUtils.isBlank(textFromResource) || StringUtils.isBlank(textFromResource2)) {
                        i++;
                    } else {
                        if (next2.getRelation() == CorrespondenceRelation.EQUIVALENCE) {
                            str = SchemaSymbols.ATTVAL_TRUE_1;
                            i2++;
                        } else {
                            str = SchemaSymbols.ATTVAL_FALSE_0;
                            i3++;
                        }
                        bufferedWriter.write(StringEscapeUtils.escapeCsv(textFromResource) + "," + StringEscapeUtils.escapeCsv(textFromResource2) + "," + str + NEWLINE);
                    }
                }
                LOGGER.info("Wrote {} training examples. {} positive, {} negative, {} not used due to insufficient textual data.", Integer.valueOf(i2 + i3), Integer.valueOf(i2), Integer.valueOf(i3), Integer.valueOf(i));
            }
            bufferedWriter.close();
            return i2 + i3;
        } catch (Throwable th) {
            try {
                bufferedWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private String getTextFromResource(Resource resource) {
        StringBuilder sb = new StringBuilder();
        Iterator<String> it2 = this.extractor.extract(resource).iterator();
        while (it2.hasNext()) {
            sb.append(it2.next().trim()).append(" ");
        }
        return sb.toString().trim();
    }

    public File finetuneModel(File file) throws Exception {
        PythonServer.getInstance().transformersFineTuning(this, file);
        return this.resultingModelLocation;
    }

    public File finetuneModel() throws Exception {
        if (this.trainingFile == null || !this.trainingFile.exists() || this.trainingFile.length() == 0) {
            throw new IllegalArgumentException("Cannot finetune a model if because no training file is generated. Did you call the match method before (e.g. in a pipeline)?");
        }
        File finetuneModel = finetuneModel(this.trainingFile);
        this.trainingFile.delete();
        return finetuneModel;
    }

    public void addTrainingParameterToMakeTrainingFaster() {
        this.trainingArguments.addParameter("fp16", true);
    }

    public File getResultingModelLocation() {
        return this.resultingModelLocation;
    }

    public void setResultingModelLocation(File file) {
        this.resultingModelLocation = file;
    }
}
