package de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers;

import com.googlecode.cqengine.index.support.CloseableIterator;
import de.uni_mannheim.informatik.dws.melt.matching_base.FileUtil;
import de.uni_mannheim.informatik.dws.melt.matching_base.Filter;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServer;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServerException;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Alignment;
import de.uni_mannheim.informatik.dws.melt.yet_another_alignment_api.Correspondence;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.OutputStreamWriter;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.text.StringEscapeUtils;
import org.apache.jena.ontology.OntModel;
import org.apache.jena.rdf.model.Resource;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/melt/matching_ml/python/nlptransformers/TransformersFilter.class */
public class TransformersFilter extends TransformersBase implements Filter {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) TransformersFilter.class);
    private static final String NEWLINE = System.getProperty("line.separator");
    private boolean changeClass;
    private boolean optimizeBatchSize;

    public TransformersFilter(TextExtractor textExtractor, String str) {
        super(textExtractor, str);
        this.changeClass = false;
        this.optimizeBatchSize = false;
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_jena.MatcherYAAAJena, de.uni_mannheim.informatik.dws.melt.matching_base.IMatcher
    public Alignment match(OntModel ontModel, OntModel ontModel2, Alignment alignment, Properties properties) throws Exception {
        File createFileWithRandomNumber = FileUtil.createFileWithRandomNumber(this.tmpDir, "alignment_transformers_predict", ".txt");
        try {
            Map<Correspondence, List<Integer>> createPredictionFile = createPredictionFile(ontModel, ontModel2, alignment, createFileWithRandomNumber, false);
            try {
                if (createPredictionFile.isEmpty()) {
                    LOGGER.warn("No correspondences have enough text to be processed (the input alignment has {} correspondences) - the input alignment is returned unchanged.", Integer.valueOf(alignment.size()));
                    createFileWithRandomNumber.delete();
                    return alignment;
                }
                LOGGER.info("Run prediction");
                List<Double> predictConfidences = predictConfidences(createFileWithRandomNumber);
                LOGGER.info("Finished prediction");
                for (Map.Entry<Correspondence, List<Integer>> entry : createPredictionFile.entrySet()) {
                    double d = 0.0d;
                    Iterator<Integer> it2 = entry.getValue().iterator();
                    while (it2.hasNext()) {
                        Double d2 = predictConfidences.get(it2.next().intValue());
                        if (d2 == null) {
                            throw new IllegalArgumentException("Could not find a confidence for a given correspondence.");
                        }
                        if (d2.doubleValue() > d) {
                            d = d2.doubleValue();
                        }
                    }
                    entry.getKey().addAdditionalConfidence(getClass(), d);
                }
                return alignment;
            } finally {
                createFileWithRandomNumber.delete();
            }
        } catch (IOException e) {
            LOGGER.warn("Could not write text to prediction file. Return unmodified input alignment.", (Throwable) e);
            createFileWithRandomNumber.delete();
            return alignment;
        }
    }

    public Map<Correspondence, List<Integer>> createPredictionFile(OntModel ontModel, OntModel ontModel2, Alignment alignment, File file, boolean z) throws IOException {
        HashMap hashMap = new HashMap();
        int i = 0;
        BufferedWriter bufferedWriter = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file, z), StandardCharsets.UTF_8));
        try {
            if (this.multipleTextsToMultipleExamples) {
                CloseableIterator<Correspondence> it2 = alignment.iterator();
                while (it2.hasNext()) {
                    Correspondence next = it2.next();
                    next.addAdditionalConfidence(getClass(), 0.0d);
                    for (String str : this.extractor.extract(ontModel.getResource(next.getEntityOne()))) {
                        if (!StringUtils.isBlank(str)) {
                            for (String str2 : this.extractor.extract(ontModel2.getResource(next.getEntityTwo()))) {
                                if (!StringUtils.isBlank(str2)) {
                                    bufferedWriter.write(StringEscapeUtils.escapeCsv(str) + "," + StringEscapeUtils.escapeCsv(str2) + NEWLINE);
                                    ((List) hashMap.computeIfAbsent(next, correspondence -> {
                                        return new ArrayList();
                                    })).add(Integer.valueOf(i));
                                    i++;
                                }
                            }
                        }
                    }
                }
            } else {
                CloseableIterator<Correspondence> it3 = alignment.iterator();
                while (it3.hasNext()) {
                    Correspondence next2 = it3.next();
                    next2.addAdditionalConfidence(getClass(), 0.0d);
                    String textFromResource = getTextFromResource(ontModel.getResource(next2.getEntityOne()));
                    String textFromResource2 = getTextFromResource(ontModel2.getResource(next2.getEntityTwo()));
                    if (!StringUtils.isBlank(textFromResource) && !StringUtils.isBlank(textFromResource2)) {
                        bufferedWriter.write(StringEscapeUtils.escapeCsv(textFromResource) + "," + StringEscapeUtils.escapeCsv(textFromResource2) + NEWLINE);
                        ((List) hashMap.computeIfAbsent(next2, correspondence2 -> {
                            return new ArrayList();
                        })).add(Integer.valueOf(i));
                        i++;
                    }
                }
            }
            bufferedWriter.close();
            LOGGER.info("Wrote {} examples to prediction file {}", Integer.valueOf(i), file);
            return hashMap;
        } catch (Throwable th) {
            try {
                bufferedWriter.close();
            } catch (Throwable th2) {
                th.addSuppressed(th2);
            }
            throw th;
        }
    }

    private String getTextFromResource(Resource resource) {
        StringBuilder sb = new StringBuilder();
        Iterator<String> it2 = this.extractor.extract(resource).iterator();
        while (it2.hasNext()) {
            sb.append(it2.next().trim()).append(" ");
        }
        return sb.toString().trim();
    }

    public List<Double> predictConfidences(File file) throws Exception {
        if (this.optimizeBatchSize) {
            this.trainingArguments.addParameter("per_device_eval_batch_size", Integer.valueOf(getMaximumPerDeviceEvalBatchSize(file)));
        }
        return PythonServer.getInstance().transformersPrediction(this, file);
    }

    private int getMaximumPerDeviceEvalBatchSize(File file) {
        TransformersTrainerArguments transformersTrainerArguments = this.trainingArguments;
        int i = 4;
        while (true) {
            int i2 = i;
            if (i2 >= 8193) {
                LOGGER.info("It looks like that batch sizes up to 8192 works out which is unusual. If greater batch sizes are possible the code to search max batch size needs to be changed.");
                this.trainingArguments = transformersTrainerArguments;
                return i2;
            }
            LOGGER.info("Try out batch size of {}", Integer.valueOf(i2));
            File createFileWithRandomNumber = FileUtil.createFileWithRandomNumber(this.tmpDir, "alignment_transformers_predict_find_max_batch_size", ".txt");
            try {
                try {
                    if (!copyCSVLines(file, createFileWithRandomNumber, i2)) {
                        LOGGER.info("File contains too few lines to further increase batch size. Thus use now {}", Integer.valueOf(i2 / 2));
                    }
                    this.trainingArguments = new TransformersTrainerArguments(transformersTrainerArguments);
                    this.trainingArguments.addParameter("per_device_eval_batch_size", Integer.valueOf(i2));
                    PythonServer.getInstance().transformersPrediction(this, createFileWithRandomNumber);
                    createFileWithRandomNumber.delete();
                    i = i2 * 2;
                } catch (PythonServerException e) {
                    if (!e.getMessage().contains("not enough memory") && !e.getMessage().contains("out of memory")) {
                        LOGGER.warn("Something went wrong in python server during getMaximumPerDeviceEvalBatchSize. Return default of 8", (Throwable) e);
                        this.trainingArguments = transformersTrainerArguments;
                        createFileWithRandomNumber.delete();
                        return 8;
                    }
                    int i3 = i2 / 2;
                    LOGGER.info("Found memory error, thus returning batchsize of {}", Integer.valueOf(i3));
                    this.trainingArguments = transformersTrainerArguments;
                    createFileWithRandomNumber.delete();
                    return i3;
                } catch (IOException e2) {
                    LOGGER.warn("Something went wrong with io during getMaximumPerDeviceEvalBatchSize. Return default of 8", (Throwable) e2);
                    this.trainingArguments = transformersTrainerArguments;
                    createFileWithRandomNumber.delete();
                    return 8;
                } catch (Exception e3) {
                    LOGGER.warn("Something went wrong during getMaximumPerDeviceEvalBatchSize. Return default of 8", (Throwable) e3);
                    this.trainingArguments = transformersTrainerArguments;
                    createFileWithRandomNumber.delete();
                    return 8;
                }
            } catch (Throwable th) {
                createFileWithRandomNumber.delete();
                throw th;
            }
        }
    }

    public boolean isChangeClass() {
        return this.changeClass;
    }

    public void setChangeClass(boolean z) {
        this.changeClass = z;
    }

    public boolean isOptimizeBatchSize() {
        return this.optimizeBatchSize;
    }

    public void setOptimizeBatchSize(boolean z) {
        this.optimizeBatchSize = z;
    }

    public void setOptimizeAll(boolean z) {
        setOptimizeBatchSize(z);
        setOptimizeForMixedPrecisionTraining(z);
    }

    public boolean isOptimizeAll() {
        return isOptimizeBatchSize() && isOptimizeForMixedPrecisionTraining();
    }
}
