package de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers;

import de.uni_mannheim.informatik.dws.melt.matching_base.FileUtil;
import de.uni_mannheim.informatik.dws.melt.matching_base.Filter;
import de.uni_mannheim.informatik.dws.melt.matching_jena.TextExtractor;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServer;
import de.uni_mannheim.informatik.dws.melt.matching_ml.python.PythonServerException;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
import org.apache.jena.ext.xerces.impl.xs.SchemaSymbols;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:de/uni_mannheim/informatik/dws/melt/matching_ml/python/nlptransformers/TransformersFineTunerHpSearch.class */
public class TransformersFineTunerHpSearch extends TransformersFineTuner implements Filter {
    private static final Logger LOGGER = LoggerFactory.getLogger((Class<?>) TransformersFineTunerHpSearch.class);
    private int numberOfTrials;
    private float testSize;
    private TransformersOptimizingMetric optimizingMetric;
    private TransformersHpSearchSpace hpSpace;
    private TransformersHpSearchSpace hpMutations;
    private boolean adjustMaxBatchSize;

    public TransformersFineTunerHpSearch(TextExtractor textExtractor, String str, File file) {
        super(textExtractor, str, file);
        this.numberOfTrials = 10;
        this.testSize = 0.33f;
        this.optimizingMetric = TransformersOptimizingMetric.AUC;
        this.hpSpace = TransformersHpSearchSpace.getDefaultHpSpace();
        this.hpMutations = TransformersHpSearchSpace.getDefaultHpSpaceMutations();
    }

    @Override // de.uni_mannheim.informatik.dws.melt.matching_ml.python.nlptransformers.TransformersFineTuner
    public File finetuneModel(File file) throws Exception {
        if (this.adjustMaxBatchSize) {
            int maximumPerDeviceTrainBatchSize = getMaximumPerDeviceTrainBatchSize(file);
            ArrayList arrayList = new ArrayList();
            if (maximumPerDeviceTrainBatchSize >= 4) {
                if (maximumPerDeviceTrainBatchSize >= 8) {
                    int i = 4;
                    while (true) {
                        int i2 = i;
                        if (i2 > maximumPerDeviceTrainBatchSize) {
                            break;
                        }
                        arrayList.add(Integer.valueOf(i2));
                        i = i2 * 2;
                    }
                } else {
                    int i3 = 2;
                    while (true) {
                        int i4 = i3;
                        if (i4 > maximumPerDeviceTrainBatchSize) {
                            break;
                        }
                        arrayList.add(Integer.valueOf(i4));
                        i3 = i4 * 2;
                    }
                }
            } else {
                int i5 = 1;
                while (true) {
                    int i6 = i5;
                    if (i6 > maximumPerDeviceTrainBatchSize) {
                        break;
                    }
                    arrayList.add(Integer.valueOf(i6));
                    i5 = i6 * 2;
                }
            }
            LOGGER.info("Set the hyper parameter search space for \"per_device_train_batch_size\" to: {}", arrayList);
            this.hpSpace.choice("per_device_train_batch_size", arrayList);
            this.hpMutations.choice("per_device_train_batch_size", arrayList);
        }
        PythonServer.getInstance().transformersFineTuningHpSearch(this, file);
        return this.resultingModelLocation;
    }

    public int getMaximumPerDeviceTrainBatchSize() {
        if (this.trainingFile == null || !this.trainingFile.exists() || this.trainingFile.length() == 0) {
            throw new IllegalArgumentException("Cannot get maximum per device train batch size because no training file is generated. Did you call the match method before (e.g. in a pipeline)?");
        }
        return getMaximumPerDeviceTrainBatchSize(this.trainingFile);
    }

