package org.springframework.graphql.web.webmvc;

import graphql.ErrorType;
import graphql.GraphqlErrorBuilder;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URI;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
import org.springframework.graphql.web.WebGraphQlHandler;
import org.springframework.graphql.web.WebInput;
import org.springframework.graphql.web.WebOutput;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.SubProtocolCapable;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.handler.ExceptionWebSocketHandlerDecorator;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import reactor.core.publisher.BaseSubscriber;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.scheduler.Scheduler;
import reactor.core.scheduler.Schedulers;

/* loaded from: input_file:org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler.class */
public class GraphQlWebSocketHandler extends TextWebSocketHandler implements SubProtocolCapable {
    private static final Log logger = LogFactory.getLog(GraphQlWebSocketHandler.class);
    private static final List<String> SUB_PROTOCOL_LIST = Arrays.asList("graphql-transport-ws", "subscriptions-transport-ws");
    private final WebGraphQlHandler graphQlHandler;
    private final Duration initTimeoutDuration;
    private final HttpMessageConverter<?> converter;
    private final Map<String, SessionState> sessionInfoMap = new ConcurrentHashMap();

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler$GraphQlStatus.class */
    public static class GraphQlStatus {
        private static final CloseStatus INVALID_MESSAGE_STATUS = new CloseStatus(4400, "Invalid message");
        private static final CloseStatus UNAUTHORIZED_STATUS = new CloseStatus(4401, "Unauthorized");
        private static final CloseStatus INIT_TIMEOUT_STATUS = new CloseStatus(4408, "Connection initialisation timeout");
        private static final CloseStatus TOO_MANY_INIT_REQUESTS_STATUS = new CloseStatus(4429, "Too many initialisation requests");

        private GraphQlStatus() {
        }

