package se.litsec.swedisheid.opensaml.saml2.signservice;

import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory;
import com.nimbusds.jose.proc.JWSVerifierFactory;
import com.nimbusds.jwt.SignedJWT;
import java.io.IOException;
import java.nio.charset.Charset;
import java.security.cert.X509Certificate;
import java.text.ParseException;
import java.time.Instant;
import java.util.Arrays;
import java.util.Base64;
import java.util.Collections;
import java.util.Iterator;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import net.shibboleth.utilities.java.support.resolver.ResolverException;
import org.opensaml.core.xml.io.MarshallingException;
import org.opensaml.saml.saml2.core.Assertion;
import org.opensaml.saml.saml2.core.Attribute;
import org.opensaml.saml.saml2.core.AttributeStatement;
import org.opensaml.saml.saml2.core.AuthnRequest;
import org.opensaml.saml.saml2.core.AuthnStatement;
import org.opensaml.saml.saml2.metadata.EntityDescriptor;
import org.opensaml.security.credential.UsageType;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import se.litsec.opensaml.saml2.attribute.AttributeUtils;
import se.litsec.opensaml.saml2.metadata.MetadataUtils;
import se.litsec.opensaml.saml2.metadata.provider.MetadataProvider;
import se.litsec.opensaml.saml2.metadata.provider.StaticMetadataProvider;
import se.litsec.swedisheid.opensaml.saml2.attribute.AttributeConstants;
import se.litsec.swedisheid.opensaml.saml2.signservice.SADValidationException;
import se.litsec.swedisheid.opensaml.saml2.signservice.sap.SAD;
import se.litsec.swedisheid.opensaml.saml2.signservice.sap.SADRequest;

/* loaded from: input_file:se/litsec/swedisheid/opensaml/saml2/signservice/SADParser.class */
public class SADParser {

    /* loaded from: input_file:se/litsec/swedisheid/opensaml/saml2/signservice/SADParser$SADValidator.class */
    public static class SADValidator {
        private Logger logger = LoggerFactory.getLogger(SADValidator.class);
        private List<X509Certificate> validationCertificates;
        private MetadataProvider metadataProvider;
        private static final JWSVerifierFactory verifierFactory = new DefaultJWSVerifierFactory();

        public SADValidator(X509Certificate... x509CertificateArr) {
            this.validationCertificates = Arrays.asList(x509CertificateArr);
        }

        public SADValidator(MetadataProvider metadataProvider) {
            this.metadataProvider = metadataProvider;
        }

        public SADValidator(EntityDescriptor entityDescriptor) {
            try {
                this.metadataProvider = new StaticMetadataProvider(entityDescriptor);
            } catch (MarshallingException e) {
                throw new SecurityException("Invalid IdP metadata", e);
            }
        }

