package org.springframework.graphql.web.webflux;

import graphql.ErrorType;
import graphql.GraphqlErrorBuilder;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.core.ResolvableType;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.core.io.buffer.DataBufferUtils;
import org.springframework.graphql.web.WebGraphQlHandler;
import org.springframework.graphql.web.WebInput;
import org.springframework.graphql.web.WebOutput;
import org.springframework.http.MediaType;
import org.springframework.http.codec.DecoderHttpMessageReader;
import org.springframework.http.codec.EncoderHttpMessageWriter;
import org.springframework.http.codec.ServerCodecConfigurer;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.util.MimeType;
import org.springframework.util.MimeTypeUtils;
import org.springframework.web.reactive.socket.CloseStatus;
import org.springframework.web.reactive.socket.HandshakeInfo;
import org.springframework.web.reactive.socket.WebSocketHandler;
import org.springframework.web.reactive.socket.WebSocketMessage;
import org.springframework.web.reactive.socket.WebSocketSession;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

/* loaded from: input_file:org/springframework/graphql/web/webflux/GraphQlWebSocketHandler.class */
public class GraphQlWebSocketHandler implements WebSocketHandler {
    private static final Log logger = LogFactory.getLog(GraphQlWebSocketHandler.class);
    private static final List<String> SUB_PROTOCOL_LIST = Arrays.asList("graphql-transport-ws", "graphql-ws");
    static final ResolvableType MAP_RESOLVABLE_TYPE = ResolvableType.forType(new ParameterizedTypeReference<Map<String, Object>>() { // from class: org.springframework.graphql.web.webflux.GraphQlWebSocketHandler.1
    });
    private final WebGraphQlHandler graphQlHandler;
    private final Decoder<?> decoder;
    private final Encoder<?> encoder;
    private final Duration initTimeoutDuration;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/web/webflux/GraphQlWebSocketHandler$GraphQlStatus.class */
    public static class GraphQlStatus {
        static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus(4400, "Invalid message");
        static final CloseStatus UNAUTHORIZED_STATUS = new CloseStatus(4401, "Unauthorized");
        static final CloseStatus INIT_TIMEOUT_STATUS = new CloseStatus(4408, "Connection initialisation timeout");
        static final CloseStatus TOO_MANY_INIT_REQUESTS_STATUS = new CloseStatus(4429, "Too many initialisation requests");

        private GraphQlStatus() {
        }

