package ai.djl.aws.sagemaker;

import ai.djl.Model;
import ai.djl.util.RandomUtils;
import ai.djl.util.Utils;
import com.google.gson.JsonParser;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.text.SimpleDateFormat;
import java.util.Date;
import java.util.Objects;
import java.util.Optional;
import org.apache.commons.compress.archivers.tar.TarArchiveEntry;
import org.apache.commons.compress.archivers.tar.TarArchiveOutputStream;
import org.apache.commons.compress.compressors.gzip.GzipCompressorOutputStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import software.amazon.awssdk.core.SdkBytes;
import software.amazon.awssdk.core.exception.SdkException;
import software.amazon.awssdk.core.sync.RequestBody;
import software.amazon.awssdk.regions.Region;
import software.amazon.awssdk.regions.providers.DefaultAwsRegionProviderChain;
import software.amazon.awssdk.services.iam.IamClient;
import software.amazon.awssdk.services.iam.model.AttachRolePolicyRequest;
import software.amazon.awssdk.services.iam.model.CreatePolicyRequest;
import software.amazon.awssdk.services.iam.model.CreateRoleRequest;
import software.amazon.awssdk.services.iam.model.GetPolicyRequest;
import software.amazon.awssdk.services.s3.S3Client;
import software.amazon.awssdk.services.s3.model.CreateBucketRequest;
import software.amazon.awssdk.services.s3.model.HeadBucketRequest;
import software.amazon.awssdk.services.s3.model.ListBucketsRequest;
import software.amazon.awssdk.services.s3.model.PutObjectRequest;
import software.amazon.awssdk.services.s3.model.S3Exception;
import software.amazon.awssdk.services.sagemaker.SageMakerClient;
import software.amazon.awssdk.services.sagemaker.model.ContainerDefinition;
import software.amazon.awssdk.services.sagemaker.model.CreateEndpointConfigRequest;
import software.amazon.awssdk.services.sagemaker.model.CreateEndpointRequest;
import software.amazon.awssdk.services.sagemaker.model.CreateModelRequest;
import software.amazon.awssdk.services.sagemaker.model.DeleteEndpointConfigRequest;
import software.amazon.awssdk.services.sagemaker.model.DeleteEndpointRequest;
import software.amazon.awssdk.services.sagemaker.model.DeleteModelRequest;
import software.amazon.awssdk.services.sagemaker.model.DescribeEndpointConfigRequest;
import software.amazon.awssdk.services.sagemaker.model.DescribeEndpointRequest;
import software.amazon.awssdk.services.sagemaker.model.DescribeModelRequest;
import software.amazon.awssdk.services.sagemaker.model.ProductionVariant;
import software.amazon.awssdk.services.sagemaker.model.SageMakerException;
import software.amazon.awssdk.services.sagemaker.waiters.SageMakerWaiter;
import software.amazon.awssdk.services.sagemakerruntime.SageMakerRuntimeClient;
import software.amazon.awssdk.services.sagemakerruntime.model.InvokeEndpointRequest;
import software.amazon.awssdk.services.sts.StsClient;

/* loaded from: input_file:ai/djl/aws/sagemaker/SageMaker.class */
public final class SageMaker {
    private static final Logger logger = LoggerFactory.getLogger(SageMaker.class);
    private static final char[] CHARS = "abcdefghijklmnopqrstuvwxyz1234567890.-".toCharArray();
    private SageMakerClient sageMaker;
    private SageMakerRuntimeClient smRuntime;
    private S3Client s3;
    private IamClient iam;
    private Region region;
    private Model model;
    private String modelName;
    private String bucketName;
    private String bucketPath;
    private String executionRole;
    private String containerImage;
    private String endpointConfigName;
    private String endpointName;
    private String instanceType;
    private int instanceCount;

