package io.dropwizard.metrics5.jvm;

import io.dropwizard.metrics5.Gauge;
import io.dropwizard.metrics5.Metric;
import io.dropwizard.metrics5.MetricName;
import io.dropwizard.metrics5.MetricRegistry;
import io.dropwizard.metrics5.MetricSet;

import java.lang.management.ManagementFactory;
import java.lang.management.ThreadInfo;
import java.lang.management.ThreadMXBean;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/**
 * A set of gauges for the number of threads in their various states and deadlock detection.
 */
public class ThreadStatesGaugeSet implements MetricSet {

    // do not compute stack traces.
    private final static int STACK_TRACE_DEPTH = 0;

    private final ThreadMXBean threads;
    private final ThreadDeadlockDetector deadlockDetector;

    /**
     * Creates a new set of gauges using the default MXBeans.
     */
    public ThreadStatesGaugeSet() {
        this(ManagementFactory.getThreadMXBean(), new ThreadDeadlockDetector());
    }

    /**
     * Creates a new set of gauges using the given MXBean and detector.
     *
     * @param threads          a thread MXBean
     * @param deadlockDetector a deadlock detector
     */
    public ThreadStatesGaugeSet(ThreadMXBean threads,
                                ThreadDeadlockDetector deadlockDetector) {
        this.threads = threads;
        this.deadlockDetector = deadlockDetector;
    }

    @Override
    public Map<MetricName, Metric> getMetrics() {
        final Map<MetricName, Metric> gauges = new HashMap<>();

        for (final Thread.State state : Thread.State.values()) {
            gauges.put(MetricRegistry.name(state.toString().toLowerCase(), "count"),
                    (Gauge<Object>) () -> getThreadCount(state));
        }

        gauges.put(MetricName.build("count"), (Gauge<Integer>) threads::getThreadCount);
        gauges.put(MetricName.build("daemon.count"), (Gauge<Integer>) threads::getDaemonThreadCount);
        gauges.put(MetricName.build("peak.count"), (Gauge<Integer>) threads::getPeakThreadCount);
        gauges.put(MetricName.build("total_started.count"), (Gauge<Long>) threads::getTotalStartedThreadCount);
        gauges.put(MetricName.build("deadlock.count"), (Gauge<Integer>) () -> deadlockDetector.getDeadlockedThreads().size());
        gauges.put(MetricName.build("deadlocks"), (Gauge<Set<String>>) deadlockDetector::getDeadlockedThreads);

        return Collections.unmodifiableMap(gauges);
    }

    private int getThreadCount(Thread.State state) {
        final ThreadInfo[] allThreads = getThreadInfo();
        int count = 0;
        for (ThreadInfo info : allThreads) {
            if (info != null && info.getThreadState() == state) {
                count++;
            }
        }
        return count;
    }

    ThreadInfo[] getThreadInfo() {
        return threads.getThreadInfo(threads.getAllThreadIds(), STACK_TRACE_DEPTH);
    }

}
