package org.springframework.security.oauth2.client.web.reactive.function.client;

import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Clock;
import java.time.Duration;
import java.time.temporal.TemporalAmount;
import java.util.Base64;
import java.util.Collection;
import java.util.Map;
import java.util.Optional;
import java.util.function.Consumer;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.security.core.Authentication;
import org.springframework.security.core.GrantedAuthority;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.oauth2.client.ClientAuthorizationRequiredException;
import org.springframework.security.oauth2.client.OAuth2AuthorizedClient;
import org.springframework.security.oauth2.client.authentication.OAuth2AuthenticationToken;
import org.springframework.security.oauth2.client.registration.ClientRegistration;
import org.springframework.security.oauth2.client.web.OAuth2AuthorizedClientRepository;
import org.springframework.security.oauth2.core.AuthorizationGrantType;
import org.springframework.security.oauth2.core.web.reactive.function.OAuth2BodyExtractors;
import org.springframework.security.web.http.SecurityHeaders;
import org.springframework.util.Assert;
import org.springframework.web.context.request.RequestContextHolder;
import org.springframework.web.context.request.ServletRequestAttributes;
import org.springframework.web.reactive.function.BodyInserters;
import org.springframework.web.reactive.function.client.ClientRequest;
import org.springframework.web.reactive.function.client.ClientResponse;
import org.springframework.web.reactive.function.client.ExchangeFilterFunction;
import org.springframework.web.reactive.function.client.ExchangeFunction;
import org.springframework.web.reactive.function.client.WebClient;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Schedulers;

/* loaded from: input_file:org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction.class */
public final class ServletOAuth2AuthorizedClientExchangeFilterFunction implements ExchangeFilterFunction {
    private static final String OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME = OAuth2AuthorizedClient.class.getName();
    private static final String CLIENT_REGISTRATION_ID_ATTR_NAME = OAuth2AuthorizedClient.class.getName().concat(".CLIENT_REGISTRATION_ID");
    private static final String AUTHENTICATION_ATTR_NAME = Authentication.class.getName();
    private static final String HTTP_SERVLET_REQUEST_ATTR_NAME = HttpServletRequest.class.getName();
    private static final String HTTP_SERVLET_RESPONSE_ATTR_NAME = HttpServletResponse.class.getName();
    private Clock clock = Clock.systemUTC();
    private Duration accessTokenExpiresSkew = Duration.ofMinutes(1);
    private OAuth2AuthorizedClientRepository authorizedClientRepository;

    /* loaded from: input_file:org/springframework/security/oauth2/client/web/reactive/function/client/ServletOAuth2AuthorizedClientExchangeFilterFunction$PrincipalNameAuthentication.class */
    private static class PrincipalNameAuthentication implements Authentication {
        private final String username;

        private PrincipalNameAuthentication(String str) {
            this.username = str;
        }

        public Collection<? extends GrantedAuthority> getAuthorities() {
            throw unsupported();
        }

        public Object getCredentials() {
            throw unsupported();
        }

        public Object getDetails() {
            throw unsupported();
        }

        public Object getPrincipal() {
            throw unsupported();
        }

        public boolean isAuthenticated() {
            throw unsupported();
        }

        public void setAuthenticated(boolean z) throws IllegalArgumentException {
            throw unsupported();
        }

        public String getName() {
            return this.username;
        }

        private UnsupportedOperationException unsupported() {
            return new UnsupportedOperationException("Not Supported");
        }
    }

    public ServletOAuth2AuthorizedClientExchangeFilterFunction() {
    }

    public ServletOAuth2AuthorizedClientExchangeFilterFunction(OAuth2AuthorizedClientRepository oAuth2AuthorizedClientRepository) {
        this.authorizedClientRepository = oAuth2AuthorizedClientRepository;
    }

    public Consumer<WebClient.Builder> oauth2Configuration() {
        return builder -> {
            builder.defaultRequest(defaultRequest()).filter(this);
        };
    }

    public Consumer<WebClient.RequestHeadersSpec<?>> defaultRequest() {
        return requestHeadersSpec -> {
            requestHeadersSpec.attributes(map -> {
                populateDefaultRequestResponse(map);
                populateDefaultAuthentication(map);
                populateDefaultOAuth2AuthorizedClient(map);
            });
        };
    }