    /* loaded from: input_file:ai/djl/aws/sagemaker/SageMaker$Builder.class */
    public static final class Builder {
        Model model;
        String bucketName;
        String executionRole;
        String containerImage;
        String endpointConfigName;
        String endpointName;
        String modelName;
        SageMakerClient sageMaker;
        SageMakerRuntimeClient smRuntime;
        S3Client s3;
        IamClient iam;
        String bucketPath = "";
        String instanceType = "ml.m4.xlarge";
        int instanceCount = 1;

        Builder() {
        }

        public Builder setModel(Model model) {
            this.model = model;
            return this;
        }

        public Builder optBucketName(String str) {
            this.bucketName = str;
            return this;
        }

        public Builder optBucketPath(String str) {
            this.bucketPath = str;
            return this;
        }

        public Builder optExecutionRole(String str) {
            this.executionRole = str;
            return this;
        }

        public Builder optContainerImage(String str) {
            this.containerImage = str;
            return this;
        }

        public Builder optEndpointConfigName(String str) {
            this.endpointConfigName = str;
            return this;
        }

        public Builder optEndpointName(String str) {
            this.endpointName = str;
            return this;
        }

        public Builder optModelName(String str) {
            this.modelName = str;
            return this;
        }

        public Builder optInstanceType(String str) {
            this.instanceType = str;
            return this;
        }

        public Builder optInstanceCount(int i) {
            this.instanceCount = i;
            return this;
        }

        public Builder optSageMakerClient(SageMakerClient sageMakerClient) {
            this.sageMaker = sageMakerClient;
            return this;
        }

        public Builder optSageMakerRuntimeClient(SageMakerRuntimeClient sageMakerRuntimeClient) {
            this.smRuntime = sageMakerRuntimeClient;
            return this;
        }

        public Builder optS3Client(S3Client s3Client) {
            this.s3 = s3Client;
            return this;
        }

        public Builder optIamClient(IamClient iamClient) {
            this.iam = iamClient;
            return this;
        }

        public SageMaker build() {
            if (this.model == null) {
                throw new IllegalArgumentException("Model is required.");
            }
            if (this.sageMaker == null) {
                this.sageMaker = SageMakerClient.create();
            }
            if (this.smRuntime == null) {
                this.smRuntime = SageMakerRuntimeClient.create();
            }
            if (this.s3 == null) {
                this.s3 = S3Client.create();
            }
            if (this.iam == null) {
                this.iam = (IamClient) IamClient.builder().region(Region.AWS_GLOBAL).build();
            }
            if (this.bucketName == null) {
                StringBuilder sb = new StringBuilder("djl-sm-");
                for (int i = 0; i < 8; i++) {
                    sb.append(SageMaker.CHARS[RandomUtils.nextInt(SageMaker.CHARS.length)]);
                }
                this.bucketName = sb.toString();
            }
            if (this.bucketPath.endsWith("/")) {
                this.bucketPath = this.bucketPath.substring(0, this.bucketPath.length() - 1);
            }
            if (this.bucketPath.startsWith("/")) {
                this.bucketPath = this.bucketPath.substring(1);
            }
            if (this.endpointConfigName == null) {
                this.endpointConfigName = this.modelName == null ? this.model.getName() : this.modelName;
            }
            if (this.endpointName == null) {
                this.endpointName = this.modelName == null ? this.model.getName() : this.modelName;
            }
            return new SageMaker(this);
        }
    }

    private SageMaker(Builder builder) {
        this.sageMaker = builder.sageMaker;
        this.smRuntime = builder.smRuntime;
        this.s3 = builder.s3;
        this.iam = builder.iam;
        this.model = builder.model;
        if (builder.modelName != null) {
            this.modelName = builder.modelName;
        } else {
            this.modelName = this.model.getName();
        }
        this.bucketName = builder.bucketName;
        this.bucketPath = builder.bucketPath;
        this.executionRole = builder.executionRole;
        this.containerImage = builder.containerImage;
        this.endpointConfigName = builder.endpointConfigName;
        this.endpointName = builder.endpointName;
        this.instanceType = builder.instanceType;
        this.instanceCount = builder.instanceCount;
        this.region = DefaultAwsRegionProviderChain.builder().build().getRegion();
    }