    public int getMaximumPerDeviceTrainBatchSize(File file) {
        TransformersTrainerArguments transformersTrainerArguments = this.trainingArguments;
        String str = this.cudaVisibleDevices;
        this.cudaVisibleDevices = getCudaVisibleDevicesButOnlyOneGPU();
        int i = 4;
        while (true) {
            int i2 = i;
            if (i2 >= 8193) {
                LOGGER.info("It looks like that batch sizes up to 8192 works out which is unusual. If greater batch sizes are possible the code to search max batch size needs to be changed.");
                this.trainingArguments = transformersTrainerArguments;
                this.cudaVisibleDevices = str;
                return i2;
            }
            LOGGER.info("Try out batch size of {}", Integer.valueOf(i2));
            File createFileWithRandomNumber = FileUtil.createFileWithRandomNumber(this.tmpDir, "alignment_transformers_find_max_batch_size", ".txt");
            try {
                try {
                    try {
                        try {
                            if (!copyCSVLines(file, createFileWithRandomNumber, i2)) {
                                LOGGER.info("File contains too few lines to further increase batch size. Thus use now {}", Integer.valueOf(i2 / 2));
                            }
                            this.trainingArguments = new TransformersTrainerArguments(transformersTrainerArguments);
                            this.trainingArguments.addParameter("per_device_train_batch_size", Integer.valueOf(i2));
                            this.trainingArguments.addParameter("save_at_end", false);
                            this.trainingArguments.addParameter("max_steps", 1);
                            super.finetuneModel(createFileWithRandomNumber);
                            createFileWithRandomNumber.delete();
                            i = i2 * 2;
                        } catch (PythonServerException e) {
                            if (!e.getMessage().contains("not enough memory") && !e.getMessage().contains("out of memory")) {
                                LOGGER.warn("Something went wrong in python server during getMaximumPerDeviceTrainBatchSize. Return default of 8", (Throwable) e);
                                this.trainingArguments = transformersTrainerArguments;
                                this.cudaVisibleDevices = str;
                                createFileWithRandomNumber.delete();
                                return 8;
                            }
                            int i3 = i2 / 2;
                            LOGGER.info("Found memory error, thus returning batchsize of {}", Integer.valueOf(i3));
                            this.trainingArguments = transformersTrainerArguments;
                            this.cudaVisibleDevices = str;
                            createFileWithRandomNumber.delete();
                            return i3;
                        }
                    } catch (Exception e2) {
                        LOGGER.warn("Something went wrong during getMaximumPerDeviceTrainBatchSize. Return default of 8", (Throwable) e2);
                        this.trainingArguments = transformersTrainerArguments;
                        this.cudaVisibleDevices = str;
                        createFileWithRandomNumber.delete();
                        return 8;
                    }
                } catch (IOException e3) {
                    LOGGER.warn("Something went wrong with io during getMaximumPerDeviceTrainBatchSize. Return default of 8", (Throwable) e3);
                    this.trainingArguments = transformersTrainerArguments;
                    this.cudaVisibleDevices = str;
                    createFileWithRandomNumber.delete();
                    return 8;
                }
            } catch (Throwable th) {
                createFileWithRandomNumber.delete();
                throw th;
            }
        }
    }

    public String getCudaVisibleDevicesButOnlyOneGPU() {
        String cudaVisibleDevices = getCudaVisibleDevices();
        if (cudaVisibleDevices == null) {
            return SchemaSymbols.ATTVAL_FALSE_0;
        }
        String trim = cudaVisibleDevices.trim();
        return trim.isEmpty() ? SchemaSymbols.ATTVAL_FALSE_0 : trim.split(",")[0];
    }

    public int getNumberOfTrials() {
        return this.numberOfTrials;
    }

    public void setNumberOfTrials(int i) {
        this.numberOfTrials = i;
    }

    public float getTestSize() {
        return this.testSize;
    }

    public void setTestSize(float f) {
        if (f < 0.0d || f > 1.0d) {
            throw new IllegalArgumentException("Test size should be between zero and one");
        }
        this.testSize = f;
    }

    public TransformersOptimizingMetric getOptimizingMetric() {
        return this.optimizingMetric;
    }

    public void setOptimizingMetric(TransformersOptimizingMetric transformersOptimizingMetric) {
        this.optimizingMetric = transformersOptimizingMetric;
    }

    public TransformersHpSearchSpace getHpSpace() {
        return this.hpSpace;
    }

    public void setHpSpace(TransformersHpSearchSpace transformersHpSearchSpace) {
        if (transformersHpSearchSpace == null) {
            throw new IllegalArgumentException("HpSpace should not be null.");
        }
        this.hpSpace = transformersHpSearchSpace;
    }

    public TransformersHpSearchSpace getHpMutations() {
        return this.hpMutations;
    }

    public void setHpMutations(TransformersHpSearchSpace transformersHpSearchSpace) {
        if (transformersHpSearchSpace == null) {
            throw new IllegalArgumentException("HpMutations should not be null.");
        }
        this.hpMutations = transformersHpSearchSpace;
    }

    public boolean isAdjustMaxBatchSize() {
        return this.adjustMaxBatchSize;
    }

    public void setAdjustMaxBatchSize(boolean z) {
        this.adjustMaxBatchSize = z;
    }
}
