package edu.emory.clir.clearnlp.vector;

import edu.emory.clir.clearnlp.collection.map.ObjectIntHashMap;
import edu.emory.clir.clearnlp.collection.pair.ObjectIntPair;
import edu.emory.clir.clearnlp.util.DSUtils;
import edu.emory.clir.clearnlp.util.Joiner;
import edu.emory.clir.clearnlp.util.MathUtils;
import edu.emory.clir.clearnlp.util.constant.StringConst;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.stream.Collectors;

/* loaded from: input_file:edu/emory/clir/clearnlp/vector/VectorSpaceModel.class */
public class VectorSpaceModel implements Serializable {
    private static final long serialVersionUID = 4172483442205081702L;
    private int DOCUMENT_SIZE;
    private ObjectIntHashMap<String> term_to_id = new ObjectIntHashMap<>();
    private List<ObjectIntPair<String>> id_to_term = new ArrayList();
    private Set<String> stop_words = new HashSet();

    public void addStopWords(Set<String> set) {
        this.stop_words.addAll(set);
    }

    public void addStopWord(String str) {
        this.stop_words.add(str);
    }

    public List<Term> toBagOfWords(List<String> list, int i, boolean z) {
        ObjectIntHashMap<String> bagOfWords = getBagOfWords(list, this.stop_words, i);
        ArrayList arrayList = new ArrayList(bagOfWords.size());
        Iterator<ObjectIntPair<String>> it = bagOfWords.iterator();
        while (it.hasNext()) {
            ObjectIntPair<String> next = it.next();
            int id = getID(next.o);
            if (id < 0) {
                id = this.term_to_id.size();
                this.term_to_id.put(next.o, id + 1);
                this.id_to_term.add(new ObjectIntPair<>(next.o, 0));
            }
            arrayList.add(new Term(id, next.i));
            if (z) {
                this.id_to_term.get(id).i++;
            }
        }
        Collections.sort(arrayList);
        return arrayList;
    }

    public List<List<Term>> toTFIDFs(List<List<String>> list, int i, BiFunction<Term, Integer, Double> biFunction) {
        List<List<Term>> list2 = (List) list.stream().map(list3 -> {
            return toBagOfWords(list3, i, true);
        }).collect(Collectors.toList());
        this.DOCUMENT_SIZE = list.size();
        Iterator<List<Term>> it = list2.iterator();
        while (it.hasNext()) {
            for (Term term : it.next()) {
                term.setDocumentFrequency(getDocumentFrequency(term.getID()));
                term.setScore(biFunction.apply(term, Integer.valueOf(this.DOCUMENT_SIZE)).doubleValue());
            }
        }
        return list2;
    }

    public List<Term> getTFIDFs(List<String> list, int i, BiFunction<Term, Integer, Double> biFunction) {
        ObjectIntHashMap<String> bagOfWords = getBagOfWords(list, this.stop_words, i);
        ArrayList arrayList = new ArrayList();
        Iterator<ObjectIntPair<String>> it = bagOfWords.iterator();
        while (it.hasNext()) {
            ObjectIntPair<String> next = it.next();
            int id = getID(next.o);
            if (id >= 0) {
                Term term = new Term(id, next.i, getDocumentFrequency(id));
                term.setScore(biFunction.apply(term, Integer.valueOf(this.DOCUMENT_SIZE)).doubleValue());
                arrayList.add(term);
            }
        }
        return arrayList;
    }

    public String getTerm(int i) {
        if (DSUtils.isRange(this.id_to_term, i)) {
            return this.id_to_term.get(i).o;
        }
        return null;
    }

    public int getID(String str) {
        return this.term_to_id.get(str) - 1;
    }

    public int getTermSize() {
        return this.id_to_term.size();
    }

    public int getDocumentFrequency(int i) {
        if (DSUtils.isRange(this.id_to_term, i)) {
            return this.id_to_term.get(i).i;
        }
        return 0;
    }

