package ai.djl.serving.wlm;

import ai.djl.modality.Output;
import ai.djl.serving.wlm.util.WlmCapacityException;
import ai.djl.serving.wlm.util.WlmConfigManager;
import ai.djl.serving.wlm.util.WlmShutdownException;
import ai.djl.serving.wlm.util.WorkerJob;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/serving/wlm/WorkLoadManager.class */
public class WorkLoadManager {
    private static final Logger logger = LoggerFactory.getLogger(WorkLoadManager.class);
    private ExecutorService threadPool = Executors.newCachedThreadPool();
    private ConcurrentHashMap<ModelInfo, WorkerPool> workerPools = new ConcurrentHashMap<>();

    /* loaded from: input_file:ai/djl/serving/wlm/WorkLoadManager$WorkerPool.class */
    public final class WorkerPool {
        private final ModelInfo model;
        private List<WorkerThread> workers = new CopyOnWriteArrayList();
        private LinkedBlockingDeque<WorkerJob> jobQueue;
        private int minWorkers;
        private int maxWorkers;

        public WorkerPool(ModelInfo modelInfo) {
            this.model = modelInfo;
            this.jobQueue = new LinkedBlockingDeque<>(modelInfo.getQueueSize());
        }

        public List<WorkerThread> getWorkers() {
            return this.workers;
        }

        public LinkedBlockingDeque<WorkerJob> getJobQueue() {
            return this.jobQueue;
        }

        public int getMinWorkers() {
            return this.minWorkers;
        }

        public int getMaxWorkers() {
            return this.maxWorkers;
        }

        public WorkerPool scaleWorkers(String str, int i, int i2) {
            synchronized (this.model) {
                this.maxWorkers = WlmConfigManager.getInstance().getDefaultWorkers(this.model.getModel().getNDManager(), str, i2);
                this.minWorkers = Math.min(i, this.maxWorkers);
                cleanup();
                List<WorkerThread> workers = getWorkers();
                List list = (List) workers.stream().filter((v0) -> {
                    return v0.isFixPoolThread();
                }).collect(Collectors.toList());
                int size = list.size();
                if (size < this.minWorkers) {
                    addThreads(this.model, this.minWorkers - size, true);
                } else {
                    list.subList(this.minWorkers, size).forEach(workerThread -> {
                        workers.remove(workerThread);
                        workerThread.shutdown(WorkerState.WORKER_SCALED_DOWN);
                    });
                }
                log();
            }
            return this;
        }

        /* JADX INFO: Access modifiers changed from: private */
        public void addThreads(ModelInfo modelInfo, int i, boolean z) {
            for (int i2 = 0; i2 < i; i2++) {
                WorkerThread build = WorkerThread.builder().setModel(modelInfo).setJobQueue(this.jobQueue).optFixPoolThread(z).build();
                this.workers.add(build);
                WorkLoadManager.this.threadPool.submit(build);
            }
        }

        public void log() {
            if (WorkLoadManager.logger.isDebugEnabled()) {
                StringBuffer stringBuffer = new StringBuffer();
                this.workers.forEach(workerThread -> {
                    stringBuffer.append(workerThread.getWorkerId());
                    if (workerThread.isFixPoolThread()) {
                        stringBuffer.append("-fixedPool\n");
                    } else {
                        stringBuffer.append("-tmpPool\n");
                    }
                });
                WorkLoadManager.logger.debug("worker pool for model {}:\n {}", this.model.getModelName(), stringBuffer);
            }
        }

        public void cleanup() {
            this.workers.removeIf(workerThread -> {
                return workerThread.getState() == WorkerState.WORKER_STOPPED || workerThread.getState() == WorkerState.WORKER_ERROR;
            });
        }
    }

    public List<WorkerThread> getWorkers(ModelInfo modelInfo) {
        List<WorkerThread> workers;
        WorkerPool workerPool = this.workerPools.get(modelInfo);
        if (workerPool == null) {
            workers = Collections.emptyList();
        } else {
            workers = workerPool.getWorkers();
            if (workers == null) {
                workers = Collections.emptyList();
            }
        }
        return workers;
    }

    public void unregisterModel(ModelInfo modelInfo) {
        getWorkerPoolForModel(modelInfo).scaleWorkers(null, 0, 0);
        this.workerPools.remove(modelInfo);
    }

    public CompletableFuture<Output> runJob(Job job) {
        CompletableFuture<Output> completableFuture = new CompletableFuture<>();
        ModelInfo model = job.getModel();
        WorkerPool workerPoolForModel = getWorkerPoolForModel(model);
        int maxWorkers = workerPoolForModel.getMaxWorkers();
        if (maxWorkers == 0) {
            completableFuture.completeExceptionally(new WlmShutdownException("All model workers has been shutdown: " + model.getModelName()));
            return completableFuture;
        }
        LinkedBlockingDeque<WorkerJob> jobQueue = workerPoolForModel.getJobQueue();
        if (!jobQueue.offer(new WorkerJob(job, completableFuture))) {
            completableFuture.completeExceptionally(new WlmCapacityException("Worker queue capacity exceeded for model: " + model.getModelName()));
            return completableFuture;
        }
        int numRunningWorkers = getNumRunningWorkers(model);
        if (numRunningWorkers == 0 || (numRunningWorkers < maxWorkers && jobQueue.size() > model.getBatchSize() * 2)) {
            synchronized (model.getModel()) {
                int numRunningWorkers2 = getNumRunningWorkers(model);
                if (numRunningWorkers2 < maxWorkers) {
                    logger.info("Scaling up workers for model {} to {} ", model, Integer.valueOf(numRunningWorkers2 + 1));
                    workerPoolForModel.addThreads(model, 1, false);
                }
            }
        }
        return completableFuture;
    }

    public int getNumRunningWorkers(ModelInfo modelInfo) {
        int i = 0;
        WorkerPool workerPool = this.workerPools.get(modelInfo);
        if (workerPool != null) {
            workerPool.cleanup();
            for (WorkerThread workerThread : workerPool.getWorkers()) {
                if (workerThread.getState() != WorkerState.WORKER_STOPPED && workerThread.getState() != WorkerState.WORKER_ERROR && workerThread.getState() != WorkerState.WORKER_SCALED_DOWN) {
                    i++;
                }
            }
        }
        return i;
    }

    public int getQueueLength(ModelInfo modelInfo) {
        return getWorkerPoolForModel(modelInfo).getJobQueue().size();
    }

    public WorkerPool getWorkerPoolForModel(ModelInfo modelInfo) {
        return this.workerPools.computeIfAbsent(modelInfo, modelInfo2 -> {
            return new WorkerPool(modelInfo);
        });
    }
}
