package com.amazon.corretto.crypto.provider;

import java.security.AlgorithmParameters;
import java.security.InvalidKeyException;
import java.security.Key;
import java.security.NoSuchAlgorithmException;
import java.security.SecureRandom;
import java.security.spec.AlgorithmParameterSpec;
import java.security.spec.InvalidKeySpecException;
import java.util.Arrays;
import javax.crypto.CipherSpi;
import javax.crypto.IllegalBlockSizeException;
import javax.crypto.NoSuchPaddingException;
import javax.crypto.SecretKey;
import javax.crypto.ShortBufferException;

/* loaded from: input_file:com/amazon/corretto/crypto/provider/AesKeyWrapPaddingSpi.class */
final class AesKeyWrapPaddingSpi extends CipherSpi {
    private static final int BLOCK_SIZE = 16;
    private final AmazonCorrettoCryptoProvider provider;
    private SecretKey jceKey;
    private byte[] keyBytes;
    private int opmode = -1;
    private final AccessibleByteArrayOutputStream buffer;

    AesKeyWrapPaddingSpi(AmazonCorrettoCryptoProvider amazonCorrettoCryptoProvider) {
        Loader.checkNativeLibraryAvailability();
        this.provider = amazonCorrettoCryptoProvider;
        this.buffer = new AccessibleByteArrayOutputStream();
    }

    private static native int wrapKey(byte[] bArr, byte[] bArr2, int i, byte[] bArr3, int i2);

    private static native int unwrapKey(byte[] bArr, byte[] bArr2, int i, byte[] bArr3, int i2);

    @Override // javax.crypto.CipherSpi
    protected void engineSetMode(String str) throws NoSuchAlgorithmException {
        if (str != null && !"KWP".equals(str)) {
            throw new NoSuchAlgorithmException(str + " cannot be used");
        }
    }

    @Override // javax.crypto.CipherSpi
    protected void engineSetPadding(String str) throws NoSuchPaddingException {
        if (str != null && !"NoPadding".equalsIgnoreCase(str)) {
            throw new NoSuchPaddingException("Unsupported padding " + str);
        }
    }

    @Override // javax.crypto.CipherSpi
    protected int engineGetBlockSize() {
        return BLOCK_SIZE;
    }

    @Override // javax.crypto.CipherSpi
    protected int engineGetKeySize(Key key) throws InvalidKeyException {
        byte[] encoded = key.getEncoded();
        if (encoded == null) {
            throw new InvalidKeyException("Can't encode key to obtain length");
        }
        int length = key.getEncoded().length;
        Arrays.fill(encoded, (byte) 0);
        return Math.multiplyExact(length, 8);
    }

    @Override // javax.crypto.CipherSpi
    protected int engineGetOutputSize(int i) {
        int addExact = Math.addExact(this.buffer.size(), i);
        switch (this.opmode) {
            case HkdfSpec.HKDF_MODE /* 1 */:
            case HkdfSpec.HKDF_EXPAND_MODE /* 3 */:
                return getWrappedLen(addExact);
            case HkdfSpec.HKDF_EXTRACT_MODE /* 2 */:
            case 4:
            default:
                return estimateUnwrappedLen(addExact);
        }
    }

    private static int getWrappedLen(int i) {
        return Math.addExact(Math.addExact(i, i % 8 == 0 ? 0 : 8 - (i % 8)), 8);
    }

    private static int estimateUnwrappedLen(int i) {
        if (i < BLOCK_SIZE) {
            return 8;
        }
        return Math.subtractExact(i, 8);
    }

    @Override // javax.crypto.CipherSpi
    protected byte[] engineGetIV() {
        return null;
    }

    @Override // javax.crypto.CipherSpi
    protected AlgorithmParameters engineGetParameters() {
        return null;
    }

    @Override // javax.crypto.CipherSpi
    protected void engineInit(int i, Key key, SecureRandom secureRandom) throws InvalidKeyException {
        implInit(i, key);
    }

    @Override // javax.crypto.CipherSpi
    protected void engineInit(int i, Key key, AlgorithmParameters algorithmParameters, SecureRandom secureRandom) throws InvalidKeyException {
        implInit(i, key);
    }

    @Override // javax.crypto.CipherSpi
    protected void engineInit(int i, Key key, AlgorithmParameterSpec algorithmParameterSpec, SecureRandom secureRandom) throws InvalidKeyException {
        implInit(i, key);
    }

    private void implInit(int i, Key key) throws InvalidKeyException {
        if (i != 4 && i != 3 && i != 1 && i != 2) {
            throw new UnsupportedOperationException("Unsupported mode");
        }
        if (key == null) {
            throw new InvalidKeyException("Null key");
        }
        if (key != this.jceKey) {
            if (!(key instanceof SecretKey)) {
                throw new InvalidKeyException("Need a SecretKey");
            }
            if (!"RAW".equalsIgnoreCase(key.getFormat())) {
                throw new InvalidKeyException("Need a raw format key");
            }
            if (!"AES".equalsIgnoreCase(key.getAlgorithm())) {
                throw new InvalidKeyException("Expected an AES key");
            }
            if (this.keyBytes != null) {
                Arrays.fill(this.keyBytes, (byte) 0);
                this.keyBytes = null;
            }
            this.keyBytes = key.getEncoded();
            if (this.keyBytes == null) {
                throw new InvalidKeyException("Key doesn't support encoding");
            }
            this.jceKey = (SecretKey) key;
        }
        this.opmode = i;
        this.buffer.reset();
    }

