/*
 * Decompiled with CFR 0.152.
 */
package net.automatalib.util.automata.ads;

import java.util.Iterator;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import net.automatalib.automata.transout.MealyMachine;
import net.automatalib.commons.util.Pair;
import net.automatalib.graphs.ads.ADSNode;
import net.automatalib.graphs.ads.impl.ADSLeafNode;
import net.automatalib.util.automata.Automata;
import net.automatalib.util.automata.ads.ADSUtil;
import net.automatalib.util.automata.ads.SplitTree;
import net.automatalib.words.Alphabet;
import net.automatalib.words.Word;

public final class StateEquivalence {
    private StateEquivalence() {
    }

    public static <S, I, O> Optional<ADSNode<S, I, O>> compute(MealyMachine<S, I, ?, O> automaton, Alphabet<I> input, Set<S> states) throws IllegalArgumentException {
        if (states.size() != 2) {
            throw new IllegalArgumentException("StateEquivalence can only distinguish 2 states");
        }
        SplitTree node = new SplitTree(states);
        node.getMapping().putAll(states.stream().collect(Collectors.toMap(Function.identity(), Function.identity())));
        return StateEquivalence.compute(automaton, input, node);
    }

    static <S, I, O> Optional<ADSNode<S, I, O>> compute(MealyMachine<S, I, ?, O> automaton, Alphabet<I> input, SplitTree<S, I, O> node) {
        S s2;
        Iterator<S> targetStateIterator = node.getPartition().iterator();
        S s1 = targetStateIterator.next();
        Word<I> separatingWord = Automata.findSeparatingWord(automaton, s1, s2 = targetStateIterator.next(), input);
        if (separatingWord == null) {
            return Optional.empty();
        }
        Word s1Output = automaton.computeStateOutput(s1, separatingWord);
        Word s2Output = automaton.computeStateOutput(s2, separatingWord);
        Word sharedOutput = s1Output.longestCommonPrefix(s2Output);
        Word trace = separatingWord.prefix(sharedOutput.length() + 1);
        Pair<ADSNode<S, I, O>, ADSNode<S, I, O>> ads = ADSUtil.buildFromTrace(automaton, trace, s1);
        ADSNode head = (ADSNode)ads.getFirst();
        ADSNode tail = (ADSNode)ads.getSecond();
        ADSLeafNode s1FinalNode = new ADSLeafNode(tail, node.getMapping().get(s1));
        ADSLeafNode s2FinalNode = new ADSLeafNode(tail, node.getMapping().get(s2));
        tail.getChildren().put(s1Output.getSymbol(sharedOutput.length()), s1FinalNode);
        tail.getChildren().put(s2Output.getSymbol(sharedOutput.length()), s2FinalNode);
        return Optional.of(head);
    }
}