        static void closeSession(WebSocketSession webSocketSession, CloseStatus closeStatus) {
            try {
                webSocketSession.close(closeStatus);
            } catch (IOException e) {
                if (GraphQlWebSocketHandler.logger.isDebugEnabled()) {
                    GraphQlWebSocketHandler.logger.debug("Error while closing session with status: " + closeStatus, e);
                }
            }
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler$HttpInputMessageAdapter.class */
    public static class HttpInputMessageAdapter extends ByteArrayInputStream implements HttpInputMessage {
        HttpInputMessageAdapter(TextMessage textMessage) {
            super(textMessage.asBytes());
        }

        public InputStream getBody() {
            return this;
        }

        public HttpHeaders getHeaders() {
            return HttpHeaders.EMPTY;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler$HttpOutputMessageAdapter.class */
    public static class HttpOutputMessageAdapter extends ByteArrayOutputStream implements HttpOutputMessage {
        private static final HttpHeaders noOpHeaders = new HttpHeaders();

        private HttpOutputMessageAdapter() {
        }

        public OutputStream getBody() {
            return this;
        }

        public HttpHeaders getHeaders() {
            return noOpHeaders;
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler$MessageType.class */
    public enum MessageType {
        CONNECTION_INIT("connection_init"),
        CONNECTION_ACK("connection_ack"),
        SUBSCRIBE("subscribe"),
        NEXT("next"),
        ERROR("error"),
        COMPLETE("complete");

        private static final Map<String, MessageType> messageTypes = new HashMap(6);
        private final String type;

        MessageType(String str) {
            this.type = str;
        }

        public String getType() {
            return this.type;
        }

        @Nullable
        public static MessageType resolve(@Nullable String str) {
            if (str != null) {
                return messageTypes.get(str);
            }
            return null;
        }

        static {
            for (MessageType messageType : values()) {
                messageTypes.put(messageType.getType(), messageType);
            }
        }
    }

    /* loaded from: input_file:org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler$SendMessageSubscriber.class */
    private static class SendMessageSubscriber extends BaseSubscriber<TextMessage> {
        private final String subscriptionId;
        private final WebSocketSession session;
        private final SessionState sessionState;

        SendMessageSubscriber(String str, WebSocketSession webSocketSession, SessionState sessionState) {
            this.subscriptionId = str;
            this.session = webSocketSession;
            this.sessionState = sessionState;
        }

        protected void hookOnSubscribe(Subscription subscription) {
            subscription.request(1L);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public void hookOnNext(TextMessage textMessage) {
            try {
                this.session.sendMessage(textMessage);
                request(1L);
            } catch (IOException e) {
                ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.session, e, GraphQlWebSocketHandler.logger);
            }
        }

        public void hookOnError(Throwable th) {
            ExceptionWebSocketHandlerDecorator.tryCloseWithError(this.session, th, GraphQlWebSocketHandler.logger);
        }

        public void hookOnComplete() {
            this.sessionState.getSubscriptions().remove(this.subscriptionId);
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler$SessionState.class */
    public static class SessionState {
        private boolean connectionInitProcessed;
        private final Map<String, Subscription> subscriptions = new ConcurrentHashMap();
        private final Scheduler scheduler;

        SessionState(String str) {
            this.scheduler = Schedulers.newSingle("GraphQL-WsSession-" + str);
        }

        boolean isConnectionInitNotProcessed() {
            return !this.connectionInitProcessed;
        }

        synchronized boolean setConnectionInitProcessed() {
            boolean z = this.connectionInitProcessed;
            this.connectionInitProcessed = true;
            return z;
        }

        Map<String, Subscription> getSubscriptions() {
            return this.subscriptions;
        }

        void dispose() {
            Iterator<Map.Entry<String, Subscription>> it = this.subscriptions.entrySet().iterator();
            while (it.hasNext()) {
                try {
                    it.next().getValue().cancel();
                } catch (Throwable th) {
                }
            }
            this.subscriptions.clear();
            this.scheduler.dispose();
        }

        Scheduler getScheduler() {
            return this.scheduler;
        }
    }

    /* loaded from: input_file:org/springframework/graphql/web/webmvc/GraphQlWebSocketHandler$SubscriptionExistsException.class */
    private static class SubscriptionExistsException extends RuntimeException {
        private SubscriptionExistsException() {
        }
    }

    public GraphQlWebSocketHandler(WebGraphQlHandler webGraphQlHandler, HttpMessageConverter<?> httpMessageConverter, Duration duration) {
        Assert.notNull(webGraphQlHandler, "WebGraphQlHandler is required");
        Assert.notNull(httpMessageConverter, "HttpMessageConverter for JSON is required");
        this.graphQlHandler = webGraphQlHandler;
        this.initTimeoutDuration = duration;
        this.converter = httpMessageConverter;
    }

    public List<String> getSubProtocols() {
        return SUB_PROTOCOL_LIST;
    }

    public void afterConnectionEstablished(WebSocketSession webSocketSession) {
        if ("subscriptions-transport-ws".equalsIgnoreCase(webSocketSession.getAcceptedProtocol())) {
            if (logger.isDebugEnabled()) {
                logger.debug("apollographql/subscriptions-transport-ws is not supported, nor maintained. Please, use https://github.com/enisdenjo/graphql-ws.");
            }
            GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
        } else {
            SessionState sessionState = new SessionState(webSocketSession.getId());
            this.sessionInfoMap.put(webSocketSession.getId(), sessionState);
            Mono.delay(this.initTimeoutDuration).then(Mono.fromRunnable(() -> {
                if (sessionState.isConnectionInitNotProcessed()) {
                    GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.INIT_TIMEOUT_STATUS);
                }
            })).subscribe();
        }
    }

    protected void handleTextMessage(WebSocketSession webSocketSession, TextMessage textMessage) throws Exception {
        Subscription remove;
        Map map = (Map) decode(textMessage, Map.class);
        String str = (String) map.get("id");
        MessageType resolve = MessageType.resolve((String) map.get("type"));
        if (resolve == null) {
            GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
            return;
        }
        SessionState sessionInfo = getSessionInfo(webSocketSession);
        switch (resolve) {
            case SUBSCRIBE:
                if (sessionInfo.isConnectionInitNotProcessed()) {
                    GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.UNAUTHORIZED_STATUS);
                    return;
                }
                if (str == null) {
                    GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
                    return;
                }
                URI uri = webSocketSession.getUri();
                Assert.notNull(uri, "Expected handshake url");
                WebInput webInput = new WebInput(uri, webSocketSession.getHandshakeHeaders(), getPayload(map), null, str);
                if (logger.isDebugEnabled()) {
                    logger.debug("Executing: " + webInput);
                }
                this.graphQlHandler.handleRequest(webInput).flatMapMany(webOutput -> {
                    return handleWebOutput(webSocketSession, webInput.getId(), webOutput);
                }).publishOn(sessionInfo.getScheduler()).subscribe(new SendMessageSubscriber(str, webSocketSession, sessionInfo));
                return;
            case COMPLETE:
                if (str != null && (remove = sessionInfo.getSubscriptions().remove(str)) != null) {
                    remove.cancel();
                }
                this.graphQlHandler.handleWebSocketCompletion().block(Duration.ofSeconds(10L));
                return;
            case CONNECTION_INIT:
                if (sessionInfo.setConnectionInitProcessed()) {
                    GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.TOO_MANY_INIT_REQUESTS_STATUS);
                    return;
                } else {
                    this.graphQlHandler.handleWebSocketInitialization(getPayload(map)).defaultIfEmpty(Collections.emptyMap()).publishOn(sessionInfo.getScheduler()).doOnNext(obj -> {
                        try {
                            webSocketSession.sendMessage(encode(null, MessageType.CONNECTION_ACK, obj));
                        } catch (IOException e) {
                            throw new IllegalStateException(e);
                        }
                    }).onErrorResume(th -> {
                        GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.UNAUTHORIZED_STATUS);
                        return Mono.empty();
                    }).block(Duration.ofSeconds(10L));
                    return;
                }
            default:
                GraphQlStatus.closeSession(webSocketSession, GraphQlStatus.INVALID_MESSAGE_STATUS);
                return;
        }
    }

    private <T> T decode(TextMessage textMessage, Class<T> cls) throws IOException {
        return (T) this.converter.read(cls, new HttpInputMessageAdapter(textMessage));
    }

    private static Map<String, Object> getPayload(Map<String, Object> map) {
        Map<String, Object> map2 = (Map) map.get("payload");
        return map2 != null ? map2 : Collections.emptyMap();
    }

    private SessionState getSessionInfo(WebSocketSession webSocketSession) {
        SessionState sessionState = this.sessionInfoMap.get(webSocketSession.getId());
        Assert.notNull(sessionState, "No SessionInfo for " + webSocketSession);
        return sessionState;
    }

    private Flux<TextMessage> handleWebOutput(WebSocketSession webSocketSession, String str, WebOutput webOutput) {
        Flux just;
        if (logger.isDebugEnabled()) {
            logger.debug("Execution result ready" + (!CollectionUtils.isEmpty(webOutput.getErrors()) ? " with errors: " + webOutput.getErrors() : "") + ".");
        }
        if (webOutput.getData() instanceof Publisher) {
            just = Flux.from((Publisher) webOutput.getData()).doOnSubscribe(subscription -> {
                if (getSessionInfo(webSocketSession).getSubscriptions().putIfAbsent(str, subscription) != null) {
                    throw new SubscriptionExistsException();
                }
            });
        } else {
            just = CollectionUtils.isEmpty(webOutput.getErrors()) ? Flux.just(webOutput) : Flux.error(new IllegalStateException("Execution failed: " + webOutput.getErrors()));
        }
        return just.map(executionResult -> {
            return encode(str, MessageType.NEXT, executionResult.toSpecification());
        }).concatWith(Mono.fromCallable(() -> {
            return encode(str, MessageType.COMPLETE, null);
        })).onErrorResume(th -> {
            if (th instanceof SubscriptionExistsException) {
                GraphQlStatus.closeSession(webSocketSession, new CloseStatus(4409, "Subscriber for " + str + " already exists"));
                return Flux.empty();
            }
            return Mono.just(encode(str, MessageType.ERROR, GraphqlErrorBuilder.newError().errorType(ErrorType.DataFetchingException).message(th.getMessage(), new Object[0]).build().toSpecification()));
        });
    }

    private <T> TextMessage encode(@Nullable String str, MessageType messageType, @Nullable Object obj) {
        HashMap hashMap = new HashMap(3);
        hashMap.put("type", messageType.getType());
        if (str != null) {
            hashMap.put("id", str);
        }
        if (obj != null) {
            hashMap.put("payload", obj);
        }
        try {
            HttpOutputMessageAdapter httpOutputMessageAdapter = new HttpOutputMessageAdapter();
            this.converter.write(hashMap, (MediaType) null, httpOutputMessageAdapter);
            return new TextMessage(httpOutputMessageAdapter.toByteArray());
        } catch (IOException e) {
            throw new IllegalStateException("Failed to write " + hashMap + " as JSON", e);
        }
    }

    public void handleTransportError(WebSocketSession webSocketSession, Throwable th) {
        SessionState remove = this.sessionInfoMap.remove(webSocketSession.getId());
        if (remove != null) {
            remove.dispose();
        }
    }

    public void afterConnectionClosed(WebSocketSession webSocketSession, CloseStatus closeStatus) {
        SessionState remove = this.sessionInfoMap.remove(webSocketSession.getId());
        if (remove != null) {
            remove.dispose();
        }
    }

    public boolean supportsPartialMessages() {
        return false;
    }
}