    public static Consumer<Map<String, Object>> oauth2AuthorizedClient(OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        return map -> {
            if (oAuth2AuthorizedClient == null) {
                map.remove(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
            } else {
                map.put(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME, oAuth2AuthorizedClient);
            }
        };
    }

    public static Consumer<Map<String, Object>> clientRegistrationId(String str) {
        return map -> {
            map.put(CLIENT_REGISTRATION_ID_ATTR_NAME, str);
        };
    }

    public static Consumer<Map<String, Object>> authentication(Authentication authentication) {
        return map -> {
            map.put(AUTHENTICATION_ATTR_NAME, authentication);
        };
    }

    public static Consumer<Map<String, Object>> httpServletRequest(HttpServletRequest httpServletRequest) {
        return map -> {
            map.put(HTTP_SERVLET_REQUEST_ATTR_NAME, httpServletRequest);
        };
    }

    public static Consumer<Map<String, Object>> httpServletResponse(HttpServletResponse httpServletResponse) {
        return map -> {
            map.put(HTTP_SERVLET_RESPONSE_ATTR_NAME, httpServletResponse);
        };
    }

    public void setAccessTokenExpiresSkew(Duration duration) {
        Assert.notNull(duration, "accessTokenExpiresSkew cannot be null");
        this.accessTokenExpiresSkew = duration;
    }

    public Mono<ClientResponse> filter(ClientRequest clientRequest, ExchangeFunction exchangeFunction) {
        Optional attribute = clientRequest.attribute(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
        Class<OAuth2AuthorizedClient> cls = OAuth2AuthorizedClient.class;
        OAuth2AuthorizedClient.class.getClass();
        Mono map = Mono.justOrEmpty(attribute.map(cls::cast)).flatMap(oAuth2AuthorizedClient -> {
            return authorizedClient(clientRequest, exchangeFunction, oAuth2AuthorizedClient);
        }).map(oAuth2AuthorizedClient2 -> {
            return bearer(clientRequest, oAuth2AuthorizedClient2);
        });
        exchangeFunction.getClass();
        return map.flatMap(exchangeFunction::exchange).switchIfEmpty(exchangeFunction.exchange(clientRequest));
    }

    private void populateDefaultRequestResponse(Map<String, Object> map) {
        if (map.containsKey(HTTP_SERVLET_REQUEST_ATTR_NAME) && map.containsKey(HTTP_SERVLET_RESPONSE_ATTR_NAME)) {
            return;
        }
        ServletRequestAttributes requestAttributes = RequestContextHolder.getRequestAttributes();
        HttpServletRequest httpServletRequest = null;
        HttpServletResponse httpServletResponse = null;
        if (requestAttributes != null) {
            httpServletRequest = requestAttributes.getRequest();
            httpServletResponse = requestAttributes.getResponse();
        }
        map.putIfAbsent(HTTP_SERVLET_REQUEST_ATTR_NAME, httpServletRequest);
        map.putIfAbsent(HTTP_SERVLET_RESPONSE_ATTR_NAME, httpServletResponse);
    }

    private void populateDefaultAuthentication(Map<String, Object> map) {
        if (map.containsKey(AUTHENTICATION_ATTR_NAME)) {
            return;
        }
        map.putIfAbsent(AUTHENTICATION_ATTR_NAME, SecurityContextHolder.getContext().getAuthentication());
    }

    private void populateDefaultOAuth2AuthorizedClient(Map<String, Object> map) {
        if (this.authorizedClientRepository == null || map.containsKey(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME)) {
            return;
        }
        OAuth2AuthenticationToken authentication = getAuthentication(map);
        String clientRegistrationId = getClientRegistrationId(map);
        if (clientRegistrationId == null && (authentication instanceof OAuth2AuthenticationToken)) {
            clientRegistrationId = authentication.getAuthorizedClientRegistrationId();
        }
        if (clientRegistrationId != null) {
            OAuth2AuthorizedClient loadAuthorizedClient = this.authorizedClientRepository.loadAuthorizedClient(clientRegistrationId, authentication, (HttpServletRequest) map.get(HTTP_SERVLET_REQUEST_ATTR_NAME));
            if (loadAuthorizedClient == null) {
                throw new ClientAuthorizationRequiredException(clientRegistrationId);
            }
            oauth2AuthorizedClient(loadAuthorizedClient).accept(map);
        }
    }

    private Mono<OAuth2AuthorizedClient> authorizedClient(ClientRequest clientRequest, ExchangeFunction exchangeFunction, OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        return shouldRefresh(oAuth2AuthorizedClient) ? refreshAuthorizedClient(clientRequest, exchangeFunction, oAuth2AuthorizedClient) : Mono.just(oAuth2AuthorizedClient);
    }

    private Mono<OAuth2AuthorizedClient> refreshAuthorizedClient(ClientRequest clientRequest, ExchangeFunction exchangeFunction, OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        ClientRegistration clientRegistration = oAuth2AuthorizedClient.getClientRegistration();
        return exchangeFunction.exchange(ClientRequest.create(HttpMethod.POST, URI.create(clientRegistration.getProviderDetails().getTokenUri())).header("Accept", new String[]{"application/json"}).headers(httpBasic(clientRegistration.getClientId(), clientRegistration.getClientSecret())).body(refreshTokenBody(oAuth2AuthorizedClient.getRefreshToken().getTokenValue())).build()).flatMap(clientResponse -> {
            return (Mono) clientResponse.body(OAuth2BodyExtractors.oauth2AccessTokenResponse());
        }).map(oAuth2AccessTokenResponse -> {
            return new OAuth2AuthorizedClient(oAuth2AuthorizedClient.getClientRegistration(), oAuth2AuthorizedClient.getPrincipalName(), oAuth2AccessTokenResponse.getAccessToken(), oAuth2AccessTokenResponse.getRefreshToken());
        }).map(oAuth2AuthorizedClient2 -> {
            this.authorizedClientRepository.saveAuthorizedClient(oAuth2AuthorizedClient2, (Authentication) clientRequest.attribute(AUTHENTICATION_ATTR_NAME).orElse(new PrincipalNameAuthentication(oAuth2AuthorizedClient.getPrincipalName())), (HttpServletRequest) clientRequest.attributes().get(HTTP_SERVLET_REQUEST_ATTR_NAME), (HttpServletResponse) clientRequest.attributes().get(HTTP_SERVLET_RESPONSE_ATTR_NAME));
            return oAuth2AuthorizedClient2;
        }).publishOn(Schedulers.elastic());
    }

    private static Consumer<HttpHeaders> httpBasic(String str, String str2) {
        return httpHeaders -> {
            httpHeaders.set("Authorization", "Basic " + new String(Base64.getEncoder().encode((str + ":" + str2).getBytes(StandardCharsets.ISO_8859_1)), StandardCharsets.ISO_8859_1));
        };
    }

    private boolean shouldRefresh(OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        return (this.authorizedClientRepository == null || oAuth2AuthorizedClient.getRefreshToken() == null || !this.clock.instant().isAfter(oAuth2AuthorizedClient.getAccessToken().getExpiresAt().minus((TemporalAmount) this.accessTokenExpiresSkew))) ? false : true;
    }

    private ClientRequest bearer(ClientRequest clientRequest, OAuth2AuthorizedClient oAuth2AuthorizedClient) {
        return ClientRequest.from(clientRequest).headers(SecurityHeaders.bearerToken(oAuth2AuthorizedClient.getAccessToken().getTokenValue())).build();
    }

    private static BodyInserters.FormInserter<String> refreshTokenBody(String str) {
        return BodyInserters.fromFormData("grant_type", AuthorizationGrantType.REFRESH_TOKEN.getValue()).with("refresh_token", str);
    }

    static OAuth2AuthorizedClient getOAuth2AuthorizedClient(Map<String, Object> map) {
        return (OAuth2AuthorizedClient) map.get(OAUTH2_AUTHORIZED_CLIENT_ATTR_NAME);
    }

    static String getClientRegistrationId(Map<String, Object> map) {
        return (String) map.get(CLIENT_REGISTRATION_ID_ATTR_NAME);
    }

    static Authentication getAuthentication(Map<String, Object> map) {
        return (Authentication) map.get(AUTHENTICATION_ATTR_NAME);
    }

    static HttpServletRequest getRequest(Map<String, Object> map) {
        return (HttpServletRequest) map.get(HTTP_SERVLET_REQUEST_ATTR_NAME);
    }

    static HttpServletResponse getResponse(Map<String, Object> map) {
        return (HttpServletResponse) map.get(HTTP_SERVLET_RESPONSE_ATTR_NAME);
    }
}
