/*
 * Decompiled with CFR 0.152.
 */
package net.automatalib.util.automata.ads;

import com.google.common.collect.Maps;
import java.util.BitSet;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.automatalib.automata.concepts.StateIDs;
import net.automatalib.automata.transout.MealyMachine;
import net.automatalib.commons.util.Pair;
import net.automatalib.graphs.ads.ADSNode;
import net.automatalib.graphs.ads.RecursiveADSNode;
import net.automatalib.graphs.ads.impl.ADSLeafNode;
import net.automatalib.graphs.ads.impl.ADSSymbolNode;
import net.automatalib.util.automata.ads.ADS;
import net.automatalib.util.automata.ads.ADSUtil;
import net.automatalib.util.automata.ads.SplitTree;
import net.automatalib.words.Alphabet;
import net.automatalib.words.Word;

public final class BacktrackingSearch {
    private BacktrackingSearch() {
    }

    public static <S, I, O> Optional<ADSNode<S, I, O>> compute(MealyMachine<S, I, ?, O> automaton, Alphabet<I> input, Set<S> states) {
        if (states.size() == 1) {
            return ADS.compute(automaton, input, states);
        }
        SplitTree node = new SplitTree(states);
        node.getMapping().putAll(states.stream().collect(Collectors.toMap(Function.identity(), Function.identity())));
        return BacktrackingSearch.compute(automaton, input, node);
    }

    static <S, I, O> Optional<ADSNode<S, I, O>> compute(MealyMachine<S, I, ?, O> automaton, Alphabet<I> input, SplitTree<S, I, O> node) {
        return BacktrackingSearch.compute(automaton, input, node, node.getPartition().size());
    }

    private static <S, I, O> Optional<ADSNode<S, I, O>> compute(MealyMachine<S, I, ?, O> automaton, Alphabet<I> input, SplitTree<S, I, O> node, int originalPartitionSize) {
        long maximumSplittingWordLength = ADSUtil.computeMaximumSplittingWordLength(automaton.size(), node.getPartition().size(), originalPartitionSize);
        LinkedList<Word> splittingWordCandidates = new LinkedList<Word>();
        StateIDs stateIds = automaton.stateIDs();
        HashSet<BitSet> cache = new HashSet<BitSet>();
        splittingWordCandidates.add(Word.epsilon());
        while (!splittingWordCandidates.isEmpty()) {
            Word prefix = (Word)splittingWordCandidates.poll();
            Map currentToInitialMapping = node.getPartition().stream().collect(Collectors.toMap(x -> automaton.getSuccessor(x, (Iterable)prefix), Function.identity()));
            BitSet currentSetAsBitSet = new BitSet();
            for (Object s : currentToInitialMapping.keySet()) {
                currentSetAsBitSet.set(stateIds.getStateId(s));
            }
            if (cache.contains(currentSetAsBitSet)) continue;
            block2: for (Object i : input) {
                HashMap successors = new HashMap();
                for (Object entry : currentToInitialMapping.entrySet()) {
                    Object child;
                    Object current = entry.getKey();
                    Object nextState = automaton.getSuccessor(current, i);
                    Object nextOutput = automaton.getOutput(current, i);
                    if (!successors.containsKey(nextOutput)) {
                        child = new SplitTree(new HashSet());
                        successors.put(nextOutput, child);
                    } else {
                        child = (SplitTree)successors.get(nextOutput);
                    }
                    if (!((SplitTree)child).getPartition().add(nextState)) continue block2;
                    ((SplitTree)child).getMapping().put(nextState, node.getMapping().get(entry.getValue()));
                }
                if (successors.size() > 1) {
                    Object entry;
                    HashMap results = new HashMap();
                    entry = successors.entrySet().iterator();
                    while (entry.hasNext()) {
                        Map.Entry entry2 = (Map.Entry)entry.next();
                        SplitTree currentNode = (SplitTree)entry2.getValue();
                        BitSet currentNodeAsBitSet = new BitSet();
                        for (Object s : currentNode.getPartition()) {
                            currentNodeAsBitSet.set(stateIds.getStateId(s));
                        }
                        if (cache.contains(currentNodeAsBitSet)) continue block2;
                        Optional<ADSNode<S, I, O>> succ = currentNode.getPartition().size() > 2 ? BacktrackingSearch.compute(automaton, input, currentNode, originalPartitionSize) : ADS.compute(automaton, input, currentNode);
                        if (!succ.isPresent()) {
                            cache.add(currentNodeAsBitSet);
                            continue block2;
                        }
                        results.put(entry2.getKey(), succ.get());
                    }
                    Pair<ADSNode<S, I, O>, ADSNode<S, I, O>> ads = ADSUtil.buildFromTrace(automaton, prefix.append(i), node.getPartition().iterator().next());
                    ADSNode head = (ADSNode)ads.getFirst();
                    ADSNode tail = (ADSNode)ads.getSecond();
                    for (Map.Entry entry3 : results.entrySet()) {
                        ((ADSNode)entry3.getValue()).setParent((RecursiveADSNode)tail);
                        tail.getChildren().put(entry3.getKey(), entry3.getValue());
                    }
                    return Optional.of(head);
                }
                if ((long)prefix.length() >= maximumSplittingWordLength) continue;
                splittingWordCandidates.add(prefix.append(i));
            }
            cache.add(currentSetAsBitSet);
        }
        return Optional.empty();
    }

