package org.springframework.cloud.gateway.rsocket.socketacceptor;

import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.RSocket;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.assertj.core.api.Assertions;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.springframework.cloud.gateway.rsocket.autoconfigure.GatewayRSocketProperties;
import org.springframework.cloud.gateway.rsocket.core.GatewayRSocket;
import org.springframework.cloud.gateway.rsocket.filter.RSocketFilter;
import org.springframework.cloud.gateway.rsocket.support.Metadata;
import reactor.core.publisher.Mono;

/* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests.class */
public class GatewaySocketAcceptorTests {
    private static Log logger = LogFactory.getLog(GatewaySocketAcceptorTests.class);
    private GatewayRSocket.Factory factory;
    private ConnectionSetupPayload setupPayload;
    private RSocket sendingSocket;
    private MeterRegistry meterRegistry;
    private GatewayRSocketProperties properties = new GatewayRSocketProperties();

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests$AsyncFilter.class */
    private static class AsyncFilter extends TestFilter {
        private AsyncFilter() {
            super();
        }

        @Override // org.springframework.cloud.gateway.rsocket.socketacceptor.GatewaySocketAcceptorTests.TestFilter
        public Mono<RSocketFilter.Success> doFilter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            return doAsyncWork().flatMap(str -> {
                GatewaySocketAcceptorTests.logger.debug("Async result: " + str);
                return socketAcceptorFilterChain.filter(socketAcceptorExchange);
            });
        }

        private Mono<String> doAsyncWork() {
            return Mono.delay(Duration.ofMillis(100L)).map(l -> {
                return "123";
            });
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests$ExceptionFilter.class */
    private static class ExceptionFilter implements SocketAcceptorFilter {
        private ExceptionFilter() {
        }

        public Mono<RSocketFilter.Success> filter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            return Mono.error(new IllegalStateException("boo"));
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests$ShortcircuitingFilter.class */
    private static class ShortcircuitingFilter extends TestFilter {
        private ShortcircuitingFilter() {
            super();
        }

        @Override // org.springframework.cloud.gateway.rsocket.socketacceptor.GatewaySocketAcceptorTests.TestFilter
        public Mono<RSocketFilter.Success> doFilter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            return Mono.empty();
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/socketacceptor/GatewaySocketAcceptorTests$TestFilter.class */
    private static class TestFilter implements SocketAcceptorFilter {
        private volatile boolean invoked;

        private TestFilter() {
        }

        public boolean invoked() {
            return this.invoked;
        }

        public Mono<RSocketFilter.Success> filter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            this.invoked = true;
            return doFilter(socketAcceptorExchange, socketAcceptorFilterChain);
        }

        public Mono<RSocketFilter.Success> doFilter(SocketAcceptorExchange socketAcceptorExchange, SocketAcceptorFilterChain socketAcceptorFilterChain) {
            return socketAcceptorFilterChain.filter(socketAcceptorExchange);
        }
    }

    @Before
    public void init() {
        this.factory = (GatewayRSocket.Factory) Mockito.mock(GatewayRSocket.Factory.class);
        this.setupPayload = (ConnectionSetupPayload) Mockito.mock(ConnectionSetupPayload.class);
        this.sendingSocket = (RSocket) Mockito.mock(RSocket.class);
        this.meterRegistry = new SimpleMeterRegistry();
        Mockito.when(this.factory.create((Metadata) ArgumentMatchers.any(Metadata.class))).thenReturn(Mockito.mock(GatewayRSocket.class));
    }

    @Test
    public void multipleFilters() {
        TestFilter testFilter = new TestFilter();
        TestFilter testFilter2 = new TestFilter();
        TestFilter testFilter3 = new TestFilter();
        RSocket rSocket = (RSocket) new GatewaySocketAcceptor(this.factory, Arrays.asList(testFilter, testFilter2, testFilter3), this.meterRegistry, this.properties).accept(this.setupPayload, this.sendingSocket).block(Duration.ZERO);
        Assertions.assertThat(testFilter.invoked()).isTrue();
        Assertions.assertThat(testFilter2.invoked()).isTrue();
        Assertions.assertThat(testFilter3.invoked()).isTrue();
        Assertions.assertThat(rSocket).isNotNull();
    }

    @Test
    public void zeroFilters() {
        Assertions.assertThat((RSocket) new GatewaySocketAcceptor(this.factory, Collections.emptyList(), this.meterRegistry, this.properties).accept(this.setupPayload, this.sendingSocket).block(Duration.ZERO)).isNotNull();
    }

    @Test
    public void shortcircuitFilter() {
        TestFilter testFilter = new TestFilter();
        ShortcircuitingFilter shortcircuitingFilter = new ShortcircuitingFilter();
        TestFilter testFilter2 = new TestFilter();
        RSocket rSocket = (RSocket) new GatewaySocketAcceptor(this.factory, Arrays.asList(testFilter, shortcircuitingFilter, testFilter2), this.meterRegistry, this.properties).accept(this.setupPayload, this.sendingSocket).block(Duration.ZERO);
        Assertions.assertThat(testFilter.invoked()).isTrue();
        Assertions.assertThat(shortcircuitingFilter.invoked()).isTrue();
        Assertions.assertThat(testFilter2.invoked()).isFalse();
        Assertions.assertThat(rSocket).isNull();
    }

    @Test
    public void asyncFilter() {
        AsyncFilter asyncFilter = new AsyncFilter();
        RSocket rSocket = (RSocket) new GatewaySocketAcceptor(this.factory, Collections.singletonList(asyncFilter), this.meterRegistry, this.properties).accept(this.setupPayload, this.sendingSocket).block(Duration.ofSeconds(5L));
        Assertions.assertThat(asyncFilter.invoked()).isTrue();
        Assertions.assertThat(rSocket).isNotNull();
    }

    @Test(expected = IllegalStateException.class)
    public void handleErrorFromFilter() {
        new GatewaySocketAcceptor(this.factory, Collections.singletonList(new ExceptionFilter()), this.meterRegistry, this.properties).accept(this.setupPayload, this.sendingSocket).block(Duration.ofSeconds(5L));
    }
}
