package org.springframework.cloud.sleuth.instrument.rsocket;

import io.rsocket.frame.FrameType;
import java.net.URI;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;
import org.assertj.core.api.BDDAssertions;
import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;
import org.springframework.boot.WebApplicationType;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.cloud.sleuth.Span;
import org.springframework.cloud.sleuth.TraceContext;
import org.springframework.cloud.sleuth.Tracer;
import org.springframework.cloud.sleuth.exporter.FinishedSpan;
import org.springframework.cloud.sleuth.test.TestSpanHandler;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.Environment;
import org.springframework.messaging.handler.annotation.MessageMapping;
import org.springframework.messaging.handler.annotation.Payload;
import org.springframework.messaging.rsocket.RSocketRequester;
import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.stereotype.Controller;
import org.springframework.util.MimeType;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.util.context.ContextView;

/* loaded from: input_file:org/springframework/cloud/sleuth/instrument/rsocket/TraceRSocketTests.class */
public abstract class TraceRSocketTests {
    public static final String EXPECTED_TRACE_ID = "b919095138aa4c6e";

    @Configuration(proxyBeanMethods = false)
    @EnableAutoConfiguration
    /* loaded from: input_file:org/springframework/cloud/sleuth/instrument/rsocket/TraceRSocketTests$MyConfig.class */
    static class MyConfig {
        MyConfig() {
        }

