package net.maizegenetics.pangenome.hapCalling;

import com.google.common.collect.HashMultiset;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multiset;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.NavigableMap;
import java.util.TreeMap;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.maizegenetics.analysis.imputation.BackwardForwardVariableStateNumber;
import net.maizegenetics.analysis.imputation.ViterbiAlgorithmVariableStateNumber;
import net.maizegenetics.dna.map.Chromosome;
import net.maizegenetics.pangenome.api.CreateGraphUtils;
import net.maizegenetics.pangenome.api.FilterGraphPlugin;
import net.maizegenetics.pangenome.api.HaplotypeEmissionProbability;
import net.maizegenetics.pangenome.api.HaplotypeGraph;
import net.maizegenetics.pangenome.api.HaplotypeNode;
import net.maizegenetics.pangenome.api.ReferenceRange;
import net.maizegenetics.pangenome.api.ReferenceRangeTransitionProbability;
import net.maizegenetics.taxa.TaxaList;
import org.apache.log4j.Logger;

/* loaded from: input_file:net/maizegenetics/pangenome/hapCalling/ConvertReadsToPathUsingHMM.class */
public class ConvertReadsToPathUsingHMM {
    private static Logger myLogger = Logger.getLogger(ConvertReadsToPathUsingHMM.class);
    private HaplotypeGraph myGraph;
    private Multiset<Integer> myHapidCounts = null;
    private Map<Integer, Integer> myHapidCountMap = null;
    private List<double[]> pathGammas = null;
    private Multimap<ReferenceRange, HapIdSetCount> myReadMap = null;
    private int minReadsPerRefRange = 0;
    private int maxReadsPerRefRangeKB = 10000;
    private String myTaxaListString = null;
    private TaxaList myTaxaList = null;
    private double minTransitionProb = 0.001d;
    private double probReadMappedCorrectly = 0.99d;
    private double transitionProbSameTaxon = 0.99d;
    private String targetTaxon = null;
    private boolean removeRangesWithEqualCounts = true;

