package ai.djl.serving.wlm;

import ai.djl.Device;
import ai.djl.ModelException;
import ai.djl.serving.wlm.ModelInfo;
import ai.djl.serving.wlm.util.WlmCapacityException;
import ai.djl.serving.wlm.util.WlmConfigManager;
import ai.djl.serving.wlm.util.WlmException;
import ai.djl.serving.wlm.util.WlmShutdownException;
import ai.djl.serving.wlm.util.WorkerJob;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/serving/wlm/WorkLoadManager.class */
public class WorkLoadManager implements AutoCloseable {
    private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class);
    private ExecutorService threadPool = Executors.newCachedThreadPool();
    private ConcurrentHashMap<ModelInfo<?, ?>, WorkerPool<?, ?>> workerPools = new ConcurrentHashMap<>();

    /* loaded from: input_file:ai/djl/serving/wlm/WorkLoadManager$WorkerPool.class */
    public final class WorkerPool<I, O> implements AutoCloseable {
        private final ModelInfo<I, O> model;
        private Map<Device, WorkerPool<I, O>.WorkerPoolDevice> devices = new ConcurrentHashMap();
        private LinkedBlockingDeque<WorkerJob<I, O>> jobQueue;

        /* loaded from: input_file:ai/djl/serving/wlm/WorkLoadManager$WorkerPool$WorkerPoolDevice.class */
        public final class WorkerPoolDevice {
            private Device device;
            private int minWorkers;
            private int maxWorkers;
            private List<WorkerThread<I, O>> workers;

            private WorkerPoolDevice(Device device, int i, int i2) {
                this.device = device;
                this.minWorkers = i;
                this.maxWorkers = i2;
                this.workers = new CopyOnWriteArrayList();
            }

            public int getMinWorkers() {
                return this.minWorkers;
            }

            public int getMaxWorkers() {
                return this.maxWorkers;
            }

            /* JADX INFO: Access modifiers changed from: private */
            public void addThreads(int i, boolean z) {
                for (int i2 = 0; i2 < i; i2++) {
                    WorkerThread<I, O> build = WorkerThread.builder(WorkerPool.this.model.getInputClass(), WorkerPool.this.model.getOutputClass()).setModel(WorkerPool.this.model).setDevice(this.device).setJobQueue(WorkerPool.this.jobQueue).optFixPoolThread(z).build();
                    this.workers.add(build);
                    WorkLoadManager.this.threadPool.submit(build);
                }
            }
        }

        public WorkerPool(ModelInfo<I, O> modelInfo) {
            this.model = modelInfo;
            this.jobQueue = new LinkedBlockingDeque<>(modelInfo.getQueueSize());
        }

        public List<WorkerThread<I, O>> getWorkers() {
            return (List) this.devices.values().stream().flatMap(workerPoolDevice -> {
                return workerPoolDevice.workers.stream();
            }).collect(Collectors.toList());
        }

        public LinkedBlockingDeque<WorkerJob<I, O>> getJobQueue() {
            return this.jobQueue;
        }

        public int getMinWorkers() {
            return this.devices.values().stream().mapToInt(workerPoolDevice -> {
                return workerPoolDevice.minWorkers;
            }).reduce(0, Integer::sum);
        }

        public int getMaxWorkers() {
            return this.devices.values().stream().mapToInt(workerPoolDevice -> {
                return workerPoolDevice.maxWorkers;
            }).reduce(0, Integer::sum);
        }

        public WorkerPool<I, O> scaleWorkers(Device device, int i, int i2) {
            synchronized (this.model) {
                try {
                    this.model.load(device);
                    if (this.model.getStatus() != ModelInfo.Status.READY) {
                        WorkLoadManager.logger.warn("Cannot scale workers while model is not READY: {}", this.model);
                        return this;
                    }
                    Device withDefaultDevice = this.model.withDefaultDevice(device);
                    WlmConfigManager wlmConfigManager = WlmConfigManager.getInstance();
                    int defaultMaxWorkers = wlmConfigManager.getDefaultMaxWorkers(this.model, withDefaultDevice, i2);
                    int defaultMinWorkers = wlmConfigManager.getDefaultMinWorkers(this.model, withDefaultDevice, i, defaultMaxWorkers);
                    WorkerPool<I, O>.WorkerPoolDevice workerPoolDevice = new WorkerPoolDevice(withDefaultDevice, defaultMinWorkers, defaultMaxWorkers);
                    this.devices.put(withDefaultDevice, workerPoolDevice);
                    cleanup();
                    List<WorkerThread<I, O>> workers = getWorkers();
                    List list = (List) workers.stream().filter((v0) -> {
                        return v0.isFixPoolThread();
                    }).collect(Collectors.toList());
                    int size = list.size();
                    if (size < defaultMinWorkers) {
                        workerPoolDevice.addThreads(defaultMinWorkers - size, true);
                    } else {
                        list.subList(defaultMinWorkers, size).forEach(workerThread -> {
                            workers.remove(workerThread);
                            workerThread.shutdown(WorkerState.WORKER_SCALED_DOWN);
                        });
                    }
                    log();
                    return this;
                } catch (ModelException | IOException e) {
                    throw new CompletionException((Throwable) e);
                }
            }
        }

        public WorkerPool<I, O>.WorkerPoolDevice forDevice(Device device) {
            return this.devices.get(device);
        }

        public void log() {
            if (WorkLoadManager.logger.isDebugEnabled()) {
                StringBuffer stringBuffer = new StringBuffer();
                getWorkers().forEach(workerThread -> {
                    stringBuffer.append(workerThread.getWorkerId());
                    if (workerThread.isFixPoolThread()) {
                        stringBuffer.append("-fixedPool\n");
                    } else {
                        stringBuffer.append("-tmpPool\n");
                    }
                });
                WorkLoadManager.logger.debug("worker pool for model {}:\n {}", this.model, stringBuffer);
            }
        }

        public void cleanup() {
            Iterator<WorkerPool<I, O>.WorkerPoolDevice> it = this.devices.values().iterator();
            while (it.hasNext()) {
                ((WorkerPoolDevice) it.next()).workers.removeIf(workerThread -> {
                    return workerThread.getState() == WorkerState.WORKER_STOPPED || workerThread.getState() == WorkerState.WORKER_ERROR;
                });
            }
        }

        @Override // java.lang.AutoCloseable
        public void close() {
            this.model.close();
            Iterator<WorkerPool<I, O>.WorkerPoolDevice> it = this.devices.values().iterator();
            while (it.hasNext()) {
                Iterator it2 = ((WorkerPoolDevice) it.next()).workers.iterator();
                while (it2.hasNext()) {
                    ((WorkerThread) it2.next()).shutdown(WorkerState.WORKER_STOPPED);
                }
            }
            Iterator<WorkerJob<I, O>> it3 = this.jobQueue.iterator();
            while (it3.hasNext()) {
                it3.next().getFuture().cancel(true);
            }
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addThreads(int i) {
            Collections.shuffle(new ArrayList(this.devices.values()));
            for (WorkerPool<I, O>.WorkerPoolDevice workerPoolDevice : this.devices.values()) {
                int min = Math.min(i, workerPoolDevice.getMaxWorkers() - ((WorkerPoolDevice) workerPoolDevice).workers.size());
                workerPoolDevice.addThreads(min, false);
                i -= min;
                if (i == 0) {
                    return;
                }
            }
        }
    }

    public <I, O> List<WorkerThread<I, O>> getWorkers(ModelInfo<I, O> modelInfo) {
        List<WorkerThread<I, O>> workers;
        WorkerPool<I, O> workerPoolForModel = getWorkerPoolForModel(modelInfo);
        if (workerPoolForModel == null) {
            workers = Collections.emptyList();
        } else {
            workers = workerPoolForModel.getWorkers();
            if (workers == null) {
                workers = Collections.emptyList();
            }
        }
        return workers;
    }

    public <I, O> WorkerPool<I, O> registerModel(ModelInfo<I, O> modelInfo) {
        return (WorkerPool) this.workerPools.computeIfAbsent(modelInfo, modelInfo2 -> {
            return new WorkerPool(modelInfo);
        });
    }

    public void unregisterModel(ModelInfo<?, ?> modelInfo) {
        getWorkerPoolForModel(modelInfo).scaleWorkers(null, 0, 0);
        this.workerPools.remove(modelInfo);
    }

    public <I, O> CompletableFuture<O> runJob(Job<I, O> job) {
        CompletableFuture<O> completableFuture = new CompletableFuture<>();
        ModelInfo<I, O> model = job.getModel();
        if (model.getStatus() != ModelInfo.Status.READY) {
            completableFuture.completeExceptionally(new WlmException("Model is not ready: " + model.getStatus()));
            return completableFuture;
        }
        WorkerPool<I, O> workerPoolForModel = getWorkerPoolForModel(model);
        int maxWorkers = workerPoolForModel.getMaxWorkers();
        if (maxWorkers == 0) {
            completableFuture.completeExceptionally(new WlmShutdownException("All model workers has been shutdown: " + model));
            return completableFuture;
        }
        LinkedBlockingDeque<WorkerJob<I, O>> jobQueue = workerPoolForModel.getJobQueue();
        if (!jobQueue.offer(new WorkerJob<>(job, completableFuture))) {
            completableFuture.completeExceptionally(new WlmCapacityException("Worker queue capacity exceeded for model: " + model));
            return completableFuture;
        }
        int numRunningWorkers = getNumRunningWorkers(model);
        if (numRunningWorkers == 0 || (numRunningWorkers < maxWorkers && jobQueue.size() > model.getBatchSize() * 2)) {
            synchronized (workerPoolForModel) {
                int numRunningWorkers2 = getNumRunningWorkers(model);
                if (numRunningWorkers2 < maxWorkers) {
                    logger.info("Scaling up workers for model {} to {} ", model, Integer.valueOf(numRunningWorkers2 + 1));
                    workerPoolForModel.addThreads(1);
                }
            }
        }
        return completableFuture;
    }

    public int getNumRunningWorkers(ModelInfo<?, ?> modelInfo) {
        int i = 0;
        WorkerPool<?, ?> workerPool = this.workerPools.get(modelInfo);
        if (workerPool != null) {
            workerPool.cleanup();
            for (WorkerThread<?, ?> workerThread : workerPool.getWorkers()) {
                if (workerThread.getState() != WorkerState.WORKER_STOPPED && workerThread.getState() != WorkerState.WORKER_ERROR && workerThread.getState() != WorkerState.WORKER_SCALED_DOWN) {
                    i++;
                }
            }
        }
        return i;
    }

    public int getQueueLength(ModelInfo<?, ?> modelInfo) {
        return getWorkerPoolForModel(modelInfo).getJobQueue().size();
    }

    public <I, O> WorkerPool<I, O> getWorkerPoolForModel(ModelInfo<I, O> modelInfo) {
        return (WorkerPool) this.workerPools.get(modelInfo);
    }

    @Override // java.lang.AutoCloseable
    public void close() {
        this.threadPool.shutdownNow();
        Iterator<WorkerPool<?, ?>> it = this.workerPools.values().iterator();
        while (it.hasNext()) {
            it.next().close();
        }
    }
}