    public void deploy() throws IOException {
        DescribeEndpointRequest describeEndpointRequest = (DescribeEndpointRequest) DescribeEndpointRequest.builder().endpointName(this.endpointName).build();
        SageMakerWaiter waiter = this.sageMaker.waiter();
        if (doesEndpointExist()) {
            throw new IllegalStateException("Endpoint already exists: " + this.endpointName);
        }
        createEndpointConfig();
        logger.info("Creating endpoint {} ...", this.endpointName);
        String endpointArn = this.sageMaker.createEndpoint((CreateEndpointRequest) CreateEndpointRequest.builder().endpointName(this.endpointName).endpointConfigName(this.endpointConfigName).build()).endpointArn();
        waiter.waitUntilEndpointInService(describeEndpointRequest);
        logger.info("SageMaker endpoint {} created: {}", this.endpointName, endpointArn);
    }

    public void deleteEndpoint(boolean z) {
        try {
            logger.info("Deleting SageMaker endpoint {} ...", this.endpointName);
            this.sageMaker.deleteEndpoint((DeleteEndpointRequest) DeleteEndpointRequest.builder().endpointName(this.endpointName).build());
            this.sageMaker.waiter().waitUntilEndpointDeleted((DescribeEndpointRequest) DescribeEndpointRequest.builder().endpointName(this.endpointConfigName).build());
            logger.info("SageMaker endpoint {} deleted.", this.endpointName);
        } catch (SdkException e) {
            if (!z) {
                throw e;
            }
        }
    }

    public void deleteEndpointConfig(boolean z) {
        try {
            this.sageMaker.deleteEndpointConfig((DeleteEndpointConfigRequest) DeleteEndpointConfigRequest.builder().endpointConfigName(this.endpointConfigName).build());
            logger.info("SageMaker endpoint config {} deleted.", this.endpointConfigName);
        } catch (SdkException e) {
            if (!z) {
                throw e;
            }
        }
    }

    public void deleteSageMakerModel(boolean z) {
        try {
            this.sageMaker.deleteModel((DeleteModelRequest) DeleteModelRequest.builder().modelName(this.modelName).build());
            logger.info("SageMaker model {} deleted.", this.modelName);
        } catch (SdkException e) {
            if (!z) {
                throw e;
            }
        }
    }

    public byte[] invoke(byte[] bArr) {
        return this.smRuntime.invokeEndpoint((InvokeEndpointRequest) InvokeEndpointRequest.builder().endpointName(this.endpointName).body(SdkBytes.fromByteArray(bArr)).build()).body().asByteArray();
    }

    private boolean doesEndpointExist() {
        try {
            this.sageMaker.describeEndpoint((DescribeEndpointRequest) DescribeEndpointRequest.builder().endpointName(this.endpointConfigName).build());
            return true;
        } catch (SageMakerException e) {
            return false;
        }
    }

    private void createEndpointConfig() throws IOException {
        if (doesEndpointConfigExist()) {
            throw new IllegalStateException("Endpoint config already exists: " + this.endpointConfigName);
        }
        createSageMakerModel();
        logger.info("Creating endpoint config {} ...", this.endpointConfigName);
        logger.info("SageMaker endpoint configure {} created: {}", this.endpointConfigName, this.sageMaker.createEndpointConfig((CreateEndpointConfigRequest) CreateEndpointConfigRequest.builder().endpointConfigName(this.endpointConfigName).productionVariants(new ProductionVariant[]{(ProductionVariant) ProductionVariant.builder().variantName("AllTraffic").modelName(this.modelName).initialInstanceCount(Integer.valueOf(this.instanceCount)).initialVariantWeight(Float.valueOf(1.0f)).instanceType(this.instanceType).build()}).build()).endpointConfigArn());
    }

