package org.springframework.cloud.gateway.rsocket.core;

import io.micrometer.core.instrument.simple.SimpleMeterRegistry;
import io.netty.buffer.Unpooled;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.util.DefaultPayload;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.assertj.core.api.Assertions;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentMatchers;
import org.mockito.Mockito;
import org.springframework.cloud.gateway.rsocket.autoconfigure.GatewayRSocketProperties;
import org.springframework.cloud.gateway.rsocket.filter.RSocketFilter;
import org.springframework.cloud.gateway.rsocket.registry.LoadBalancedRSocket;
import org.springframework.cloud.gateway.rsocket.registry.Registry;
import org.springframework.cloud.gateway.rsocket.route.Route;
import org.springframework.cloud.gateway.rsocket.route.Routes;
import org.springframework.cloud.gateway.rsocket.support.Metadata;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.MonoProcessor;
import reactor.test.StepVerifier;

/* loaded from: input_file:org/springframework/cloud/gateway/rsocket/core/GatewayRSocketTests.class */
public class GatewayRSocketTests {
    private static Log logger = LogFactory.getLog(GatewayRSocketTests.class);
    private Registry registry;
    private Payload incomingPayload;

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/core/GatewayRSocketTests$AsyncFilter.class */
    private static class AsyncFilter extends TestFilter {
        private AsyncFilter() {
            super();
        }

        @Override // org.springframework.cloud.gateway.rsocket.core.GatewayRSocketTests.TestFilter
        public Mono<RSocketFilter.Success> doFilter(GatewayExchange gatewayExchange, GatewayFilterChain gatewayFilterChain) {
            return doAsyncWork().flatMap(str -> {
                GatewayRSocketTests.logger.debug("Async result: " + str);
                return gatewayFilterChain.filter(gatewayExchange);
            });
        }