    public static <S, I, O> Optional<ADSNode<S, I, O>> computeOptimal(MealyMachine<S, I, ?, O> automaton, Alphabet<I> input, Set<S> states, CostAggregator costAggregator) {
        if (states.size() == 1) {
            return ADS.compute(automaton, input, states);
        }
        HashMap<Set<S>, Optional<SearchState<S, I, O>>> searchStateCache = new HashMap<Set<S>, Optional<SearchState<S, I, O>>>();
        HashSet<Set<S>> traceCache = new HashSet<Set<S>>();
        Optional<SearchState<S, I, O>> searchState = BacktrackingSearch.exploreSearchSpace(automaton, input, states, costAggregator, searchStateCache, traceCache, Integer.MAX_VALUE);
        if (!searchState.isPresent()) {
            return Optional.empty();
        }
        Map initialMapping = states.stream().collect(Collectors.toMap(Function.identity(), Function.identity()));
        return Optional.of(BacktrackingSearch.constructADS(automaton, initialMapping, searchState.get()));
    }

    private static <S, I, O> Optional<SearchState<S, I, O>> exploreSearchSpace(MealyMachine<S, I, ?, O> automaton, Alphabet<I> alphabet, Set<S> targets, CostAggregator costAggregator, Map<Set<S>, Optional<SearchState<S, I, O>>> stateCache, Set<Set<S>> currentTraceCache, int costsBound) {
        Optional<SearchState<S, I, O>> cachedValue = stateCache.get(targets);
        if (cachedValue != null) {
            return cachedValue;
        }
        if (currentTraceCache.contains(targets)) {
            return Optional.empty();
        }
        if (targets.size() == 1) {
            SearchState resultSS = new SearchState();
            resultSS.costs = 0;
            resultSS.successors = null;
            resultSS.symbol = null;
            Optional result = Optional.of(resultSS);
            stateCache.put(targets, result);
            return result;
        }
        if (costsBound == 0) {
            return Optional.empty();
        }
        boolean foundValidSuccessor = false;
        boolean convergingStates = true;
        int bestCosts = costsBound;
        HashMap bestSuccessor = null;
        Object bestInputSymbol = null;
        block0: for (Object i : alphabet) {
            int costsForInputSymbol;
            boolean foundValidSuccessorForInputSymbol;
            SearchState<S, I, O> subResult;
            Optional<SearchState<S, I, O>> potentialResult;
            Map successorsForInputSymbol;
            HashMap successors = new HashMap();
            for (Object s : targets) {
                Set<Object> child;
                Object nextState = automaton.getSuccessor(s, i);
                Object nextOutput = automaton.getOutput(s, i);
                if (!successors.containsKey(nextOutput)) {
                    child = new HashSet();
                    successors.put(nextOutput, child);
                } else {
                    child = (Set)successors.get(nextOutput);
                }
                if (child.add(nextState)) continue;
                continue block0;
            }
            convergingStates = false;
            if (successors.size() > 1) {
                Object s;
                successorsForInputSymbol = Maps.newHashMapWithExpectedSize((int)successors.size());
                int partitionCosts = 0;
                s = successors.entrySet().iterator();
                while (s.hasNext()) {
                    Map.Entry entry = (Map.Entry)s.next();
                    potentialResult = BacktrackingSearch.exploreSearchSpace(automaton, alphabet, (Set)entry.getValue(), costAggregator, stateCache, new HashSet<Set<S>>(), bestCosts);
                    if (!potentialResult.isPresent()) continue block0;
                    subResult = potentialResult.get();
                    successorsForInputSymbol.put(entry.getKey(), subResult);
                    if ((partitionCosts = ((Integer)costAggregator.apply(partitionCosts, ((SearchState)subResult).costs)).intValue()) < bestCosts) continue;
                    continue block0;
                }
                foundValidSuccessorForInputSymbol = true;
                costsForInputSymbol = partitionCosts;
            } else {
                Map.Entry entry = successors.entrySet().iterator().next();
                Set nextTargets = (Set)entry.getValue();
                HashSet<Set<S>> nextTraceCache = new HashSet<Set<S>>(currentTraceCache);
                nextTraceCache.add(targets);
                potentialResult = BacktrackingSearch.exploreSearchSpace(automaton, alphabet, nextTargets, costAggregator, stateCache, nextTraceCache, bestCosts);
                if (!potentialResult.isPresent()) continue;
                subResult = potentialResult.get();
                foundValidSuccessorForInputSymbol = true;
                costsForInputSymbol = ((SearchState)subResult).costs;
                successorsForInputSymbol = Collections.singletonMap(entry.getKey(), subResult);
            }
            if (!foundValidSuccessorForInputSymbol || costsForInputSymbol >= bestCosts) continue;
            foundValidSuccessor = true;
            bestCosts = costsForInputSymbol;
            bestSuccessor = successorsForInputSymbol;
            bestInputSymbol = i;
        }
        if (convergingStates) {
            stateCache.put(targets, Optional.empty());
            return Optional.empty();
        }
        if (!foundValidSuccessor) {
            return Optional.empty();
        }
        SearchState resultSS = new SearchState();
        resultSS.costs = bestCosts + 1;
        resultSS.successors = bestSuccessor;
        resultSS.symbol = bestInputSymbol;
        Optional result = Optional.of(resultSS);
        stateCache.put(targets, result);
        return result;
    }

