package org.springframework.cloud.function.observability;

import io.micrometer.observation.Observation;
import io.micrometer.tracing.Span;
import io.micrometer.tracing.Tracer;
import io.micrometer.tracing.handler.TracingObservationHandler;
import io.micrometer.tracing.propagation.Propagator;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry;
import org.springframework.cloud.function.context.message.MessageUtils;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.MessagingException;
import org.springframework.messaging.support.ErrorMessage;
import org.springframework.messaging.support.GenericMessage;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.messaging.support.MessageHeaderAccessor;
import org.springframework.util.StringUtils;

/* loaded from: input_file:org/springframework/cloud/function/observability/FunctionTracingObservationHandler.class */
public class FunctionTracingObservationHandler implements TracingObservationHandler<FunctionContext> {
    private static final Log log = LogFactory.getLog(FunctionTracingObservationHandler.class);
    private static final String REMOTE_SERVICE_NAME = "broker";
    private final Tracer tracer;
    private final Propagator propagator;
    private final Propagator.Getter<MessageHeaderAccessor> getter;
    private final Propagator.Setter<MessageHeaderAccessor> setter;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/cloud/function/observability/FunctionTracingObservationHandler$MessageAndSpan.class */
    public static class MessageAndSpan {
        final Message msg;
        final Span span;

        MessageAndSpan(Message message, Span span) {
            this.msg = message;
            this.span = span;
        }

        public String toString() {
            return "MessageAndSpan{msg=" + this.msg + ", span=" + this.span + "}";
        }
    }

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:org/springframework/cloud/function/observability/FunctionTracingObservationHandler$MessageAndSpans.class */
    public static class MessageAndSpans {
        final Message msg;
        final Span parentSpan;
        final Span childSpan;

        MessageAndSpans(Message message, Span span, Span span2) {
            this.msg = message;
            this.parentSpan = span;
            this.childSpan = span2;
        }

        public String toString() {
            return "MessageAndSpans{msg=" + this.msg + ", parentSpan=" + this.parentSpan + ", childSpan=" + this.childSpan + "}";
        }
    }

    public FunctionTracingObservationHandler(Tracer tracer, Propagator propagator, MessageHeaderPropagatorGetter messageHeaderPropagatorGetter, MessageHeaderPropagatorSetter messageHeaderPropagatorSetter) {
        this.tracer = tracer;
        this.propagator = propagator;
        this.getter = messageHeaderPropagatorGetter;
        this.setter = messageHeaderPropagatorSetter;
    }

    public void onStart(FunctionContext functionContext) {
        Span span;
        Message<?> message = (Message) functionContext.getInput();
        MessageAndSpans messageAndSpans = null;
        SimpleFunctionRegistry.FunctionInvocationWrapper targetFunction = functionContext.getTargetFunction();
        if (message == null && targetFunction.isSupplier()) {
            if (log.isDebugEnabled()) {
                log.debug("Creating a span for a supplier");
            }
            span = this.tracer.nextSpan().start();
        } else {
            if (log.isDebugEnabled()) {
                log.debug("Will retrieve the tracing headers from the message");
            }
            messageAndSpans = wrapInputMessage(functionContext, message);
            if (log.isDebugEnabled()) {
                log.debug("Wrapped input msg " + messageAndSpans);
            }
            span = messageAndSpans.childSpan;
        }
        functionContext.put(MessageAndSpans.class, messageAndSpans);
        getTracingContext(functionContext).setSpan(span);
    }

    public void onStop(FunctionContext functionContext) {
        MessageAndSpans messageAndSpans = (MessageAndSpans) functionContext.get(MessageAndSpans.class);
        Span requiredSpan = getRequiredSpan(functionContext);
        requiredSpan.name(functionContext.getTargetFunction().getFunctionDefinition()).end();
        Message<?> message = toMessage(functionContext.getOutput());
        if (log.isDebugEnabled()) {
            log.debug("Will instrument the output message");
        }
        MessageAndSpan wrapOutputMessage = messageAndSpans != null ? wrapOutputMessage(message, messageAndSpans.parentSpan, functionContext) : wrapOutputMessage(message, requiredSpan, functionContext);
        if (log.isDebugEnabled()) {
            log.debug("Wrapped output msg " + wrapOutputMessage);
        }
        wrapOutputMessage.span.end();
        functionContext.setModifiedOutput(wrapOutputMessage.msg);
    }

    private Message<?> toMessage(Object obj) {
        return !(obj instanceof Message) ? MessageBuilder.withPayload(obj).build() : (Message) obj;
    }

    private MessageAndSpans wrapInputMessage(FunctionContext functionContext, Message<?> message) {
        MessageHeaderAccessor mutableHeaderAccessor = mutableHeaderAccessor(message);
        Span consumerSpan = consumerSpan(functionContext, this.propagator.extract(mutableHeaderAccessor, this.getter));
        if (log.isDebugEnabled()) {
            log.debug("Built a consumer span " + consumerSpan);
        }
        Span start = this.tracer.nextSpan(consumerSpan).name(functionContext.getContextualName()).start();
        clearTracingHeaders(mutableHeaderAccessor);
        if (message instanceof ErrorMessage) {
            return new MessageAndSpans(new ErrorMessage((Throwable) message.getPayload(), mutableHeaderAccessor.getMessageHeaders()), consumerSpan, start);
        }
        mutableHeaderAccessor.setImmutable();
        return new MessageAndSpans(new GenericMessage(message.getPayload(), mutableHeaderAccessor.getMessageHeaders()), consumerSpan, start);
    }