        private Mono<String> doAsyncWork() {
            return Mono.delay(Duration.ofMillis(100L)).map(l -> {
                return "123";
            });
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/core/GatewayRSocketTests$ExceptionFilter.class */
    private static class ExceptionFilter implements GatewayFilter {
        private ExceptionFilter() {
        }

        public Mono<RSocketFilter.Success> filter(GatewayExchange gatewayExchange, GatewayFilterChain gatewayFilterChain) {
            return Mono.error(new IllegalStateException("boo"));
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/core/GatewayRSocketTests$ShortcircuitingFilter.class */
    private static class ShortcircuitingFilter extends TestFilter {
        private ShortcircuitingFilter() {
            super();
        }

        @Override // org.springframework.cloud.gateway.rsocket.core.GatewayRSocketTests.TestFilter
        public Mono<RSocketFilter.Success> doFilter(GatewayExchange gatewayExchange, GatewayFilterChain gatewayFilterChain) {
            return Mono.empty();
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/core/GatewayRSocketTests$TestFilter.class */
    private static class TestFilter implements GatewayFilter {
        private volatile boolean invoked;

        private TestFilter() {
        }

        public boolean invoked() {
            return this.invoked;
        }

        public Mono<RSocketFilter.Success> filter(GatewayExchange gatewayExchange, GatewayFilterChain gatewayFilterChain) {
            this.invoked = true;
            return doFilter(gatewayExchange, gatewayFilterChain);
        }

        public Mono<RSocketFilter.Success> doFilter(GatewayExchange gatewayExchange, GatewayFilterChain gatewayFilterChain) {
            return gatewayFilterChain.filter(gatewayExchange);
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/core/GatewayRSocketTests$TestGatewayRSocket.class */
    private static class TestGatewayRSocket extends GatewayRSocket {
        private final MonoProcessor<RSocket> processor;

        TestGatewayRSocket(Registry registry, Routes routes) {
            super(registry, routes, new SimpleMeterRegistry(), new GatewayRSocketProperties(), GatewayRSocketTests.access$500());
            this.processor = MonoProcessor.create();
        }

        PendingRequestRSocket constructPendingRSocket(GatewayExchange gatewayExchange) {
            return new PendingRequestRSocket(registeredEvent -> {
                return getRouteMono(registeredEvent, gatewayExchange);
            }, metadata -> {
                gatewayExchange.setTags(gatewayExchange.getTags().and("responder.id", metadata.get("id")));
            }, this.processor);
        }

        public MonoProcessor<RSocket> getProcessor() {
            return this.processor;
        }
    }

    /* loaded from: input_file:org/springframework/cloud/gateway/rsocket/core/GatewayRSocketTests$TestRoutes.class */
    private static class TestRoutes implements Routes {
        private final Route route;
        private List<GatewayFilter> filters;

        TestRoutes() {
            this((List<GatewayFilter>) Collections.emptyList());
        }

        TestRoutes(GatewayFilter... gatewayFilterArr) {
            this((List<GatewayFilter>) Arrays.asList(gatewayFilterArr));
        }

        TestRoutes(List<GatewayFilter> list) {
            this.filters = list;
            this.route = Route.builder().id("route1").routingMetadata(Metadata.from("mock").build()).predicate(gatewayExchange -> {
                return Mono.just(true);
            }).filters(list).build();
        }

        public Flux<Route> getRoutes() {
            return Flux.just(this.route);
        }
    }

    @Before
    public void init() {
        this.registry = (Registry) Mockito.mock(Registry.class);
        this.incomingPayload = DefaultPayload.create(Unpooled.EMPTY_BUFFER, Metadata.from("mock").with("id", "mock1").encode());
        RSocket rSocket = (RSocket) Mockito.mock(RSocket.class);
        LoadBalancedRSocket loadBalancedRSocket = (LoadBalancedRSocket) Mockito.mock(LoadBalancedRSocket.class);
        Mockito.when(this.registry.getRegistered((Metadata) ArgumentMatchers.any(Metadata.class))).thenReturn(loadBalancedRSocket);
        Mockito.when(loadBalancedRSocket.choose()).thenReturn(Mono.just(new LoadBalancedRSocket.EnrichedRSocket(rSocket, getMetadata())));
        Mockito.when(rSocket.requestResponse((Payload) ArgumentMatchers.any(Payload.class))).thenReturn(Mono.just(DefaultPayload.create("response")));
    }

    @Test
    public void multipleFilters() {
        TestFilter testFilter = new TestFilter();
        TestFilter testFilter2 = new TestFilter();
        TestFilter testFilter3 = new TestFilter();
        Payload payload = (Payload) new TestGatewayRSocket(this.registry, new TestRoutes(testFilter, testFilter2, testFilter3)).requestResponse(this.incomingPayload).block(Duration.ZERO);
        Assertions.assertThat(testFilter.invoked()).isTrue();
        Assertions.assertThat(testFilter2.invoked()).isTrue();
        Assertions.assertThat(testFilter3.invoked()).isTrue();
        Assertions.assertThat(payload).isNotNull();
    }

    @Test
    public void zeroFilters() {
        Assertions.assertThat((Payload) new TestGatewayRSocket(this.registry, new TestRoutes()).requestResponse(this.incomingPayload).block(Duration.ZERO)).isNotNull();
    }

    @Test
    public void shortcircuitFilter() {
        TestFilter testFilter = new TestFilter();
        ShortcircuitingFilter shortcircuitingFilter = new ShortcircuitingFilter();
        TestFilter testFilter2 = new TestFilter();
        TestGatewayRSocket testGatewayRSocket = new TestGatewayRSocket(this.registry, new TestRoutes(testFilter, shortcircuitingFilter, testFilter2));
        Mono requestResponse = testGatewayRSocket.requestResponse(this.incomingPayload);
        testGatewayRSocket.processor.onNext((Object) null);
        StepVerifier.withVirtualTime(() -> {
            return requestResponse;
        }).expectSubscription().verifyComplete();
        Assertions.assertThat(testFilter.invoked()).isTrue();
        Assertions.assertThat(shortcircuitingFilter.invoked()).isTrue();
        Assertions.assertThat(testFilter2.invoked()).isFalse();
    }

    @Test
    public void asyncFilter() {
        AsyncFilter asyncFilter = new AsyncFilter();
        Payload payload = (Payload) new TestGatewayRSocket(this.registry, new TestRoutes(asyncFilter)).requestResponse(this.incomingPayload).block(Duration.ofSeconds(5L));
        Assertions.assertThat(asyncFilter.invoked()).isTrue();
        Assertions.assertThat(payload).isNotNull();
    }

    @Test(expected = IllegalStateException.class)
    public void handleErrorFromFilter() {
        new TestGatewayRSocket(this.registry, new TestRoutes(new ExceptionFilter())).requestResponse(this.incomingPayload).block(Duration.ofSeconds(5L));
    }

    private static Metadata getMetadata() {
        return Metadata.from("service").with("id", "service1").build();
    }

    static /* synthetic */ Metadata access$500() {
        return getMetadata();
    }
}