    private static <S, I, O> ADSNode<S, I, O> constructADS(MealyMachine<S, I, ?, O> automaton, Map<S, S> currentToInitialMapping, SearchState<S, I, O> searchState) {
        if (currentToInitialMapping.size() == 1) {
            return new ADSLeafNode(null, currentToInitialMapping.values().iterator().next());
        }
        Object i = ((SearchState)searchState).symbol;
        HashMap successors = new HashMap();
        for (Map.Entry<S, S> entry : currentToInitialMapping.entrySet()) {
            Map<Object, S> nextMapping;
            S current = entry.getKey();
            Object nextState = automaton.getSuccessor(current, i);
            Object nextOutput = automaton.getOutput(current, i);
            if (!successors.containsKey(nextOutput)) {
                nextMapping = new HashMap();
                successors.put(nextOutput, nextMapping);
            } else {
                nextMapping = (Map)successors.get(nextOutput);
            }
            if (nextMapping.put(nextState, entry.getValue()) == null) continue;
            throw new IllegalStateException();
        }
        ADSSymbolNode result = new ADSSymbolNode(null, i);
        for (Map.Entry entry : successors.entrySet()) {
            Object output = entry.getKey();
            Map nextMapping = (Map)entry.getValue();
            ADSNode<S, I, O> successor = BacktrackingSearch.constructADS(automaton, nextMapping, (SearchState)((SearchState)searchState).successors.get(output));
            result.getChildren().put(output, successor);
            successor.setParent((RecursiveADSNode)result);
        }
        return result;
    }

    private static class SearchState<S, I, O> {
        private I symbol;
        private Map<O, SearchState<S, I, O>> successors;
        private int costs;

        private SearchState() {
        }
    }

    public static enum CostAggregator implements BiFunction<Integer, Integer, Integer>
    {
        MIN_LENGTH{

            @Override
            public Integer apply(Integer oldValue, Integer newValue) {
                return Math.max(oldValue, newValue);
            }
        }
        ,
        MIN_SIZE{

            @Override
            public Integer apply(Integer oldValue, Integer newValue) {
                return oldValue + newValue;
            }
        };

    }
}