    public void resetDocumentFrequency() {
        Iterator<ObjectIntPair<String>> it = this.id_to_term.iterator();
        while (it.hasNext()) {
            it.next().i = 0;
        }
        this.DOCUMENT_SIZE = 0;
    }

    public static double getTFIDF(double d, int i, int i2) {
        return Math.log(MathUtils.divide(i2, i)) * d;
    }

    public static double getTFIDF(Term term, int i) {
        return getTFIDF(term.getTermFrequency(), term.getDocumentFrequency(), i);
    }

    public static double getWFIDF(Term term, int i) {
        return getTFIDF(term.getTermFrequency() > 0 ? 1.0d + Math.log(term.getTermFrequency()) : 0.0d, term.getDocumentFrequency(), i);
    }

    public static double getEuclideanDistance(List<Term> list, List<Term> list2) {
        int i = 0;
        int i2 = 0;
        int size = list.size();
        int size2 = list2.size();
        double d = 0.0d;
        while (i < size && i2 < size2) {
            Term term = list.get(i);
            Term term2 = list2.get(i2);
            if (term.getID() < term2.getID()) {
                d += MathUtils.sq(term.getScore());
                i++;
            } else if (term.getID() > term2.getID()) {
                d += MathUtils.sq(term2.getScore());
                i2++;
            } else {
                d += MathUtils.sq(term.getScore() - term2.getScore());
                i++;
                i2++;
            }
        }
        while (i < size) {
            d += MathUtils.sq(list.get(i).getScore());
            i++;
        }
        while (i2 < size2) {
            d += MathUtils.sq(list2.get(i2).getScore());
            i2++;
        }
        return Math.sqrt(d);
    }

    public static double getCosineSimilarity(List<Term> list, List<Term> list2) {
        int i = 0;
        int i2 = 0;
        int size = list.size();
        int size2 = list2.size();
        double d = 0.0d;
        double d2 = 0.0d;
        double d3 = 0.0d;
        while (i < size && i2 < size2) {
            Term term = list.get(i);
            Term term2 = list2.get(i2);
            d2 += MathUtils.sq(term.getScore());
            d3 += MathUtils.sq(term2.getScore());
            if (term.getID() < term2.getID()) {
                i++;
            } else if (term.getID() > term2.getID()) {
                i2++;
            } else {
                d += term.getScore() * term2.getScore();
                i++;
                i2++;
            }
        }
        while (i < size) {
            d2 += MathUtils.sq(list.get(i).getScore());
            i++;
        }
        while (i2 < size2) {
            d3 += MathUtils.sq(list2.get(i2).getScore());
            i2++;
        }
        return d / (Math.sqrt(d2) * Math.sqrt(d3));
    }

    public static ObjectIntHashMap<String> getBagOfWords(List<String> list, Set<String> set, int i) {
        ObjectIntHashMap<String> objectIntHashMap = new ObjectIntHashMap<>();
        List removeAll = DSUtils.removeAll(list, set);
        int size = removeAll.size();
        for (int i2 = 0; i2 < size; i2++) {
            int i3 = 0;
            for (int i4 = i2; i3 < i && i4 >= 0; i4--) {
                objectIntHashMap.add(Joiner.join(removeAll, StringConst.UNDERSCORE, i4, i2 + 1));
                i3++;
            }
        }
        return objectIntHashMap;
    }

    public static Set<String> generateStopWords(List<List<String>> list, int i) {
        ObjectIntHashMap objectIntHashMap = new ObjectIntHashMap();
        Iterator<List<String>> it = list.iterator();
        while (it.hasNext()) {
            Iterator it2 = new HashSet(it.next()).iterator();
            while (it2.hasNext()) {
                objectIntHashMap.add((String) it2.next());
            }
        }
        List list2 = objectIntHashMap.toList();
        Collections.sort(list2, Collections.reverseOrder());
        int size = list2.size();
        HashSet hashSet = new HashSet();
        for (int i2 = 0; i2 < size; i2++) {
            hashSet.add(((ObjectIntPair) list2.get(i2)).o);
        }
        return hashSet;
    }
}
