package pragma.protoc.plugin.custom
import java.util.Stack

/**
 * Directed graph.
 * Vertices are unique within the graph based on their value.
 * Edges are unique based on the values of the vertices.
 * Edges both ways are supported.
 * Reflexive edges are supported.
 */
class DirectedUniqueGraph<T> {
    /**
     * Unique within the graph by its value. Equality and hash determined by value.
     * Edges to 'next' are also unique by this same property.
     */
    private class Vertex<T>(val value: T) {
        private val edges = mutableSetOf<Vertex<T>>()
        fun edges(): Set<Vertex<T>> = edges

        fun addNext(vertex: Vertex<T>) = edges.add(vertex)

        // Generated by intellij wizard.
        override fun equals(other: Any?): Boolean {
            if (this === other) return true
            if (javaClass != other?.javaClass) return false

            other as Vertex<*>

            if (value != other.value) return false

            return true
        }
        // Generated by intellij wizard.
        override fun hashCode(): Int {
            return value?.hashCode() ?: 0
        }
    }

    private val vertices = mutableMapOf<T, Vertex<T>>()
    val size: Int
        get() = vertices.size

    /**
     * Add a vertex to the graph. Note that addEdge will add any vertices specified,
     * but this is the only way to add an edgeless vertex.
     */
    fun addVertex(value: T) {
        getOrAddVertex(value)
    }

    /**
     * Connects two vertices in the direction specified. Vertices will be created if they don't exist.
     */
    fun addEdge(from: T, to: T) = getOrAddVertex(from).addNext(getOrAddVertex(to))

    fun containsVertex(value: T): Boolean = vertices.containsKey(value)
    fun containsEdge(from: T, to: T): Boolean = vertices[from]?.edges()?.contains(Vertex(to)) ?: false

    /**
     * Class returned by topologicalSort.
     * If isCyclical is false, values will be the topologically sorted list.
     * If isCyclical is true, values will be set to the recurse stack in which the cycle was found.
     */
    data class TopologicalSortResult<T>(val values: List<T>, val isCyclical: Boolean)

    /**
     * Run a topological sort on the graph.
     * If successful, result.values will contain a list sorted by:
     * "a linear ordering of its vertices such that for every directed edge uv from vertex u to vertex v, u comes before v in the ordering."
     * (Definition: Wikipedia) https://en.wikipedia.org/wiki/Topological_sorting
     *
     * The sorting is for equivalent elements is arbitrary, meaning there can be multiple orderings for the same
     * graph, as long as that ordering is still topological.
     *
     * Sort can fail if graph has any cycles. See TopologicalSortResult.
     */
    fun topologicalSort(): TopologicalSortResult<T> {
        // Visited ensures we only ever visit each vertex once.
        val visited = mutableSetOf<Vertex<T>>()
        // The stack of values we'll walk to generate the final ordering.
        val outStack = Stack<T>()
        for (vertex in vertices.values) {
            // A stack of vertices visited _this recurse_. If we see something in here, it means we've found a cycle.
            val recurseStack = Stack<Vertex<T>>()
            if (visited.add(vertex)) {
                if (!topologicalSortRecurse(vertex, visited, recurseStack, outStack)) {
                    // Found a cycle, abort.
                    return TopologicalSortResult(recurseStack.map { it.value }, true)
                }
            }
        }
        val sorted = mutableListOf<T>()
        while (!outStack.empty()) {
            sorted.add(outStack.pop())
        }
        return TopologicalSortResult(sorted, false)
    }

    private fun topologicalSortRecurse(vertex: Vertex<T>, visited: MutableSet<Vertex<T>>, recurseStack: Stack<Vertex<T>>, outStack: Stack<T>): Boolean {
        recurseStack.push(vertex)
        for (adjacentVertex in vertex.edges()) {
            if (recurseStack.contains(adjacentVertex)) {
                // Found a cycle, abort. Include the vertex in the recurse stack so the loop is clear.
                recurseStack.push(adjacentVertex)
                return false
            }
            if (visited.add(adjacentVertex)) {
                if (!topologicalSortRecurse(adjacentVertex, visited, recurseStack, outStack)) {
                    // Cycle found in children, abort.
                    return false
                }
            }
        }
        recurseStack.pop()
        outStack.push(vertex.value)
        return true
    }

    private fun getOrAddVertex(value: T): Vertex<T> {
        val vertex = vertices[value] ?: Vertex(value)
        vertices[value] = vertex
        return vertex
    }
}
