package net.trajano.ms.vertx.jaxrs;

import com.google.common.cache.Cache;
import com.google.common.cache.CacheBuilder;
import com.nimbusds.jose.JOSEException;
import com.nimbusds.jose.JOSEObject;
import com.nimbusds.jose.JWEObject;
import com.nimbusds.jose.JWSObject;
import com.nimbusds.jose.crypto.RSADecrypter;
import com.nimbusds.jose.crypto.RSASSAVerifier;
import com.nimbusds.jose.jwk.JWKSet;
import com.nimbusds.jose.jwk.RSAKey;
import com.nimbusds.jwt.JWTClaimsSet;
import java.io.IOException;
import java.net.URI;
import java.text.ParseException;
import java.util.concurrent.ExecutionException;
import javax.annotation.PostConstruct;
import javax.annotation.Priority;
import javax.ws.rs.BadRequestException;
import javax.ws.rs.container.ContainerRequestContext;
import javax.ws.rs.container.ContainerRequestFilter;
import javax.ws.rs.container.ResourceInfo;
import javax.ws.rs.core.Context;
import javax.ws.rs.core.Response;
import javax.ws.rs.ext.Provider;
import net.trajano.ms.common.oauth.OAuthTokenResponse;
import net.trajano.ms.vertx.beans.DefaultAssertionRequiredPredicate;
import net.trajano.ms.vertx.beans.JwksProvider;
import net.trajano.ms.vertx.beans.JwtAssertionRequiredPredicate;
import net.trajano.ms.vertx.beans.JwtClaimsProcessor;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.stereotype.Component;

@Provider
@Priority(2000)
@Component
/* loaded from: input_file:net/trajano/ms/vertx/jaxrs/JwtAssertionInterceptor.class */
public class JwtAssertionInterceptor implements ContainerRequestFilter {
    private static final Logger LOG = LoggerFactory.getLogger((Class<?>) JwtAssertionInterceptor.class);
    private static final long MAX_NUMBER_OF_KEYS = 20;
    private JwtAssertionRequiredPredicate assertionRequiredPredicate;

    @Autowired(required = false)
    @Qualifier("authz.audience")
    private URI audience;
    private JwtClaimsProcessor claimsProcessor;

    @Autowired(required = false)
    @Qualifier("authz.issuer")
    private URI issuer;
    private JwksProvider jwksProvider;

    @Autowired
    private JwksUriProvider jwksUriProvider;
    private Cache<String, RSAKey> keyCache;

    @Context
    private ResourceInfo resourceInfo;