    private boolean doesEndpointConfigExist() {
        try {
            this.sageMaker.describeEndpointConfig((DescribeEndpointConfigRequest) DescribeEndpointConfigRequest.builder().endpointConfigName(this.endpointConfigName).build());
            return true;
        } catch (SageMakerException e) {
            return false;
        }
    }

    private void createSageMakerModel() throws IOException {
        if (doesSageMakerModelExist()) {
            throw new IllegalStateException("SageMaker model already exists: " + this.endpointConfigName);
        }
        createBucket();
        Path tar = tar(this.model.getModelPath());
        String uploadModel = uploadModel(this.bucketName, this.bucketPath.isEmpty() ? this.modelName + ".tar.gz" : this.bucketPath + '/' + this.modelName + ".tar.gz", tar);
        Files.delete(tar);
        createRoleIfNeeded();
        getContainerImageArn();
        logger.info("SageMaker model {} created: {}", this.modelName, this.sageMaker.createModel((CreateModelRequest) CreateModelRequest.builder().modelName(this.modelName).primaryContainer((ContainerDefinition) ContainerDefinition.builder().image(this.containerImage).modelDataUrl(uploadModel).build()).executionRoleArn(this.executionRole).build()).modelArn());
    }

    private boolean doesSageMakerModelExist() {
        try {
            this.sageMaker.describeModel((DescribeModelRequest) DescribeModelRequest.builder().modelName(this.modelName).build());
            return true;
        } catch (SageMakerException e) {
            return false;
        }
    }