    private Span consumerSpan(FunctionContext functionContext, Span.Builder builder) {
        builder.kind(Span.Kind.CONSUMER).name("handle");
        builder.remoteServiceName(REMOTE_SERVICE_NAME);
        Span start = builder.start();
        tagSpan(functionContext, start);
        start.end();
        return start;
    }

    private MessageHeaderAccessor mutableHeaderAccessor(Message<?> message) {
        MessageHeaderAccessor accessor = MessageHeaderAccessor.getAccessor(message, MessageHeaderAccessor.class);
        if (accessor != null && accessor.isMutable()) {
            return accessor;
        }
        MessageHeaderAccessor mutableAccessor = MessageHeaderAccessor.getMutableAccessor(message);
        mutableAccessor.setLeaveMutable(true);
        return mutableAccessor;
    }

    private void clearTracingHeaders(MessageHeaderAccessor messageHeaderAccessor) {
        MessageHeaderPropagatorSetter.removeHeaders(messageHeaderAccessor, this.propagator.fields());
    }

    private MessageAndSpan wrapOutputMessage(Message<?> message, Span span, FunctionContext functionContext) {
        Message<?> message2 = getMessage(message);
        MessageHeaderAccessor mutableHeaderAccessor = mutableHeaderAccessor(message2);
        Span.Builder parent = this.tracer.spanBuilder().setParent(span.context());
        clearTracingHeaders(mutableHeaderAccessor);
        Span createProducerSpan = createProducerSpan(functionContext, mutableHeaderAccessor, parent);
        this.propagator.inject(createProducerSpan.context(), mutableHeaderAccessor, this.setter);
        if (log.isDebugEnabled()) {
            log.debug("Created a new span output message " + parent);
        }
        return new MessageAndSpan(outputMessage(message, message2, mutableHeaderAccessor), createProducerSpan);
    }

    private Message<?> getMessage(Message<?> message) {
        Message<?> failedMessage;
        Object payload = message.getPayload();
        if ((payload instanceof MessagingException) && (failedMessage = ((MessagingException) payload).getFailedMessage()) != null) {
            return failedMessage;
        }
        return message;
    }

    private Span createProducerSpan(FunctionContext functionContext, MessageHeaderAccessor messageHeaderAccessor, Span.Builder builder) {
        builder.kind(Span.Kind.PRODUCER).name("send").remoteServiceName(toRemoteServiceName(messageHeaderAccessor));
        Span start = builder.start();
        if (!start.isNoop()) {
            tagSpan(functionContext, start);
        }
        return start;
    }

    private String toRemoteServiceName(MessageHeaderAccessor messageHeaderAccessor) {
        String str = (String) messageHeaderAccessor.getHeader(MessageUtils.TARGET_PROTOCOL);
        if (!StringUtils.hasLength(str)) {
            str = REMOTE_SERVICE_NAME;
        }
        return str;
    }

    private Message<?> outputMessage(Message<?> message, Message<?> message2, MessageHeaderAccessor messageHeaderAccessor) {
        MessageHeaderAccessor mutableHeaderAccessor = mutableHeaderAccessor(message);
        if (!(message instanceof ErrorMessage)) {
            mutableHeaderAccessor.copyHeaders(messageHeaderAccessor.getMessageHeaders());
            return new GenericMessage(message2.getPayload(), isWebSockets(mutableHeaderAccessor) ? mutableHeaderAccessor.getMessageHeaders() : new MessageHeaders(mutableHeaderAccessor.getMessageHeaders()));
        }
        ErrorMessage errorMessage = (ErrorMessage) message;
        mutableHeaderAccessor.copyHeaders(MessageHeaderPropagatorSetter.copyHeaders(messageHeaderAccessor.getMessageHeaders(), this.propagator.fields()));
        return new ErrorMessage((Throwable) errorMessage.getPayload(), isWebSockets(mutableHeaderAccessor) ? mutableHeaderAccessor.getMessageHeaders() : new MessageHeaders(mutableHeaderAccessor.getMessageHeaders()), errorMessage.getOriginalMessage());
    }

    private boolean isWebSockets(MessageHeaderAccessor messageHeaderAccessor) {
        return messageHeaderAccessor.getMessageHeaders().containsKey("stompCommand") || messageHeaderAccessor.getMessageHeaders().containsKey("simpMessageType");
    }

    public Tracer getTracer() {
        return this.tracer;
    }

    public boolean supportsContext(Observation.Context context) {
        return (context instanceof FunctionContext) && (((FunctionContext) context).getInput() instanceof Message);
    }
}
