package ai.idylnlp.nlp.recognizer;

import ai.idylnlp.model.exceptions.EntityFinderException;
import ai.idylnlp.model.exceptions.ModelLoaderException;
import ai.idylnlp.model.manifest.SecondGenModelManifest;
import ai.idylnlp.model.nlp.AbstractEntityRecognizer;
import ai.idylnlp.model.nlp.SentenceSanitizer;
import ai.idylnlp.model.nlp.ner.EntityExtractionRequest;
import ai.idylnlp.model.nlp.ner.EntityExtractionResponse;
import ai.idylnlp.model.nlp.ner.EntityRecognizer;
import ai.idylnlp.nlp.recognizer.configuration.DeepLearningEntityRecognizerConfiguration;
import ai.idylnlp.nlp.recognizer.deep.DeepLearningTokenNameFinder;
import ai.idylnlp.nlp.sentence.sanitizers.DefaultSentenceSanitizer;
import com.neovisionaries.i18n.LanguageCode;
import java.io.File;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.Set;
import opennlp.tools.namefind.TokenNameFinder;
import org.apache.commons.collections4.CollectionUtils;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.deeplearning4j.models.embeddings.loader.WordVectorSerializer;
import org.deeplearning4j.models.embeddings.wordvectors.WordVectors;
import org.deeplearning4j.nn.multilayer.MultiLayerNetwork;
import org.deeplearning4j.util.ModelSerializer;

/* loaded from: input_file:ai/idylnlp/nlp/recognizer/DeepLearningEntityRecognizer.class */
public class DeepLearningEntityRecognizer extends AbstractEntityRecognizer<DeepLearningEntityRecognizerConfiguration> implements EntityRecognizer {
    private static final Logger LOGGER = LogManager.getLogger(DeepLearningEntityRecognizer.class);
    private Map<LanguageCode, Map<String, ImmutablePair<MultiLayerNetwork, WordVectors>>> loadedModels;

    public DeepLearningEntityRecognizer(DeepLearningEntityRecognizerConfiguration deepLearningEntityRecognizerConfiguration) {
        super(deepLearningEntityRecognizerConfiguration);
        this.loadedModels = new HashMap();
        for (String str : deepLearningEntityRecognizerConfiguration.getEntityModels().keySet()) {
            Map map = (Map) deepLearningEntityRecognizerConfiguration.getEntityModels().get(str);
            for (LanguageCode languageCode : map.keySet()) {
                for (SecondGenModelManifest secondGenModelManifest : (Set) map.get(languageCode)) {
                    if (deepLearningEntityRecognizerConfiguration.getBlacklistedModelIDs().contains(secondGenModelManifest.getModelId())) {
                        LOGGER.info("Model {} is blacklisted. Loading will not be attempted until restart.", secondGenModelManifest.getModelFileName());
                    } else {
                        try {
                            String absolutePath = new File(deepLearningEntityRecognizerConfiguration.getEntityModelDirectory(), secondGenModelManifest.getModelFileName()).getAbsolutePath();
                            LOGGER.info("Loading {} {} model from file: {}", languageCode.getAlpha3().toString(), str, absolutePath);
                            MultiLayerNetwork restoreMultiLayerNetwork = ModelSerializer.restoreMultiLayerNetwork(new File(absolutePath).getAbsolutePath());
                            String absolutePath2 = new File(deepLearningEntityRecognizerConfiguration.getEntityModelDirectory(), secondGenModelManifest.getVectorsFileName()).getAbsolutePath();
                            File file = new File(absolutePath2);
                            LOGGER.info("Loading vectors from file: {}", absolutePath2);
                            WordVectors loadStaticModel = WordVectorSerializer.loadStaticModel(file);
                            HashMap hashMap = new HashMap();
                            hashMap.put(str, new ImmutablePair(restoreMultiLayerNetwork, loadStaticModel));
                            this.loadedModels.put(languageCode, hashMap);
                        } catch (Exception e) {
                            LOGGER.error("Unable to load model: " + secondGenModelManifest.getModelFileName(), e);
                            ((DeepLearningEntityRecognizerConfiguration) getConfiguration()).getBlacklistedModelIDs().add(secondGenModelManifest.getModelId());
                            LOGGER.warn("Model {} is blacklisted. Loading will not be attempted until restart.", secondGenModelManifest.getModelFileName());
                        }
                    }
                }
            }
        }
    }