    public ConvertReadsToPathUsingHMM filterHaplotypeGraph(HaplotypeGraph haplotypeGraph) {
        int sum = this.myReadMap.entries().stream().mapToInt(entry -> {
            return ((HapIdSetCount) entry.getValue()).getCount();
        }).sum();
        myLogger.info("Filtering graph based on read mappings for " + sum + " reads.");
        if (sum == 0) {
            throw new IllegalArgumentException("myReadMap has not reads.");
        }
        HaplotypeGraph filterOnTaxa = filterOnTaxa(haplotypeGraph);
        FilterGraphPlugin filterGraphPlugin = new FilterGraphPlugin(null, false);
        ArrayList arrayList = new ArrayList();
        int numberOfRanges = filterOnTaxa.numberOfRanges();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (ReferenceRange referenceRange : filterOnTaxa.referenceRanges()) {
            Collection collection = this.myReadMap.get(referenceRange);
            int sum2 = collection == null ? 0 : collection.stream().mapToInt(hapIdSetCount -> {
                return hapIdSetCount.getCount();
            }).sum();
            double end = (sum2 * 1000.0d) / ((referenceRange.end() - referenceRange.start()) + 1);
            int size = filterOnTaxa.nodes(referenceRange).size();
            if (sum2 < this.minReadsPerRefRange) {
                i++;
            }
            if (end > this.maxReadsPerRefRangeKB) {
                i2++;
            }
            if (sum2 < this.minReadsPerRefRange || end > this.maxReadsPerRefRangeKB) {
                arrayList.add(referenceRange);
            } else if (this.removeRangesWithEqualCounts && sum2 > 0) {
                boolean z = true;
                HashMultiset create = HashMultiset.create();
                for (HapIdSetCount hapIdSetCount2 : this.myReadMap.get(referenceRange)) {
                    Iterator<Integer> it = hapIdSetCount2.getHapIdSet().iterator();
                    while (it.hasNext()) {
                        create.add(it.next(), hapIdSetCount2.getCount());
                    }
                }
                if (create.elementSet().size() >= size) {
                    int count = ((Multiset.Entry) create.entrySet().iterator().next()).getCount();
                    Iterator it2 = create.entrySet().iterator();
                    while (true) {
                        if (!it2.hasNext()) {
                            break;
                        }
                        if (((Multiset.Entry) it2.next()).getCount() != count) {
                            z = false;
                            break;
                        }
                    }
                } else {
                    z = false;
                }
                if (z) {
                    arrayList.add(referenceRange);
                    i3++;
                }
            }
        }
        myLogger.info(String.format("total ranges = %d, number of ranges removed = %d", Integer.valueOf(numberOfRanges), Integer.valueOf(arrayList.size())));
        myLogger.info(String.format("number of ranges with low read counts = %d, high count per kb = %d, counts all equal = %d", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3)));
        filterGraphPlugin.refRanges(arrayList);
        myLogger.debug(String.format("before filtering hapgraph: %d nodes.%n", Integer.valueOf(filterOnTaxa.numberOfNodes())));
        this.myGraph = filterGraphPlugin.filter(filterOnTaxa);
        myLogger.debug(String.format("after filtering hapgraph: %d nodes.%n", Integer.valueOf(this.myGraph.numberOfNodes())));
        if (this.myGraph.numberOfNodes() >= 1) {
            return this;
        }
        myLogger.info("Method names for read mapping ids: " + ((String) ReadMappingUtils.getHaplotypeMethodsForReadMappings(this.myReadMap, 1000).stream().collect(Collectors.joining(","))));
        throw new IllegalArgumentException("The filtered graph has no nodes.");
    }

    public ConvertReadsToPathUsingHMM filterHaplotypeGraph(HaplotypeGraph haplotypeGraph, List<Integer> list) {
        HaplotypeGraph filterOnTaxa = filterOnTaxa(haplotypeGraph);
        FilterGraphPlugin filterGraphPlugin = new FilterGraphPlugin(null, false);
        ArrayList arrayList = new ArrayList();
        int numberOfRanges = filterOnTaxa.numberOfRanges();
        int i = 0;
        int i2 = 0;
        int i3 = 0;
        for (ReferenceRange referenceRange : filterOnTaxa.referenceRanges()) {
            Collection collection = this.myReadMap.get(referenceRange);
            int sum = collection == null ? 0 : collection.stream().mapToInt(hapIdSetCount -> {
                return hapIdSetCount.getCount();
            }).sum();
            double end = (sum / ((referenceRange.end() - referenceRange.start()) + 1)) * 1000.0d;
            int size = filterOnTaxa.nodes(referenceRange).size();
            if (sum < this.minReadsPerRefRange) {
                i++;
            }
            if (end > this.maxReadsPerRefRangeKB) {
                i2++;
            }
            if (sum < this.minReadsPerRefRange || end > this.maxReadsPerRefRangeKB) {
                arrayList.add(referenceRange);
            } else if (list != null && !list.contains(referenceRange)) {
                arrayList.add(referenceRange);
            } else if (this.removeRangesWithEqualCounts && sum > 0) {
                boolean z = true;
                HashMultiset create = HashMultiset.create();
                for (HapIdSetCount hapIdSetCount2 : this.myReadMap.get(referenceRange)) {
                    Iterator<Integer> it = hapIdSetCount2.getHapIdSet().iterator();
                    while (it.hasNext()) {
                        create.add(it.next(), hapIdSetCount2.getCount());
                    }
                }
                if (create.elementSet().size() >= size) {
                    int count = ((Multiset.Entry) create.entrySet().iterator().next()).getCount();
                    Iterator it2 = create.entrySet().iterator();
                    while (true) {
                        if (!it2.hasNext()) {
                            break;
                        }
                        if (((Multiset.Entry) it2.next()).getCount() != count) {
                            z = false;
                            break;
                        }
                    }
                } else {
                    z = false;
                }
                if (z) {
                    arrayList.add(referenceRange);
                    i3++;
                }
            }
        }
        myLogger.info(String.format("total ranges = %d, number of ranges removed = %d", Integer.valueOf(numberOfRanges), Integer.valueOf(arrayList.size())));
        myLogger.info(String.format("number of ranges with low read counts = %d, high count per kb = %d, counts all equal = %d", Integer.valueOf(i), Integer.valueOf(i2), Integer.valueOf(i3)));
        filterGraphPlugin.refRanges(arrayList);
        myLogger.debug(String.format("before filtering hapgraph: %d nodes.%n", Integer.valueOf(filterOnTaxa.numberOfNodes())));
        this.myGraph = filterGraphPlugin.filter(filterOnTaxa);
        myLogger.debug(String.format("after filtering hapgraph: %d nodes.%n", Integer.valueOf(this.myGraph.numberOfNodes())));
        if (this.myGraph.numberOfNodes() < 1) {
            myLogger.info("Method names for read mapping ids: " + ((String) ReadMappingUtils.getHaplotypeMethodsForReadMappings(this.myReadMap, 1000).stream().collect(Collectors.joining(","))));
            throw new IllegalArgumentException("The filtered graph has no nodes.");
        }
        this.myGraph = CreateGraphUtils.addMissingSequenceNodes(this.myGraph);
        return this;
    }

    public HaplotypeGraph filterOnTaxa(HaplotypeGraph haplotypeGraph) {
        int numberOfNodes = haplotypeGraph.numberOfNodes();
        int i = haplotypeGraph.totalNumberTaxa();
        int numberOfRanges = haplotypeGraph.numberOfRanges();
        if (this.myTaxaList != null && this.myTaxaList.size() > 0) {
            HaplotypeGraph filter = new FilterGraphPlugin(null, false).taxaList(this.myTaxaList).filter(haplotypeGraph);
            myLogger.debug(String.format("Numbers before filtering on taxa: nodes = %d, ranges = %d, taxa = %d", Integer.valueOf(numberOfNodes), Integer.valueOf(numberOfRanges), Integer.valueOf(i)));
            myLogger.debug(String.format("Numbers after filtering on taxa: nodes = %d, ranges = %d, taxa = %d", Integer.valueOf(filter.numberOfNodes()), Integer.valueOf(filter.totalNumberTaxa()), Integer.valueOf(filter.numberOfRanges())));
            return filter;
        }
        if (this.myTaxaListString == null) {
            return haplotypeGraph;
        }
        HaplotypeGraph filter2 = new FilterGraphPlugin(null, false).taxaList(this.myTaxaListString).filter(haplotypeGraph);
        myLogger.debug(String.format("Numbers before filtering on taxa: nodes = %d, ranges = %d, taxa = %d", Integer.valueOf(numberOfNodes), Integer.valueOf(numberOfRanges), Integer.valueOf(i)));
        myLogger.debug(String.format("Numbers after filtering on taxa: nodes = %d, ranges = %d, taxa = %d", Integer.valueOf(filter2.numberOfNodes()), Integer.valueOf(filter2.totalNumberTaxa()), Integer.valueOf(filter2.numberOfRanges())));
        return filter2;
    }

    public void listTaxa() {
        System.out.println("taxa in graph:");
        Stream stream = this.myGraph.taxaInGraph().stream();
        PrintStream printStream = System.out;
        printStream.getClass();
        stream.forEach((v1) -> {
            r1.println(v1);
        });
    }

    public List<HaplotypeNode> haplotypeCountsToPath() {
        ArrayList arrayList = new ArrayList();
        for (Chromosome chromosome : this.myGraph.chromosomes()) {
            myLogger.info("Getting path for chromosome " + chromosome.getName());
            NavigableMap<ReferenceRange, List<HaplotypeNode>> tree = this.myGraph.tree(chromosome);
            HaplotypeEmissionProbability haplotypeEmissionProbability = new HaplotypeEmissionProbability(tree, this.myReadMap, this.probReadMappedCorrectly);
            ArrayList arrayList2 = new ArrayList(tree.values());
            ReferenceRangeTransitionProbability referenceRangeTransitionProbability = new ReferenceRangeTransitionProbability(arrayList2, this.myGraph, this.minTransitionProb);
            int size = arrayList2.size();
            ViterbiAlgorithmVariableStateNumber viterbiAlgorithmVariableStateNumber = new ViterbiAlgorithmVariableStateNumber(new byte[size], referenceRangeTransitionProbability, haplotypeEmissionProbability, startProbabilities(tree.values().iterator().next().size()));
            viterbiAlgorithmVariableStateNumber.initialize();
            viterbiAlgorithmVariableStateNumber.calculate();
            byte[] mostProbableStateSequence = viterbiAlgorithmVariableStateNumber.getMostProbableStateSequence();
            for (int i = 0; i < size; i++) {
                arrayList.add(((List) arrayList2.get(i)).get(mostProbableStateSequence[i]));
            }
        }
        return arrayList;
    }

    public List<double[]> haplotypeCountsToPathProbability() {
        this.pathGammas = new ArrayList();
        ArrayList arrayList = new ArrayList();
        for (Chromosome chromosome : this.myGraph.chromosomes()) {
            myLogger.info("Getting path for chromosome " + chromosome.getName());
            NavigableMap<ReferenceRange, List<HaplotypeNode>> tree = this.myGraph.tree(chromosome);
            arrayList.addAll(tree.values());
            myLogger.info("Extracted graph tree for chromosome " + chromosome.getName());
            long currentTimeMillis = System.currentTimeMillis();
            HaplotypeEmissionProbability haplotypeEmissionProbability = new HaplotypeEmissionProbability(tree, this.myReadMap, this.probReadMappedCorrectly);
            myLogger.info(String.format("emission probability set up in %d ms.", Long.valueOf(System.currentTimeMillis() - currentTimeMillis)));
            myLogger.info(haplotypeEmissionProbability.toString());
            long currentTimeMillis2 = System.currentTimeMillis();
            ArrayList arrayList2 = new ArrayList(tree.values());
            ReferenceRangeTransitionProbability referenceRangeTransitionProbability = new ReferenceRangeTransitionProbability(arrayList2, this.myGraph, this.minTransitionProb);
            myLogger.info(String.format("transition probability set up in %d ms.", Long.valueOf(System.currentTimeMillis() - currentTimeMillis2)));
            System.currentTimeMillis();
            int[] iArr = new int[arrayList2.size()];
            double[] startProbabilities = startProbabilities(tree.values().iterator().next().size());
            System.currentTimeMillis();
            BackwardForwardVariableStateNumber backwardForwardVariableStateNumber = new BackwardForwardVariableStateNumber();
            backwardForwardVariableStateNumber.emission(haplotypeEmissionProbability).transition(referenceRangeTransitionProbability).initialStateProbability(startProbabilities).observations(iArr).calculateAlpha().calculateBeta();
            this.pathGammas.addAll(backwardForwardVariableStateNumber.gamma());
        }
        return this.pathGammas;
    }

    public List<HaplotypeNode> nodeListFromProbabilities(double d, String str) {
        ArrayList arrayList = new ArrayList();
        Iterator<ReferenceRange> it = this.myGraph.referenceRangeList().iterator();
        for (double[] dArr : this.pathGammas) {
            ReferenceRange next = it.next();
            int i = 0;
            for (int i2 = 1; i2 < dArr.length; i2++) {
                if (dArr[i2] > dArr[i]) {
                    i = i2;
                }
            }
            if (dArr[i] >= d) {
                arrayList.add(this.myGraph.nodes(next).get(i));
            }
        }
        if (str != null) {
            try {
                PrintWriter printWriter = new PrintWriter(str);
                printWriter.println("chr\tstart\thasTarget\tprob\ttaxa");
                Iterator<ReferenceRange> it2 = this.myGraph.referenceRangeList().iterator();
                for (double[] dArr2 : this.pathGammas) {
                    ReferenceRange next2 = it2.next();
                    int i3 = 0;
                    for (HaplotypeNode haplotypeNode : this.myGraph.nodes(next2)) {
                        printWriter.print(next2.chromosome().getName() + "\t");
                        printWriter.print(Integer.toString(next2.start()) + "\t");
                        printWriter.print(Boolean.toString(haplotypeNode.taxaList().indexOf(this.targetTaxon) >= 0) + "\t");
                        int i4 = i3;
                        i3++;
                        printWriter.print(dArr2[i4] + "\t");
                        printWriter.print(((String) haplotypeNode.taxaList().stream().map((v0) -> {
                            return v0.getName();
                        }).collect(Collectors.joining(","))) + "\t");
                        printWriter.println();
                    }
                }
                printWriter.close();
            } catch (FileNotFoundException e) {
                myLogger.error(e.getMessage());
                myLogger.error(String.format("Unable to open %s for output in ConvertReadsToPathUsingHMM.nodeListFromProbabilities", str));
            }
        }
        return arrayList;
    }

    public List<HaplotypeNode> nodeListFromProbabilities(double d) {
        return nodeListFromProbabilities(d, null);
    }

    public double[] startProbabilities(int i) {
        double[] dArr = new double[i];
        Arrays.fill(dArr, 1.0d / i);
        return dArr;
    }

    public double[] probabilityOfBeingCorrect(HaplotypeGraph haplotypeGraph, Multiset<Integer> multiset) {
        return haplotypeGraph.referenceRanges().stream().mapToDouble(referenceRange -> {
            return nodeCorrectProbability(haplotypeGraph.nodes(referenceRange).stream().map(haplotypeNode -> {
                return Integer.valueOf(multiset.count(Integer.valueOf(haplotypeNode.id())));
            }).mapToInt((v0) -> {
                return v0.intValue();
            }).toArray());
        }).toArray();
    }

    public double[] probabilityOfBeingCorrect(Multiset<Integer> multiset, TreeMap<ReferenceRange, List<HaplotypeNode>> treeMap) {
        double[] dArr = new double[treeMap.size()];
        int i = 0;
        Iterator<Map.Entry<ReferenceRange, List<HaplotypeNode>>> it = treeMap.entrySet().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2] = nodeCorrectProbability(it.next().getValue().stream().map(haplotypeNode -> {
                return Integer.valueOf(multiset.count(Integer.valueOf(haplotypeNode.id())));
            }).mapToInt((v0) -> {
                return v0.intValue();
            }).toArray());
        }
        return dArr;
    }

    public double[] probabilityOfBeingCorrect(Map<Integer, Integer> map, TreeMap<ReferenceRange, List<HaplotypeNode>> treeMap) {
        double[] dArr = new double[treeMap.size()];
        int i = 0;
        Iterator<Map.Entry<ReferenceRange, List<HaplotypeNode>>> it = treeMap.entrySet().iterator();
        while (it.hasNext()) {
            int i2 = i;
            i++;
            dArr[i2] = nodeCorrectProbability(it.next().getValue().stream().map(haplotypeNode -> {
                return (Integer) map.getOrDefault(haplotypeNode, 0);
            }).mapToInt((v0) -> {
                return v0.intValue();
            }).toArray());
        }
        return dArr;
    }

    private double nodeCorrectProbability(int[] iArr) {
        int i = 0;
        int i2 = 0;
        for (int i3 : iArr) {
            i += i3;
            i2 = Math.max(i2, i3);
        }
        return i2 / i;
    }

    public HaplotypeGraph filteredGraph() {
        return this.myGraph;
    }

    public ConvertReadsToPathUsingHMM hapidCountMap(Map<Integer, Integer> map) {
        this.myHapidCountMap = map;
        return this;
    }

    public ConvertReadsToPathUsingHMM minReadsPerRange(int i) {
        this.minReadsPerRefRange = i;
        return this;
    }

    public ConvertReadsToPathUsingHMM maxReadsPerRangeKB(int i) {
        this.maxReadsPerRefRangeKB = i;
        return this;
    }

    public ConvertReadsToPathUsingHMM taxaFilterList(String str) {
        this.myTaxaListString = str;
        return this;
    }

    public ConvertReadsToPathUsingHMM taxaFilterList(TaxaList taxaList) {
        this.myTaxaList = taxaList;
        return this;
    }

    public ConvertReadsToPathUsingHMM probabilityReadMappingCorrect(double d) {
        this.probReadMappedCorrectly = d;
        return this;
    }

    public ConvertReadsToPathUsingHMM minTransitionProbability(double d) {
        this.minTransitionProb = d;
        return this;
    }

    public ConvertReadsToPathUsingHMM transitionProbabilitySameTaxon(double d) {
        this.transitionProbSameTaxon = d;
        return this;
    }

    public ConvertReadsToPathUsingHMM targetTaxon(String str) {
        this.targetTaxon = str;
        return this;
    }

    public ConvertReadsToPathUsingHMM readMap(Multimap<ReferenceRange, HapIdSetCount> multimap) {
        this.myReadMap = multimap;
        return this;
    }

    public ConvertReadsToPathUsingHMM removeRangesWithEqualCounts(boolean z) {
        this.removeRangesWithEqualCounts = z;
        return this;
    }
}