        /* JADX INFO: Access modifiers changed from: package-private */
        public static <V> Flux<V> close(WebSocketSession webSocketSession, CloseStatus closeStatus) {
            return webSocketSession.close(closeStatus).thenMany(Mono.empty());
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/web/webflux/GraphQlWebSocketHandler$MessageType.class */
    public enum MessageType {
        CONNECTION_INIT("connection_init"),
        CONNECTION_ACK("connection_ack"),
        SUBSCRIBE("subscribe"),
        NEXT("next"),
        ERROR("error"),
        COMPLETE("complete");

        private static final Map<String, MessageType> messageTypes = new HashMap(6);
        private final String type;

        MessageType(String str) {
            this.type = str;
        }

        public String getType() {
            return this.type;
        }

        @Nullable
        public static MessageType resolve(@Nullable String str) {
            if (str != null) {
                return messageTypes.get(str);
            }
            return null;
        }

        static {
            for (MessageType messageType : values()) {
                messageTypes.put(messageType.getType(), messageType);
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/web/webflux/GraphQlWebSocketHandler$SubscriptionExistsException.class */
    public static class SubscriptionExistsException extends RuntimeException {
        private SubscriptionExistsException() {
        }
    }

    public GraphQlWebSocketHandler(WebGraphQlHandler webGraphQlHandler, ServerCodecConfigurer serverCodecConfigurer, Duration duration) {
        Assert.notNull(webGraphQlHandler, "WebGraphQlHandler is required");
        this.graphQlHandler = webGraphQlHandler;
        this.decoder = initDecoder(serverCodecConfigurer);
        this.encoder = initEncoder(serverCodecConfigurer);
        this.initTimeoutDuration = duration;
    }

    private static Decoder<?> initDecoder(ServerCodecConfigurer serverCodecConfigurer) {
        return (Decoder) serverCodecConfigurer.getReaders().stream().filter(httpMessageReader -> {
            return httpMessageReader.canRead(MAP_RESOLVABLE_TYPE, MediaType.APPLICATION_JSON);
        }).map(httpMessageReader2 -> {
            return ((DecoderHttpMessageReader) httpMessageReader2).getDecoder();
        }).findFirst().orElseThrow(() -> {
            return new IllegalArgumentException("No JSON Decoder");
        });
    }

    private static Encoder<?> initEncoder(ServerCodecConfigurer serverCodecConfigurer) {
        return (Encoder) serverCodecConfigurer.getWriters().stream().filter(httpMessageWriter -> {
            return httpMessageWriter.canWrite(MAP_RESOLVABLE_TYPE, MediaType.APPLICATION_JSON);
        }).map(httpMessageWriter2 -> {
            return ((EncoderHttpMessageWriter) httpMessageWriter2).getEncoder();
        }).findFirst().orElseThrow(() -> {
            return new IllegalArgumentException("No JSON Encoder");
        });
    }

    public List<String> getSubProtocols() {
        return SUB_PROTOCOL_LIST;
    }

    public Mono<Void> handle(WebSocketSession webSocketSession) {
        HandshakeInfo handshakeInfo = webSocketSession.getHandshakeInfo();
        if ("graphql-ws".equalsIgnoreCase(handshakeInfo.getSubProtocol())) {
            if (logger.isDebugEnabled()) {
                logger.debug("apollographql/subscriptions-transport-ws is not supported, nor maintained. Please, use https://github.com/enisdenjo/graphql-ws.");
            }
            return webSocketSession.close(GraphQlStatus.INVALID_MESSAGE_STATUS);
        }
        AtomicBoolean atomicBoolean = new AtomicBoolean();
        ConcurrentHashMap concurrentHashMap = new ConcurrentHashMap();
        Mono.delay(this.initTimeoutDuration).then(Mono.defer(() -> {
            return atomicBoolean.compareAndSet(false, true) ? webSocketSession.close(GraphQlStatus.INIT_TIMEOUT_STATUS) : Mono.empty();
        })).subscribe();
        return webSocketSession.send(webSocketSession.receive().flatMap(webSocketMessage -> {
            Subscription subscription;
            Map<String, Object> decode = decode(webSocketMessage);
            String str = (String) decode.get("id");
            MessageType resolve = MessageType.resolve((String) decode.get("type"));
            if (resolve == null) {
                return GraphQlStatus.close(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
            }
            switch (resolve) {
                case SUBSCRIBE:
                    if (!atomicBoolean.get()) {
                        return GraphQlStatus.close(webSocketSession, GraphQlStatus.UNAUTHORIZED_STATUS);
                    }
                    if (str == null) {
                        return GraphQlStatus.close(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
                    }
                    WebInput webInput = new WebInput(handshakeInfo.getUri(), handshakeInfo.getHeaders(), getPayload(decode), null, str);
                    if (logger.isDebugEnabled()) {
                        logger.debug("Executing: " + webInput);
                    }
                    return this.graphQlHandler.handleRequest(webInput).flatMapMany(webOutput -> {
                        return handleWebOutput(webSocketSession, str, concurrentHashMap, webOutput);
                    }).doOnTerminate(() -> {
                    });
                case COMPLETE:
                    if (str != null && (subscription = (Subscription) concurrentHashMap.remove(str)) != null) {
                        subscription.cancel();
                    }
                    return this.graphQlHandler.handleWebSocketCompletion().thenMany(Flux.empty());
                case CONNECTION_INIT:
                    return !atomicBoolean.compareAndSet(false, true) ? GraphQlStatus.close(webSocketSession, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS) : this.graphQlHandler.handleWebSocketInitialization(getPayload(decode)).defaultIfEmpty(Collections.emptyMap()).flatMapMany(obj -> {
                        return Flux.just(encode(webSocketSession, null, MessageType.CONNECTION_ACK, obj));
                    }).onErrorResume(th -> {
                        return GraphQlStatus.close(webSocketSession, GraphQlStatus.UNAUTHORIZED_STATUS);
                    });
                default:
                    return GraphQlStatus.close(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
            }
        }));
    }

    private Map<String, Object> decode(WebSocketMessage webSocketMessage) {
        return (Map) this.decoder.decode(DataBufferUtils.retain(webSocketMessage.getPayload()), MAP_RESOLVABLE_TYPE, (MimeType) null, (Map) null);
    }

    private static Map<String, Object> getPayload(Map<String, Object> map) {
        Map<String, Object> map2 = (Map) map.get("payload");
        return map2 != null ? map2 : Collections.emptyMap();
    }

    private Flux<WebSocketMessage> handleWebOutput(WebSocketSession webSocketSession, String str, Map<String, Subscription> map, WebOutput webOutput) {
        Flux just;
        if (logger.isDebugEnabled()) {
            logger.debug("Execution result ready" + (!CollectionUtils.isEmpty(webOutput.getErrors()) ? " with errors: " + webOutput.getErrors() : "") + ".");
        }
        if (webOutput.getData() instanceof Publisher) {
            just = Flux.from((Publisher) webOutput.getData()).doOnSubscribe(subscription -> {
                if (((Subscription) map.putIfAbsent(str, subscription)) != null) {
                    throw new SubscriptionExistsException();
                }
            });
        } else {
            just = CollectionUtils.isEmpty(webOutput.getErrors()) ? Flux.just(webOutput) : Flux.error(new IllegalStateException("Execution failed: " + webOutput.getErrors()));
        }
        return just.map(executionResult -> {
            return encode(webSocketSession, str, MessageType.NEXT, executionResult.toSpecification());
        }).concatWith(Mono.fromCallable(() -> {
            return encode(webSocketSession, str, MessageType.COMPLETE, null);
        })).onErrorResume(th -> {
            if (th instanceof SubscriptionExistsException) {
                return GraphQlStatus.close(webSocketSession, new CloseStatus(4409, "Subscriber for " + str + " already exists"));
            }
            return Mono.just(encode(webSocketSession, str, MessageType.ERROR, Collections.singletonList(GraphqlErrorBuilder.newError().errorType(ErrorType.DataFetchingException).message(th.getMessage(), new Object[0]).build().toSpecification())));
        });
    }

    private <T> WebSocketMessage encode(WebSocketSession webSocketSession, @Nullable String str, MessageType messageType, @Nullable Object obj) {
        HashMap hashMap = new HashMap(3);
        if (str != null) {
            hashMap.put("id", str);
        }
        hashMap.put("type", messageType.getType());
        if (obj != null) {
            hashMap.put("payload", obj);
        }
        return new WebSocketMessage(WebSocketMessage.Type.TEXT, this.encoder.encodeValue(hashMap, webSocketSession.bufferFactory(), MAP_RESOLVABLE_TYPE, MimeTypeUtils.APPLICATION_JSON, (Map) null));
    }
}