        public SAD validate(AuthnRequest authnRequest, Assertion assertion) throws SADValidationException, IllegalArgumentException {
            long currentTimeMillis = System.currentTimeMillis() / 1000;
            SADRequest sADRequest = null;
            if (authnRequest.getExtensions() != null) {
                Stream stream = authnRequest.getExtensions().getUnknownXMLObjects().stream();
                Class<SADRequest> cls = SADRequest.class;
                Objects.requireNonNull(SADRequest.class);
                Stream filter = stream.filter((v1) -> {
                    return r1.isInstance(v1);
                });
                Class<SADRequest> cls2 = SADRequest.class;
                Objects.requireNonNull(SADRequest.class);
                sADRequest = (SADRequest) filter.map((v1) -> {
                    return r1.cast(v1);
                }).findFirst().orElse(null);
            }
            if (sADRequest == null) {
                String format = String.format("AuthnRequest '%s' does not contain a SADRequest", authnRequest.getID());
                this.logger.info(format);
                throw new IllegalArgumentException(format);
            }
            if (assertion.getAttributeStatements().isEmpty()) {
                String format2 = String.format("Assertion '%s' does not contain any attributes (and thus no SAD)", assertion.getID());
                this.logger.info(format2);
                throw new SADValidationException(SADValidationException.ErrorCode.NO_SAD_ATTRIBUTE, format2);
            }
            List attributes = ((AttributeStatement) assertion.getAttributeStatements().get(0)).getAttributes();
            Attribute attribute = (Attribute) AttributeUtils.getAttribute(AttributeConstants.ATTRIBUTE_NAME_SAD, attributes).orElse(null);
            if (attribute == null) {
                String format3 = String.format("Assertion '%s' does not contain a SAD attribute", assertion.getID());
                this.logger.info(format3);
                throw new SADValidationException(SADValidationException.ErrorCode.NO_SAD_ATTRIBUTE, format3);
            }
            try {
                SignedJWT parse = SignedJWT.parse(AttributeUtils.getAttributeStringValue(attribute));
                SAD fromJson = SAD.fromJson(new String(Base64.getUrlDecoder().decode(parse.getPayload().toBase64URL().toString()), Charset.forName("UTF-8")));
                if (fromJson.getSeElnSadext() == null) {
                    this.logger.info("seElnSadext extension claims are missing from SAD");
                    throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, "seElnSadext extension claims are missing from SAD");
                }
                if (fromJson.getSeElnSadext().getAttributeName() == null) {
                    this.logger.info("SAD does not contain the attribute name (attr) for the subject");
                    throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, "SAD does not contain the attribute name (attr) for the subject");
                }
                Attribute attribute2 = (Attribute) AttributeUtils.getAttribute(fromJson.getSeElnSadext().getAttributeName(), attributes).orElse(null);
                if (attribute2 == null) {
                    String format4 = String.format("Assertion '%s' does not contain a '%s' attribute - this is listed as the subject attribute in the SAD", assertion.getID(), fromJson.getSeElnSadext().getAttributeName());
                    this.logger.info(format4);
                    throw new SADValidationException(SADValidationException.ErrorCode.MISSING_SUBJECT_ATTRIBUTE, format4);
                }
                String loa = getLoa(assertion);
                if (loa == null) {
                    String format5 = String.format("Assertion '%s' does not contain a LoA URI", assertion.getID());
                    this.logger.error(format5);
                    throw new IllegalArgumentException(format5);
                }
                if (sADRequest.getDocCount() == null) {
                    throw new IllegalArgumentException("Bad SADRequest - missing DocCount");
                }
                return validate(parse, fromJson, currentTimeMillis, assertion.getIssuer().getValue(), sADRequest.getRequesterID(), AttributeUtils.getAttributeStringValue(attribute2), loa, sADRequest.getID(), sADRequest.getDocCount().intValue(), sADRequest.getSignRequestID());
            } catch (IOException | ParseException e) {
                throw new SADValidationException(SADValidationException.ErrorCode.JWT_PARSE_ERROR, "Failed to parse SAD JWT", e);
            }
        }

        public SAD validate(String str, String str2, String str3, String str4, String str5, String str6, int i, String str7) throws SADValidationException {
            long currentTimeMillis = System.currentTimeMillis() / 1000;
            try {
                SignedJWT parse = SignedJWT.parse(str);
                return validate(parse, SAD.fromJson(new String(Base64.getUrlDecoder().decode(parse.getPayload().toBase64URL().toString()), Charset.forName("UTF-8"))), currentTimeMillis, str2, str3, str4, str5, str6, i, str7);
            } catch (IOException | ParseException e) {
                throw new SADValidationException(SADValidationException.ErrorCode.JWT_PARSE_ERROR, "Failed to parse SAD JWT", e);
            }
        }

        private SAD validate(SignedJWT signedJWT, SAD sad, long j, String str, String str2, String str3, String str4, String str5, int i, String str6) throws SADValidationException {
            verifyJwtSignature(signedJWT, str);
            if (sad.getJwtId() == null || sad.getJwtId().isEmpty()) {
                this.logger.info("Invalid SAD JWT - jti is missing");
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, "Invalid SAD JWT - jti is missing");
            }
            if (!Objects.equals(str, sad.getIssuer())) {
                String format = String.format("SAD contains issuer '%s' - expected '%s'", sad.getIssuer(), str);
                this.logger.info(format);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_ISSUER, format);
            }
            if (!Objects.equals(str2, sad.getAudience())) {
                String format2 = String.format("SAD contains audience '%s' - expected '%s'", sad.getAudience(), str2);
                this.logger.info(format2);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_AUDIENCE, format2);
            }
            if (sad.getExpiry() == null || sad.getIssuedAt() == null) {
                this.logger.info("SAD is missing 'exp' and/or 'iat' - Invalid SAD");
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, "SAD is missing 'exp' and/or 'iat' - Invalid SAD");
            }
            if (sad.getExpiry().intValue() < j) {
                String format3 = String.format("SAD has expired - expiration: '%s', current time: '%s'", sad.getExpiryDateTime(), Instant.ofEpochSecond(j));
                this.logger.info(format3);
                throw new SADValidationException(SADValidationException.ErrorCode.SAD_EXPIRED, format3);
            }
            if (sad.getIssuedAt().intValue() > j) {
                String format4 = String.format("SAD is not yet valid - issue-time: '%s', current time: '%s'", sad.getIssuedAtDateTime(), Instant.ofEpochSecond(j));
                this.logger.info(format4);
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, format4);
            }
            if (!Objects.equals(str3, sad.getSubject())) {
                String format5 = String.format("SAD contains subject '%s' - expected '%s'", sad.getSubject(), str3);
                this.logger.info(format5);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_SUBJECT, format5);
            }
            if (sad.getSeElnSadext() == null) {
                this.logger.info("seElnSadext extension claims are missing from SAD");
                throw new SADValidationException(SADValidationException.ErrorCode.BAD_SAD_FORMAT, "seElnSadext extension claims are missing from SAD");
            }
            if (!Objects.equals(str5, sad.getSeElnSadext().getInResponseTo())) {
                String format6 = String.format("SAD contains in-response-to (irt) '%s' - expected SAD to belong to SADRequest with ID '%s'", sad.getSeElnSadext().getInResponseTo(), str5);
                this.logger.info(format6);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_IRT, format6);
            }
            if (!Objects.equals(str4, sad.getSeElnSadext().getLoa())) {
                String format7 = String.format("SAD contains LoA '%s' - expected '%s'", sad.getSeElnSadext().getLoa(), str4);
                this.logger.info(format7);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_LOA, format7);
            }
            if (!Objects.equals(Integer.valueOf(i), sad.getSeElnSadext().getNumberOfDocuments())) {
                String format8 = String.format("SAD indicated '%s' number of documents - expected '%d'", sad.getSeElnSadext().getNumberOfDocuments(), Integer.valueOf(i));
                this.logger.info(format8);
                throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_DOCS, format8);
            }
            if (Objects.equals(str6, sad.getSeElnSadext().getRequestID())) {
                this.logger.debug("SAD with ID '{}' was successfully validated", sad.getJwtId());
                return sad;
            }
            String format9 = String.format("SAD contains SignRequest ID (reqid) '%s' - expected '%s'", sad.getSeElnSadext().getRequestID(), str6);
            this.logger.info(format9);
            throw new SADValidationException(SADValidationException.ErrorCode.VALIDATION_BAD_SIGNREQUESTID, format9);
        }

        public void verifyJwtSignature(String str, String str2) throws SADValidationException {
            try {
                verifyJwtSignature(SignedJWT.parse(str), str2);
            } catch (ParseException e) {
                throw new SADValidationException(SADValidationException.ErrorCode.JWT_PARSE_ERROR, "Failed to parse SAD JWT", e);
            }
        }

        private void verifyJwtSignature(SignedJWT signedJWT, String str) throws SADValidationException {
            try {
                List<X509Certificate> validationCertificates = getValidationCertificates(str);
                if (validationCertificates.isEmpty()) {
                    throw new SADValidationException(SADValidationException.ErrorCode.SIGNATURE_VALIDATION_ERROR, "No suitable IdP signature certificate was found - can not verify SAD JWT signature");
                }
                this.logger.debug("Verifying SAD JWT signature. Will try {} IdP key(s) ...", Integer.valueOf(validationCertificates.size()));
                boolean z = false;
                Iterator<X509Certificate> it = validationCertificates.iterator();
                while (it.hasNext()) {
                    try {
                    } catch (JOSEException e) {
                        this.logger.debug("Failed to perform signature validation of SAD JWT - {}", e.getMessage());
                        this.logger.trace("", e);
                    }
                    if (verifierFactory.createJWSVerifier(signedJWT.getHeader(), it.next().getPublicKey()).verify(signedJWT.getHeader(), signedJWT.getSigningInput(), signedJWT.getSignature())) {
                        this.logger.debug("SAD JWT signature successfully verified");
                        z = true;
                        break;
                    }
                }
                if (!z) {
                    throw new SADValidationException(SADValidationException.ErrorCode.SIGNATURE_VALIDATION_ERROR, "Signature on SAD JWT could not be validated using any of the IdP certificates found");
                }
            } catch (ResolverException e2) {
                throw new SADValidationException(SADValidationException.ErrorCode.SIGNATURE_VALIDATION_ERROR, "Failed to find validation certificate", e2);
            }
        }

        private List<X509Certificate> getValidationCertificates(String str) throws ResolverException {
            if (this.validationCertificates != null && !this.validationCertificates.isEmpty()) {
                return this.validationCertificates;
            }
            if (this.metadataProvider == null) {
                return Collections.emptyList();
            }
            EntityDescriptor entityDescriptor = this.metadataProvider.getEntityDescriptor(str);
            if (entityDescriptor != null) {
                return (List) MetadataUtils.getMetadataCertificates(entityDescriptor, UsageType.SIGNING).stream().map((v0) -> {
                    return v0.getEntityCertificate();
                }).collect(Collectors.toList());
            }
            this.logger.warn("No metadata found for IdP '{}' - cannot find key to use when verifying SAD JWT signature", str);
            return Collections.emptyList();
        }

        private static String getLoa(Assertion assertion) {
            try {
                return ((AuthnStatement) assertion.getAuthnStatements().get(0)).getAuthnContext().getAuthnContextClassRef().getURI();
            } catch (Exception e) {
                return null;
            }
        }
    }

    private SADParser() {
    }

    public static SAD parse(String str) throws IOException {
        try {
            return SAD.fromJson(new String(Base64.getUrlDecoder().decode(SignedJWT.parse(str).getPayload().toBase64URL().toString()), Charset.forName("UTF-8")));
        } catch (ParseException e) {
            throw new IOException(e);
        }
    }

    public static SADValidator getValidator(X509Certificate... x509CertificateArr) {
        return new SADValidator(x509CertificateArr);
    }

    public static SADValidator getValidator(MetadataProvider metadataProvider) {
        return new SADValidator(metadataProvider);
    }

    public static SADValidator getValidator(EntityDescriptor entityDescriptor) {
        return new SADValidator(entityDescriptor);
    }
}