    private Path tar(Path path) throws IOException {
        Path createTempFile = Files.createTempFile("model", ".tar.gz", new FileAttribute[0]);
        OutputStream newOutputStream = Files.newOutputStream(createTempFile, new OpenOption[0]);
        try {
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(newOutputStream);
            try {
                GzipCompressorOutputStream gzipCompressorOutputStream = new GzipCompressorOutputStream(bufferedOutputStream);
                try {
                    TarArchiveOutputStream tarArchiveOutputStream = new TarArchiveOutputStream(gzipCompressorOutputStream);
                    try {
                        tarArchiveOutputStream.setBigNumberMode(1);
                        addToTar(path, path, tarArchiveOutputStream);
                        tarArchiveOutputStream.finish();
                        tarArchiveOutputStream.close();
                        gzipCompressorOutputStream.close();
                        bufferedOutputStream.close();
                        if (newOutputStream != null) {
                            newOutputStream.close();
                        }
                        return createTempFile;
                    } catch (Throwable th) {
                        try {
                            tarArchiveOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                        throw th;
                    }
                } catch (Throwable th3) {
                    try {
                        gzipCompressorOutputStream.close();
                    } catch (Throwable th4) {
                        th3.addSuppressed(th4);
                    }
                    throw th3;
                }
            } finally {
            }
        } catch (Throwable th5) {
            if (newOutputStream != null) {
                try {
                    newOutputStream.close();
                } catch (Throwable th6) {
                    th5.addSuppressed(th6);
                }
            }
            throw th5;
        }
    }

    private void addToTar(Path path, Path path2, TarArchiveOutputStream tarArchiveOutputStream) throws IOException {
        String str = this.modelName + '/' + path.relativize(path2);
        if (!Files.isDirectory(path2, new LinkOption[0])) {
            if (Files.isRegularFile(path2, new LinkOption[0])) {
                tarArchiveOutputStream.putArchiveEntry(new TarArchiveEntry(path2.toFile(), str));
                Files.copy(path2, tarArchiveOutputStream);
                tarArchiveOutputStream.closeArchiveEntry();
                return;
            }
            return;
        }
        File[] listFiles = path2.toFile().listFiles();
        if (listFiles != null) {
            for (File file : listFiles) {
                addToTar(path, file.toPath(), tarArchiveOutputStream);
            }
        }
    }

    private void createBucket() {
        if (doesBucketExist()) {
            logger.info("S3 bucket: {} already exists.", this.bucketName);
            return;
        }
        logger.info("Creating S3 bucket: {}", this.bucketName);
        this.s3.createBucket((CreateBucketRequest) CreateBucketRequest.builder().bucket(this.bucketName).build());
        Optional response = this.s3.waiter().waitUntilBucketExists((HeadBucketRequest) HeadBucketRequest.builder().bucket(this.bucketName).build()).matched().response();
        PrintStream printStream = System.out;
        Objects.requireNonNull(printStream);
        response.ifPresent((v1) -> {
            r1.println(v1);
        });
    }

    private boolean doesBucketExist() {
        try {
            return this.s3.listBuckets((ListBucketsRequest) ListBucketsRequest.builder().build()).buckets().stream().anyMatch(bucket -> {
                return bucket.name().equals(this.bucketName);
            });
        } catch (S3Exception e) {
            logger.warn("Failed to check bucket existence", e);
            return false;
        }
    }

    private String uploadModel(String str, String str2, Path path) {
        this.s3.putObject((PutObjectRequest) PutObjectRequest.builder().bucket(str).key(str2).build(), RequestBody.fromFile(path));
        String str3 = "s3://" + str + '/' + str2;
        logger.info("Model uploaded to: {}", str3);
        return str3;
    }

    private void createRoleIfNeeded() {
        if (this.executionRole == null) {
            String format = new SimpleDateFormat("yyyyMMdd'T'HHmmsss").format(new Date());
            String str = "DJLSageMaker-ExecutionRole-" + format;
            this.executionRole = this.iam.createRole((CreateRoleRequest) CreateRoleRequest.builder().roleName(str).path("/service-role/").assumeRolePolicyDocument(readPolicyDocument("assume_role_policy.json")).description("DJL serving execution role for SageMaker.").build()).role().arn();
            String arn = this.iam.createPolicy((CreatePolicyRequest) CreatePolicyRequest.builder().policyName("DJLSageMaker-ExecutionPolicy-" + format).policyDocument(readPolicyDocument("execution_policy.json")).build()).policy().arn();
            this.iam.waiter().waitUntilPolicyExists((GetPolicyRequest) GetPolicyRequest.builder().policyArn(arn).build());
            this.iam.attachRolePolicy((AttachRolePolicyRequest) AttachRolePolicyRequest.builder().roleName(str).policyArn("arn:aws:iam::aws:policy/AmazonSageMakerFullAccess").build());
            this.iam.attachRolePolicy((AttachRolePolicyRequest) AttachRolePolicyRequest.builder().roleName(str).policyArn(arn).build());
        }
    }

    private void getContainerImageArn() {
        if (this.containerImage == null) {
            String containerImageName = getContainerImageName();
            this.containerImage = StsClient.create().getCallerIdentity().account() + ".dkr.ecr." + this.region.id() + ".amazonaws.com/" + containerImageName;
        }
    }

    private String getContainerImageName() {
        String str = Utils.getenv("ECS_CONTAINER_METADATA_FILE");
        if (str == null) {
            throw new AssertionError("Not in a ECS container.");
        }
        try {
            BufferedReader newBufferedReader = Files.newBufferedReader(Paths.get(str, new String[0]));
            try {
                String asString = JsonParser.parseReader(newBufferedReader).getAsJsonObject().get("ImageName").getAsString();
                if (newBufferedReader != null) {
                    newBufferedReader.close();
                }
                return asString;
            } finally {
            }
        } catch (IOException e) {
            throw new AssertionError("Failed to read container metadata.", e);
        }
    }

    private static String readPolicyDocument(String str) {
        try {
            InputStream inputStream = (InputStream) Objects.requireNonNull(SageMaker.class.getResourceAsStream(str));
            try {
                String utils = Utils.toString(inputStream);
                if (inputStream != null) {
                    inputStream.close();
                }
                return utils;
            } finally {
            }
        } catch (IOException e) {
            throw new AssertionError("Failed to read " + str, e);
        }
    }

    public static Builder builder() {
        return new Builder();
    }
}
