package org.springframework.cloud.gateway.rsocket.socketacceptor;

import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.RSocket;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.assertj.core.api.Assertions;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.springframework.cloud.gateway.rsocket.autoconfigure.BrokerProperties;
import org.springframework.cloud.gateway.rsocket.common.metadata.Metadata;
import org.springframework.cloud.gateway.rsocket.common.metadata.RouteSetup;
import org.springframework.cloud.gateway.rsocket.common.metadata.TagsMetadata;
import org.springframework.cloud.gateway.rsocket.common.test.MetadataEncoder;
import org.springframework.cloud.gateway.rsocket.core.GatewayRSocket;
import org.springframework.cloud.gateway.rsocket.core.GatewayRSocketFactory;
import org.springframework.cloud.gateway.rsocket.filter.RSocketFilter;
import org.springframework.core.codec.Decoder;
import org.springframework.core.codec.Encoder;
import org.springframework.messaging.rsocket.DefaultMetadataExtractor;
import org.springframework.messaging.rsocket.PayloadUtils;
import org.springframework.messaging.rsocket.RSocketStrategies;
import reactor.core.publisher.Mono;

/* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests.class */
public class GatewaySocketAcceptorTests {
    private static Log logger = LogFactory.getLog(GatewaySocketAcceptorTests.class);
    private GatewayRSocketFactory factory;
    private ConnectionSetupPayload setupPayload;
    private RSocket sendingSocket;
    private MeterRegistry meterRegistry;
    private BrokerProperties properties = new BrokerProperties();
    private final RSocketStrategies rSocketStrategies = RSocketStrategies.builder().decoder(new Decoder[]{new RouteSetup.Decoder()}).encoder(new Encoder[]{new RouteSetup.Encoder()}).build();
    private DefaultMetadataExtractor metadataExtractor = new DefaultMetadataExtractor(this.rSocketStrategies.decoders());

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests$AsyncFilter.class */
    private static class AsyncFilter extends TestFilter {
        private AsyncFilter() {
            super();
        }

        @Override // org.springframework.cloud.gateway.rsocket.socketacceptor.GatewaySocketAcceptorTests.TestFilter
        public Mono<RSocketFilter.Success> doFilter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            return doAsyncWork().flatMap(str -> {
                GatewaySocketAcceptorTests.logger.debug("Async result: " + str);
                return socketAcceptorFilterChain.filter(socketAcceptorExchange);
            });
        }