    @Override // javax.crypto.CipherSpi
    protected byte[] engineUpdate(byte[] bArr, int i, int i2) {
        if (this.opmode != 1 && this.opmode != 2) {
            throw new IllegalStateException("Cipher not initialized for update");
        }
        implUpdate(bArr, i, i2);
        return null;
    }

    @Override // javax.crypto.CipherSpi
    protected int engineUpdate(byte[] bArr, int i, int i2, byte[] bArr2, int i3) throws ShortBufferException {
        if (this.opmode != 1 && this.opmode != 2) {
            throw new IllegalStateException("Cipher not initialized for update");
        }
        implUpdate(bArr, i, i2);
        return 0;
    }

    private void implUpdate(byte[] bArr, int i, int i2) {
        if (bArr == null || bArr.length <= 0 || Math.addExact(i, i2) > bArr.length) {
            return;
        }
        this.buffer.write(bArr, i, i2);
    }

    @Override // javax.crypto.CipherSpi
    protected byte[] engineDoFinal(byte[] bArr, int i, int i2) {
        if (this.opmode == 1 || this.opmode == 2) {
            return implDoFinal(bArr, i, i2);
        }
        throw new IllegalStateException("Cipher not initialized for finalization");
    }

    @Override // javax.crypto.CipherSpi
    protected int engineDoFinal(byte[] bArr, int i, int i2, byte[] bArr2, int i3) throws ShortBufferException {
        if (this.opmode != 1 && this.opmode != 2) {
            throw new IllegalStateException("Cipher not initialized for finalization");
        }
        int engineGetOutputSize = engineGetOutputSize(i2);
        if (bArr2.length - i3 < engineGetOutputSize) {
            throw new ShortBufferException("Output buffer needs size of at least " + engineGetOutputSize);
        }
        return implDoFinal(bArr, i, i2, bArr2, i3);
    }

    private byte[] implDoFinal(byte[] bArr, int i, int i2) {
        int engineGetOutputSize = engineGetOutputSize(i2);
        byte[] bArr2 = new byte[engineGetOutputSize];
        int implDoFinal = implDoFinal(bArr, i, i2, bArr2, 0);
        if (implDoFinal < engineGetOutputSize) {
            byte[] bArr3 = new byte[implDoFinal];
            System.arraycopy(bArr2, 0, bArr3, 0, bArr3.length);
            Arrays.fill(bArr2, (byte) 0);
            bArr2 = bArr3;
        }
        return bArr2;
    }

    private int implDoFinal(byte[] bArr, int i, int i2, byte[] bArr2, int i3) {
        int unwrapKey;
        implUpdate(bArr, i, i2);
        try {
            switch (this.opmode) {
                case HkdfSpec.HKDF_MODE /* 1 */:
                case HkdfSpec.HKDF_EXPAND_MODE /* 3 */:
                    unwrapKey = wrapKey(this.keyBytes, this.buffer.getDataBuffer(), this.buffer.size(), bArr2, i3);
                    break;
                case HkdfSpec.HKDF_EXTRACT_MODE /* 2 */:
                case 4:
                    unwrapKey = unwrapKey(this.keyBytes, this.buffer.getDataBuffer(), this.buffer.size(), bArr2, i3);
                    break;
                default:
                    throw new IllegalStateException("Cipher not initialized for finalization");
            }
            return unwrapKey;
        } finally {
            this.buffer.reset();
        }
    }

    @Override // javax.crypto.CipherSpi
    protected byte[] engineWrap(Key key) throws IllegalBlockSizeException, InvalidKeyException {
        if (this.opmode != 3) {
            throw new IllegalStateException("Cipher must be init'd in WRAP_MODE");
        }
        byte[] bArr = null;
        try {
            try {
                bArr = Utils.encodeForWrapping(this.provider, key);
                byte[] implDoFinal = implDoFinal(bArr, 0, bArr.length);
                if (bArr != null) {
                    Arrays.fill(bArr, (byte) 0);
                }
                return implDoFinal;
            } catch (RuntimeCryptoException e) {
                throw new InvalidKeyException("Wrapping failed", e);
            }
        } catch (Throwable th) {
            if (bArr != null) {
                Arrays.fill(bArr, (byte) 0);
            }
            throw th;
        }
    }

    @Override // javax.crypto.CipherSpi
    protected Key engineUnwrap(byte[] bArr, String str, int i) throws InvalidKeyException, NoSuchAlgorithmException {
        if (this.opmode != 4) {
            throw new IllegalStateException("Cipher must be init'd in UNWRAP_MODE");
        }
        byte[] bArr2 = null;
        try {
            try {
                bArr2 = implDoFinal(bArr, 0, bArr.length);
                Key buildUnwrappedKey = Utils.buildUnwrappedKey(this.provider, bArr2, str, i);
                if (bArr2 != null) {
                    Arrays.fill(bArr2, (byte) 0);
                }
                return buildUnwrappedKey;
            } catch (RuntimeCryptoException | InvalidKeySpecException e) {
                throw new InvalidKeyException("Unwrapping failed", e);
            }
        } catch (Throwable th) {
            if (bArr2 != null) {
                Arrays.fill(bArr2, (byte) 0);
            }
            throw th;
        }
    }

    static {
        Loader.load();
    }
}