    public EntityExtractionResponse extractEntities(EntityExtractionRequest entityExtractionRequest) throws EntityFinderException, ModelLoaderException {
        Set set;
        if (entityExtractionRequest.getText().length == 0) {
            throw new IllegalArgumentException("Input text cannot be empty.");
        }
        if (entityExtractionRequest.getConfidenceThreshold() < 0 || entityExtractionRequest.getConfidenceThreshold() > 100) {
            throw new IllegalArgumentException("Confidence threshold must be an integer between 0 and 100.");
        }
        SentenceSanitizer build = new DefaultSentenceSanitizer.Builder().lowerCase().removePunctuation().consolidateSpaces().build();
        LinkedHashSet linkedHashSet = new LinkedHashSet();
        long currentTimeMillis = System.currentTimeMillis();
        String[] strArr = new String[0];
        if (!StringUtils.isEmpty(entityExtractionRequest.getType())) {
            strArr = entityExtractionRequest.getType().split(",");
        }
        for (String str : ((DeepLearningEntityRecognizerConfiguration) getConfiguration()).getEntityModels().keySet()) {
            if (strArr.length == 0 || ArrayUtils.contains(strArr, str)) {
                LOGGER.trace("Processing entity class {}.", str);
                LanguageCode language = entityExtractionRequest.getLanguage();
                HashSet<SecondGenModelManifest> hashSet = new HashSet();
                if (entityExtractionRequest.getLanguage() == null) {
                    Iterator it = ((Map) ((DeepLearningEntityRecognizerConfiguration) getConfiguration()).getEntityModels().get(str)).keySet().iterator();
                    while (it.hasNext()) {
                        hashSet.addAll((Collection) ((Map) ((DeepLearningEntityRecognizerConfiguration) getConfiguration()).getEntityModels().get(str)).get((LanguageCode) it.next()));
                    }
                } else {
                    Map map = (Map) ((DeepLearningEntityRecognizerConfiguration) getConfiguration()).getEntityModels().get(str);
                    if (map != null && (set = (Set) map.get(language)) != null) {
                        hashSet.addAll(set);
                    }
                }
                if (CollectionUtils.isNotEmpty(hashSet)) {
                    for (SecondGenModelManifest secondGenModelManifest : hashSet) {
                        LOGGER.debug("{} has {} entity models.", str, Integer.valueOf(hashSet.size()));
                        String type = secondGenModelManifest.getType();
                        LOGGER.info("Getting model for type {}, language {}", secondGenModelManifest.getLanguageCode().getAlpha3().toString(), type);
                        ImmutablePair<MultiLayerNetwork, WordVectors> immutablePair = this.loadedModels.get(secondGenModelManifest.getLanguageCode()).get(type);
                        MultiLayerNetwork multiLayerNetwork = (MultiLayerNetwork) immutablePair.getLeft();
                        WordVectors wordVectors = (WordVectors) immutablePair.getRight();
                        TokenNameFinder tokenNameFinder = (TokenNameFinder) this.nameFinders.get(secondGenModelManifest);
                        if (tokenNameFinder == null) {
                            tokenNameFinder = new DeepLearningTokenNameFinder(multiLayerNetwork, wordVectors, secondGenModelManifest.getWindowSize(), getLabels(entityExtractionRequest.getType()));
                            this.nameFinders.put(secondGenModelManifest, tokenNameFinder);
                        }
                        linkedHashSet.addAll(findEntities(tokenNameFinder, entityExtractionRequest, secondGenModelManifest, build));
                    }
                } else {
                    LOGGER.warn("No entity models available for language {}.", language.getAlpha3().toString());
                }
            }
        }
        return new EntityExtractionResponse(linkedHashSet, System.currentTimeMillis() - currentTimeMillis, true);
    }

    private String[] getLabels(String str) {
        return new String[]{str + "-start", str + "-cont", "other"};
    }
}
