package ai.djl.serving.wlm;

import ai.djl.Device;
import ai.djl.inference.Predictor;
import ai.djl.serving.wlm.util.WlmException;
import ai.djl.serving.wlm.util.WorkerJob;
import ai.djl.translate.TranslateException;
import java.util.List;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicReference;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/serving/wlm/WorkerThread.class */
public final class WorkerThread<I, O> implements Runnable {
    private static final Logger logger = LoggerFactory.getLogger(WorkerThread.class);
    private String workerName;
    private Predictor<I, O> predictor;
    private AtomicBoolean running;
    private BatchAggregator<I, O> aggregator;
    private Device device;
    private AtomicReference<Thread> currentThread;
    private WorkerState state;
    private int workerId;
    private long startTime;
    private boolean fixPoolThread;

    /* loaded from: input_file:ai/djl/serving/wlm/WorkerThread$Builder.class */
    public static class Builder<I, O> {
        private ModelInfo<I, O> model;
        private Device device;
        private BatchAggregator<I, O> aggregator;
        private LinkedBlockingDeque<WorkerJob<I, O>> jobQueue;
        private boolean fixPoolThread = true;

        Builder() {
        }

        protected Builder<I, O> self() {
            return this;
        }

        protected void preBuildProcessing() {
            if (this.aggregator == null) {
                if (this.fixPoolThread) {
                    this.aggregator = new PermanentBatchAggregator(this.model, this.jobQueue);
                } else {
                    this.aggregator = new TemporaryBatchAggregator(this.model, this.jobQueue);
                }
            }
        }

        protected void validate() {
            if (this.device == null) {
                throw new IllegalArgumentException("Must set device for worker thread");
            }
            if (this.model == null) {
                throw new IllegalArgumentException("Must set model for worker thread");
            }
            if (this.jobQueue == null && this.aggregator == null) {
                throw new IllegalArgumentException("one of jobQueue or BatchAggregator have to be set.");
            }
        }

        public WorkerThread<I, O> build() {
            validate();
            preBuildProcessing();
            return new WorkerThread<>(this);
        }

        public Builder<I, O> setModel(ModelInfo<I, O> modelInfo) {
            this.model = modelInfo;
            return self();
        }

        public Builder<I, O> setDevice(Device device) {
            this.device = device;
            return self();
        }

        public Builder<I, O> optAggregator(BatchAggregator<I, O> batchAggregator) {
            this.aggregator = batchAggregator;
            return self();
        }

        public Builder<I, O> setJobQueue(LinkedBlockingDeque<WorkerJob<I, O>> linkedBlockingDeque) {
            this.jobQueue = linkedBlockingDeque;
            return self();
        }

        public Builder<I, O> optFixPoolThread(boolean z) {
            this.fixPoolThread = z;
            return self();
        }
    }

    private WorkerThread(Builder<I, O> builder) {
        this.running = new AtomicBoolean(true);
        this.currentThread = new AtomicReference<>();
        this.workerName = buildWorkerName(((Builder) builder).model);
        this.aggregator = ((Builder) builder).aggregator;
        this.workerId = new WorkerIdGenerator().generate();
        this.startTime = System.currentTimeMillis();
        this.fixPoolThread = ((Builder) builder).fixPoolThread;
        this.device = ((Builder) builder).device;
        this.predictor = ((Builder) builder).model.getModel(this.device).newPredictor();
    }

    @Override // java.lang.Runnable
    public void run() {
        Thread currentThread = Thread.currentThread();
        currentThread.setName(this.workerName);
        this.currentThread.set(currentThread);
        this.state = WorkerState.WORKER_STARTED;
        boolean z = false;
        while (isRunning() && !this.aggregator.isFinished()) {
            try {
                try {
                    try {
                        List<I> request = this.aggregator.getRequest();
                        if (request != null && !request.isEmpty()) {
                            try {
                                this.aggregator.sendResponse(this.predictor.batchPredict(request));
                            } catch (TranslateException e) {
                                logger.warn("Failed to predict", e);
                                this.aggregator.sendError(e);
                            }
                        }
                        z = false;
                    } catch (Throwable th) {
                        logger.error("Server error", th);
                        String message = th.getMessage();
                        logger.debug("Shutting down worker thread .. {}", this.currentThread.get().getName());
                        this.currentThread.set(null);
                        shutdown(WorkerState.WORKER_STOPPED);
                        if (z) {
                            this.aggregator.sendError(new WlmException(message));
                            return;
                        }
                        return;
                    }
                } catch (InterruptedException e2) {
                    logger.debug("Shutting down the thread .. Scaling down.");
                    logger.debug("Shutting down worker thread .. {}", this.currentThread.get().getName());
                    this.currentThread.set(null);
                    shutdown(WorkerState.WORKER_STOPPED);
                    if (z) {
                        this.aggregator.sendError(new WlmException("Worker shutting down"));
                        return;
                    }
                    return;
                }
            } catch (Throwable th2) {
                logger.debug("Shutting down worker thread .. {}", this.currentThread.get().getName());
                this.currentThread.set(null);
                shutdown(WorkerState.WORKER_STOPPED);
                if (z) {
                    this.aggregator.sendError(new WlmException("Worker shutting down"));
                }
                throw th2;
            }
        }
        logger.debug("Shutting down worker thread .. {}", this.currentThread.get().getName());
        this.currentThread.set(null);
        shutdown(WorkerState.WORKER_STOPPED);
        if (z) {
            this.aggregator.sendError(new WlmException("Worker shutting down"));
        }
    }

    public int getWorkerId() {
        return this.workerId;
    }

    public boolean isRunning() {
        return this.running.get();
    }

    public Device getDevice() {
        return this.device;
    }

    public long getStartTime() {
        return this.startTime;
    }

    public WorkerState getState() {
        return this.state;
    }

    public void shutdown(WorkerState workerState) {
        this.running.set(false);
        setState(workerState);
        Thread andSet = this.currentThread.getAndSet(null);
        if (andSet != null) {
            andSet.interrupt();
            this.aggregator.sendError(new WlmException("Worker shutting down"));
        }
        this.predictor.close();
    }

    private String buildWorkerName(ModelInfo<I, O> modelInfo) {
        String modelId = modelInfo.getModelId();
        if (modelId.length() > 25) {
            modelId = modelId.substring(0, 25);
        }
        return "W-" + modelId + '-' + this.workerId;
    }

    void setState(WorkerState workerState) {
        logger.debug("{} State change {} -> {}", new Object[]{this.workerName, this.state, workerState});
        if (this.state != WorkerState.WORKER_SCALED_DOWN) {
            this.state = workerState;
        }
    }

    public boolean isFixPoolThread() {
        return this.fixPoolThread;
    }

    public static <I, O> Builder<I, O> builder(Class<I> cls, Class<O> cls2) {
        return new Builder<>();
    }
}