        @Bean
        TestController controller(Tracer tracer) {
            return new TestController(tracer);
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @MessageMapping({"api.c2"})
    @Controller
    /* loaded from: input_file:org/springframework/cloud/sleuth/instrument/rsocket/TraceRSocketTests$TestController.class */
    public static class TestController {
        final Tracer tracer;
        Span span;
        ContextView interceptedContext;
        BlockingQueue<FrameType> receivedFrames = new LinkedBlockingDeque();

        TestController(Tracer tracer) {
            this.tracer = tracer;
        }

        BlockingQueue<FrameType> getReceivedFrames() {
            return this.receivedFrames;
        }

        Span getSpan() {
            return this.span;
        }

        void reset() {
            this.span = null;
        }

        @MessageMapping({"fnf"})
        Mono<Void> testFnf() {
            this.span = this.tracer.currentSpan();
            return Mono.deferContextual(contextView -> {
                this.interceptedContext = contextView;
                this.receivedFrames.offer(FrameType.REQUEST_FNF);
                return Mono.empty();
            });
        }

        @MessageMapping({"rr"})
        Mono<String> testRR() {
            this.span = this.tracer.currentSpan();
            return Mono.deferContextual(contextView -> {
                this.interceptedContext = contextView;
                this.receivedFrames.offer(FrameType.REQUEST_RESPONSE);
                return Mono.just("response");
            });
        }

        @MessageMapping({"rs"})
        Flux<String> testRS() {
            this.span = this.tracer.currentSpan();
            return Flux.deferContextual(contextView -> {
                this.interceptedContext = contextView;
                this.receivedFrames.offer(FrameType.REQUEST_STREAM);
                return Flux.just("stream");
            });
        }

        @MessageMapping({"rc"})
        Flux<String> testRC(@Payload Flux<String> flux) {
            this.span = this.tracer.currentSpan();
            return Flux.deferContextual(contextView -> {
                this.interceptedContext = contextView;
                this.receivedFrames.offer(FrameType.REQUEST_CHANNEL);
                return flux;
            });
        }
    }

    @Test
    public void should_instrument_responder() throws Exception {
        ConfigurableApplicationContext run = new SpringApplicationBuilder(new Class[]{MyConfig.class, testConfiguration()}).web(WebApplicationType.REACTIVE).properties(new String[]{"server.port=0", "spring.rsocket.server.transport=websocket", "spring.rsocket.server.mapping-path=/rsocket", "spring.jmx.enabled=false", "spring.application.name=TraceRSocketTests", "security.basic.enabled=false", "management.security.enabled=false"}).run(new String[0]);
        TestSpanHandler testSpanHandler = (TestSpanHandler) run.getBean(TestSpanHandler.class);
        int intValue = ((Integer) ((Environment) run.getBean(Environment.class)).getProperty("local.server.port", Integer.class)).intValue();
        TestController testController = (TestController) run.getBean(TestController.class);
        RSocketRequester websocket = RSocketRequester.builder().rsocketStrategies((RSocketStrategies) run.getBean(RSocketStrategies.class)).websocket(URI.create("ws://localhost:" + intValue + "/rsocket"));
        whenRequestFnFIsSent(websocket, "api.c2.fnf").block();
        thenSpanWasReportedWithTags(testSpanHandler, "api.c2.fnf", testController.getReceivedFrames().take());
        testSpanHandler.clear();
        testController.reset();
        whenRequestResponseIsSent(websocket, "api.c2.rr").block();
        thenSpanWasReportedWithTags(testSpanHandler, "api.c2.rr", testController.getReceivedFrames().take());
        testSpanHandler.clear();
        testController.reset();
        whenRequestStreamIsSent(websocket, "api.c2.rs").blockLast();
        thenSpanWasReportedWithTags(testSpanHandler, "api.c2.rs", testController.getReceivedFrames().take());
        testSpanHandler.clear();
        testController.reset();
        whenRequestChannelIsSent(websocket, "api.c2.rc").blockLast();
        thenSpanWasReportedWithTags(testSpanHandler, "api.c2.rc", testController.getReceivedFrames().take());
        testSpanHandler.clear();
        testController.reset();
        whenNonSampledRequestFnfIsSent(websocket);
        testController.getReceivedFrames().take();
        thenNoSpanWasReported(testSpanHandler, testController, expectedTraceId());
        testSpanHandler.clear();
        testController.reset();
        whenNonSampledRequestResponseIsSent(websocket);
        testController.getReceivedFrames().take();
        thenNoSpanWasReported(testSpanHandler, testController, expectedTraceId());
        testSpanHandler.clear();
        testController.reset();
        whenNonSampledRequestStreamIsSent(websocket);
        testController.getReceivedFrames().take();
        thenNoSpanWasReported(testSpanHandler, testController, expectedTraceId());
        testSpanHandler.clear();
        testController.reset();
        whenNonSampledRequestChannelIsSent(websocket);
        testController.getReceivedFrames().take();
        thenNoSpanWasReported(testSpanHandler, testController, expectedTraceId());
        testSpanHandler.clear();
        testController.reset();
        run.close();
    }

    protected String expectedTraceId() {
        return EXPECTED_TRACE_ID;
    }

    protected String expectedSpanId() {
        return EXPECTED_TRACE_ID;
    }

    @Test
    public void should_instrument_requester_and_responder() throws Exception {
        ConfigurableApplicationContext run = new SpringApplicationBuilder(new Class[]{MyConfig.class, testConfiguration()}).web(WebApplicationType.REACTIVE).properties(new String[]{"server.port=0", "spring.rsocket.server.transport=websocket", "spring.rsocket.server.mapping-path=/rsocket", "spring.jmx.enabled=false", "spring.application.name=TraceRSocketTests", "security.basic.enabled=false", "management.security.enabled=false"}).run(new String[0]);
        Tracer tracer = (Tracer) run.getBean(Tracer.class);
        TestSpanHandler testSpanHandler = (TestSpanHandler) run.getBean(TestSpanHandler.class);
        int intValue = ((Integer) ((Environment) run.getBean(Environment.class)).getProperty("local.server.port", Integer.class)).intValue();
        TestController testController = (TestController) run.getBean(TestController.class);
        RSocketRequester websocket = ((RSocketRequester.Builder) run.getBean(RSocketRequester.Builder.class)).websocket(URI.create("ws://localhost:" + intValue + "/rsocket"));
        Span start = tracer.nextSpan().start();
        whenRequestFnFIsSent(websocket, "api.c2.fnf").contextWrite(context -> {
            return context.put(TraceContext.class, start.context());
        }).doFinally(signalType -> {
            start.end();
        }).block();
        testController.getReceivedFrames().take();
        thenNoSpanWasReported(testSpanHandler, testController, start.context().traceId());
        testSpanHandler.clear();
        testController.reset();
        Span start2 = tracer.nextSpan().start();
        whenRequestResponseIsSent(websocket, "api.c2.rr").contextWrite(context2 -> {
            return context2.put(TraceContext.class, start2.context());
        }).doFinally(signalType2 -> {
            start2.end();
        }).block();
        testController.getReceivedFrames().take();
        thenNoSpanWasReported(testSpanHandler, testController, start2.context().traceId());
        testSpanHandler.clear();
        testController.reset();
        Span start3 = tracer.nextSpan().start();
        whenRequestStreamIsSent(websocket, "api.c2.rs").contextWrite(context3 -> {
            return context3.put(TraceContext.class, start3.context());
        }).doFinally(signalType3 -> {
            start3.end();
        }).blockLast();
        testController.getReceivedFrames().take();
        thenNoSpanWasReported(testSpanHandler, testController, start3.context().traceId());
        testSpanHandler.clear();
        testController.reset();
        Span start4 = tracer.nextSpan().start();
        whenRequestChannelIsSent(websocket, "api.c2.rc").contextWrite(context4 -> {
            return context4.put(TraceContext.class, start4.context());
        }).doFinally(signalType4 -> {
            start4.end();
        }).blockLast();
        testController.getReceivedFrames().take();
        thenNoSpanWasReported(testSpanHandler, testController, start4.context().traceId());
        testSpanHandler.clear();
        testController.reset();
        run.close();
    }

    protected abstract Class testConfiguration();

    private void thenSpanWasReportedWithTags(TestSpanHandler testSpanHandler, String str, FrameType frameType) {
        String str2 = frameType.name() + " " + str;
        Awaitility.await().untilAsserted(() -> {
            FinishedSpan orElseThrow = testSpanHandler.reportedSpans().stream().filter(finishedSpan -> {
                return str2.equals(finishedSpan.getName());
            }).findFirst().orElseThrow(() -> {
                return new AssertionError("Span with name [" + str2 + "] not found");
            });
            BDDAssertions.then(orElseThrow.getTags()).containsEntry("messaging.controller.class", "org.springframework.cloud.sleuth.instrument.rsocket.TraceRSocketTests$TestController");
            BDDAssertions.then(orElseThrow.getTags()).containsKey("messaging.controller.method");
        });
    }

    private Mono<Void> whenRequestFnFIsSent(RSocketRequester rSocketRequester, String str) {
        return rSocketRequester.route(str, new Object[0]).send();
    }

    private Mono<String> whenRequestResponseIsSent(RSocketRequester rSocketRequester, String str) {
        return rSocketRequester.route(str, new Object[0]).retrieveMono(String.class);
    }

    private Flux<String> whenRequestStreamIsSent(RSocketRequester rSocketRequester, String str) {
        return rSocketRequester.route(str, new Object[0]).retrieveFlux(String.class);
    }

    private Flux<String> whenRequestChannelIsSent(RSocketRequester rSocketRequester, String str) {
        return rSocketRequester.route(str, new Object[0]).data(Flux.fromArray(new String[]{"test1", "test2"})).retrieveFlux(String.class);
    }

    private void whenNonSampledRequestFnfIsSent(RSocketRequester rSocketRequester) {
        rSocketRequester.route("api.c2.fnf", new Object[0]).metadata(expectedTraceId() + "-" + expectedSpanId() + "-0", new MimeType("b3") { // from class: org.springframework.cloud.sleuth.instrument.rsocket.TraceRSocketTests.1
            public String toString() {
                return "b3";
            }
        }).send().block();
    }

    private void whenNonSampledRequestResponseIsSent(RSocketRequester rSocketRequester) {
        rSocketRequester.route("api.c2.rr", new Object[0]).metadata(expectedTraceId() + "-" + expectedSpanId() + "-0", new MimeType("b3") { // from class: org.springframework.cloud.sleuth.instrument.rsocket.TraceRSocketTests.2
            public String toString() {
                return "b3";
            }
        }).retrieveMono(String.class).block();
    }

    private void whenNonSampledRequestStreamIsSent(RSocketRequester rSocketRequester) {
        rSocketRequester.route("api.c2.rs", new Object[0]).metadata(expectedTraceId() + "-" + expectedSpanId() + "-0", new MimeType("b3") { // from class: org.springframework.cloud.sleuth.instrument.rsocket.TraceRSocketTests.3
            public String toString() {
                return "b3";
            }
        }).retrieveFlux(String.class).blockLast();
    }

    private void whenNonSampledRequestChannelIsSent(RSocketRequester rSocketRequester) {
        rSocketRequester.route("api.c2.rc", new Object[0]).metadata(expectedTraceId() + "-" + expectedSpanId() + "-0", new MimeType("b3") { // from class: org.springframework.cloud.sleuth.instrument.rsocket.TraceRSocketTests.4
            public String toString() {
                return "b3";
            }
        }).data(Flux.fromArray(new String[]{"test1", "test2"})).retrieveFlux(String.class).blockLast();
    }

    private void thenNoSpanWasReported(TestSpanHandler testSpanHandler, TestController testController, String str) {
        Awaitility.await().untilAsserted(() -> {
            BDDAssertions.then(testController.getSpan()).isNotNull();
            BDDAssertions.then(testController.getSpan().context().traceId()).isEqualTo(str);
        });
    }
}