        private Mono<String> doAsyncWork() {
            return Mono.delay(Duration.ofMillis(100L)).map(l -> {
                return "123";
            });
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests$ExceptionFilter.class */
    private static class ExceptionFilter implements SocketAcceptorFilter {
        private ExceptionFilter() {
        }

        public Mono<RSocketFilter.Success> filter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            return Mono.error(new IllegalStateException("boo"));
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests$ShortcircuitingFilter.class */
    private static class ShortcircuitingFilter extends TestFilter {
        private ShortcircuitingFilter() {
            super();
        }

        @Override // org.springframework.cloud.gateway.rsocket.socketacceptor.GatewaySocketAcceptorTests.TestFilter
        public Mono<RSocketFilter.Success> doFilter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            return Mono.empty();
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests$TestFilter.class */
    private static class TestFilter implements SocketAcceptorFilter {
        private volatile boolean invoked;

        private TestFilter() {
        }

        public boolean invoked() {
            return this.invoked;
        }

        public Mono<RSocketFilter.Success> filter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            this.invoked = true;
            return doFilter(socketAcceptorExchange, socketAcceptorFilterChain);
        }

        public Mono<RSocketFilter.Success> doFilter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            return socketAcceptorFilterChain.filter(socketAcceptorExchange);
        }
    }

    @Before
    public void init() {
        this.factory = (GatewayRSocketFactory) Mockito.mock(GatewayRSocketFactory.class);
        this.setupPayload = (ConnectionSetupPayload) Mockito.mock(ConnectionSetupPayload.class);
        this.sendingSocket = (RSocket) Mockito.mock(RSocket.class);
        this.meterRegistry = new SimpleMeterRegistry();
        this.metadataExtractor.metadataToExtract(RouteSetup.ROUTE_SETUP_MIME_TYPE, RouteSetup.class, "routesetup");
        Mockito.when(this.factory.create((TagsMetadata) ArgumentMatchers.any(TagsMetadata.class))).thenReturn(Mockito.mock(GatewayRSocket.class));
        Mockito.when(this.setupPayload.metadataMimeType()).thenReturn(Metadata.COMPOSITE_MIME_TYPE.toString());
        Mockito.when(Boolean.valueOf(this.setupPayload.hasMetadata())).thenReturn(true);
        MetadataEncoder metadataEncoder = new MetadataEncoder(Metadata.COMPOSITE_MIME_TYPE, this.rSocketStrategies);
        metadataEncoder.metadata(RouteSetup.of(1L, "myservice").build(), RouteSetup.ROUTE_SETUP_MIME_TYPE);
        Mockito.when(this.setupPayload.metadata()).thenReturn(PayloadUtils.createPayload(MetadataEncoder.emptyDataBuffer(this.rSocketStrategies), metadataEncoder.encode()).metadata());
    }

    @Test
    public void multipleFilters() {
        TestFilter testFilter = new TestFilter();
        TestFilter testFilter2 = new TestFilter();
        TestFilter testFilter3 = new TestFilter();
        RSocket rSocket = (RSocket) new GatewaySocketAcceptor(this.factory, Arrays.asList(testFilter, testFilter2, testFilter3), this.meterRegistry, this.properties, this.metadataExtractor).accept(this.setupPayload, this.sendingSocket).block(Duration.ZERO);
        Assertions.assertThat(testFilter.invoked()).isTrue();
        Assertions.assertThat(testFilter2.invoked()).isTrue();
        Assertions.assertThat(testFilter3.invoked()).isTrue();
        Assertions.assertThat(rSocket).isNotNull();
    }

    @Test
    public void zeroFilters() {
        Assertions.assertThat((RSocket) new GatewaySocketAcceptor(this.factory, Collections.emptyList(), this.meterRegistry, this.properties, this.metadataExtractor).accept(this.setupPayload, this.sendingSocket).block(Duration.ZERO)).isNotNull();
    }

    @Test
    public void shortcircuitFilter() {
        TestFilter testFilter = new TestFilter();
        ShortcircuitingFilter shortcircuitingFilter = new ShortcircuitingFilter();
        TestFilter testFilter2 = new TestFilter();
        RSocket rSocket = (RSocket) new GatewaySocketAcceptor(this.factory, Arrays.asList(testFilter, shortcircuitingFilter, testFilter2), this.meterRegistry, this.properties, this.metadataExtractor).accept(this.setupPayload, this.sendingSocket).block(Duration.ZERO);
        Assertions.assertThat(testFilter.invoked()).isTrue();
        Assertions.assertThat(shortcircuitingFilter.invoked()).isTrue();
        Assertions.assertThat(testFilter2.invoked()).isFalse();
        Assertions.assertThat(rSocket).isNull();
    }

    @Test
    public void asyncFilter() {
        AsyncFilter asyncFilter = new AsyncFilter();
        RSocket rSocket = (RSocket) new GatewaySocketAcceptor(this.factory, Collections.singletonList(asyncFilter), this.meterRegistry, this.properties, this.metadataExtractor).accept(this.setupPayload, this.sendingSocket).block(Duration.ofSeconds(5L));
        Assertions.assertThat(asyncFilter.invoked()).isTrue();
        Assertions.assertThat(rSocket).isNotNull();
    }

    @Test(expected = IllegalStateException.class)
    public void handleErrorFromFilter() {
        new GatewaySocketAcceptor(this.factory, Collections.singletonList(new ExceptionFilter()), this.meterRegistry, this.properties, this.metadataExtractor).accept(this.setupPayload, this.sendingSocket).block(Duration.ofSeconds(5L));
    }
}