    @Override // javax.ws.rs.container.ContainerRequestFilter
    public void filter(ContainerRequestContext containerRequestContext) throws IOException {
        if (this.assertionRequiredPredicate.test(this.resourceInfo)) {
            String headerString = containerRequestContext.getHeaderString("X-JWT-Assertion");
            if (headerString == null) {
                LOG.warn("Missing assertion on request for {}", containerRequestContext.getUriInfo());
                containerRequestContext.abortWith(Response.status(Response.Status.UNAUTHORIZED).header("WWW-Authenticate", "JWT").entity("missing assertion").build());
                return;
            }
            LOG.debug("assertion={}", headerString);
            try {
                JOSEObject parse = JOSEObject.parse(headerString);
                if (parse instanceof JWEObject) {
                    JWEObject jWEObject = (JWEObject) parse;
                    jWEObject.decrypt(new RSADecrypter(this.jwksProvider.getDecryptionKey(jWEObject.getHeader().getKeyID())));
                    parse = JOSEObject.parse(jWEObject.getPayload().toString());
                }
                if (parse instanceof JWSObject) {
                    JWSObject jWSObject = (JWSObject) parse;
                    URI uri = this.jwksUriProvider.getUri(containerRequestContext);
                    if (uri != null && !jWSObject.verify(new RSASSAVerifier(getSigningKey(jWSObject.getHeader().getKeyID(), uri)))) {
                        LOG.warn("JWT verification failed for {}", containerRequestContext.getUriInfo());
                        containerRequestContext.abortWith(Response.status(Response.Status.UNAUTHORIZED).header("WWW-Authenticate", "JWT").entity("signature vertification failed").build());
                        return;
                    }
                }
                JWTClaimsSet parse2 = JWTClaimsSet.parse(parse.getPayload().toString());
                if (this.audience != null && !parse2.getAudience().contains(this.audience.toASCIIString())) {
                    LOG.warn("Audience {} did not match {} for {}", parse2.getAudience(), this.audience, containerRequestContext.getUriInfo());
                    containerRequestContext.abortWith(OAuthTokenResponse.unauthorized("invalid_audience", "Audience validation failed", "Bearer").getResponse());
                    return;
                }
                if (this.issuer != null && !parse2.getIssuer().equals(this.issuer.toASCIIString())) {
                    LOG.warn("Issuer {} did not match {} for {}", parse2.getIssuer(), this.issuer, containerRequestContext.getUriInfo());
                    containerRequestContext.abortWith(OAuthTokenResponse.unauthorized("invalid_issuer", "Issuer validation failed", "Bearer").getResponse());
                    return;
                }
                if (parse2.getExpirationTime() != null && parse2.getExpirationTime().before(containerRequestContext.getDate())) {
                    LOG.warn("Claims expired for {}", containerRequestContext.getUriInfo());
                    containerRequestContext.abortWith(OAuthTokenResponse.unauthorized("invalid_claims", "Claims expired", "Bearer").getResponse());
                    return;
                }
                containerRequestContext.setSecurityContext(new JwtSecurityContext(parse2, containerRequestContext.getUriInfo()));
                if (this.claimsProcessor != null) {
                    boolean booleanValue = this.claimsProcessor.apply(parse2).booleanValue();
                    LOG.debug("{}.validateClaims result={}", this.claimsProcessor, Boolean.valueOf(booleanValue));
                    if (booleanValue) {
                        return;
                    }
                    LOG.warn("Validation of claims failed on request for {}", containerRequestContext.getUriInfo());
                    containerRequestContext.abortWith(Response.status(Response.Status.FORBIDDEN).entity("claims validation failed").build());
                }
            } catch (JOSEException | ParseException | ExecutionException e) {
                throw new BadRequestException("unable to parse JWT");
            }
        }
    }

    private RSAKey getSigningKey(String str, URI uri) throws ExecutionException {
        RSAKey rSAKey = this.keyCache.get(str, () -> {
            JWKSet.load(uri.toURL()).getKeys().forEach(jwk -> {
                this.keyCache.put(jwk.getKeyID(), (RSAKey) jwk);
            });
            return this.keyCache.getIfPresent(str);
        });
        if (rSAKey == null) {
            LOG.error("kid={} was not found in the key cache or {}", str, uri);
        }
        return rSAKey;
    }

    @PostConstruct
    public void init() {
        this.keyCache = CacheBuilder.newBuilder().maximumSize(20L).build();
        if (this.audience == null) {
            LOG.warn("`authz.audience` was not specified, will accept any audience");
        }
        if (this.issuer == null) {
            LOG.warn("`authz.issuer` was not specified, will accept any issuer");
        }
        if (this.claimsProcessor == null) {
            LOG.warn("JwtClaimsProcessor was not defined, will not peform any claims validation");
        }
        if (this.assertionRequiredPredicate == null) {
            LOG.debug("assertionRequiredPredicate was not defined, default annotation based predicate will be used");
            this.assertionRequiredPredicate = new DefaultAssertionRequiredPredicate();
        }
    }

    @Autowired(required = false)
    public void setAssertionRequiredFunction(JwtAssertionRequiredPredicate jwtAssertionRequiredPredicate) {
        this.assertionRequiredPredicate = jwtAssertionRequiredPredicate;
    }

    @Autowired(required = false)
    public void setClaimsProcessor(JwtClaimsProcessor jwtClaimsProcessor) {
        this.claimsProcessor = jwtClaimsProcessor;
    }

    @Autowired
    public void setJwksProvider(JwksProvider jwksProvider) {
        this.jwksProvider = jwksProvider;
    }
}
