package org.springframework.cloud.function.grpc;

import com.google.protobuf.GeneratedMessageV3;
import io.grpc.Status;
import io.grpc.stub.ServerCallStreamObserver;
import io.grpc.stub.StreamObserver;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.springframework.cloud.function.context.FunctionCatalog;
import org.springframework.cloud.function.context.FunctionProperties;
import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry;
import org.springframework.context.SmartLifecycle;
import org.springframework.messaging.Message;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Sinks;

/* loaded from: input_file:org/springframework/cloud/function/grpc/MessageHandlingHelper.class */
public class MessageHandlingHelper<T extends GeneratedMessageV3> implements SmartLifecycle {
    private final List<GrpcMessageConverter<?>> grpcConverters;
    private final FunctionProperties funcProperties;
    private final FunctionCatalog functionCatalog;
    private boolean running;
    private Log logger = LogFactory.getLog(MessageHandlingHelper.class);
    private final ExecutorService executor = Executors.newCachedThreadPool();

    public MessageHandlingHelper(List<GrpcMessageConverter<?>> list, FunctionCatalog functionCatalog, FunctionProperties functionProperties) {
        this.grpcConverters = list;
        this.funcProperties = functionProperties;
        this.functionCatalog = functionCatalog;
    }

    public void requestReply(T t, StreamObserver<T> streamObserver) {
        Message<byte[]> springMessage = toSpringMessage(t);
        streamObserver.onNext(toGrpcMessage((Message) resolveFunction(springMessage.getHeaders()).apply(springMessage), t.getClass()));
        streamObserver.onCompleted();
    }

    public void serverStream(T t, StreamObserver<T> streamObserver) {
        Message<byte[]> springMessage = toSpringMessage(t);
        Flux.from((Publisher) resolveFunction(springMessage.getHeaders()).apply(springMessage)).doOnNext(message -> {
            streamObserver.onNext(toGrpcMessage(message, t.getClass()));
        }).doOnComplete(() -> {
            streamObserver.onCompleted();
        }).subscribe();
    }

    public StreamObserver<T> clientStream(final StreamObserver<T> streamObserver, final Class<T> cls) {
        final ServerCallStreamObserver serverCallStreamObserver = (ServerCallStreamObserver) streamObserver;
        serverCallStreamObserver.disableAutoInboundFlowControl();
        SimpleFunctionRegistry.FunctionInvocationWrapper resolveFunction = resolveFunction(null);
        AtomicBoolean atomicBoolean = new AtomicBoolean(false);
        serverCallStreamObserver.setOnReadyHandler(() -> {
            if (!serverCallStreamObserver.isReady() || atomicBoolean.get()) {
                return;
            }
            atomicBoolean.set(true);
            this.logger.info("gRPC Server receiving stream is ready.");
            serverCallStreamObserver.request(1);
        });
        if (!resolveFunction.isInputTypePublisher()) {
            throw new UnsupportedOperationException("The client streaming is not supported for functions that accept non-Publisher: " + resolveFunction);
        }
        if (resolveFunction.isOutputTypePublisher()) {
            throw new UnsupportedOperationException("The client streaming is not supported for functions that return Publisher: " + resolveFunction);
        }
        final Sinks.Many onBackpressureBuffer = Sinks.many().unicast().onBackpressureBuffer();
        Flux asFlux = onBackpressureBuffer.asFlux();
        final LinkedBlockingQueue linkedBlockingQueue = new LinkedBlockingQueue(1);
        this.executor.execute(() -> {
            Message message = (Message) resolveFunction.apply(asFlux);
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("Function invocation reply: " + message);
            }
            linkedBlockingQueue.offer(message);
        });
        return (StreamObserver<T>) new StreamObserver<T>() { // from class: org.springframework.cloud.function.grpc.MessageHandlingHelper.1
            public void onNext(T t) {
                if (MessageHandlingHelper.this.logger.isDebugEnabled()) {
                    MessageHandlingHelper.this.logger.debug("gRPC Server receiving: " + t);
                }
                onBackpressureBuffer.tryEmitNext(MessageHandlingHelper.this.toSpringMessage(t));
                serverCallStreamObserver.request(1);
            }

            public void onError(Throwable th) {
                th.printStackTrace();
                streamObserver.onError(Status.UNKNOWN.withDescription("Error handling request").withCause(th).asRuntimeException());
            }

            public void onCompleted() {
                MessageHandlingHelper.this.logger.info("gRPC Server has finished receiving data.");
                onBackpressureBuffer.tryEmitComplete();
                try {
                    streamObserver.onNext(MessageHandlingHelper.this.toGrpcMessage((Message) linkedBlockingQueue.poll(2147483647L, TimeUnit.MILLISECONDS), cls));
                } catch (InterruptedException e) {
                    Thread.currentThread().interrupt();
                } finally {
                    streamObserver.onCompleted();
                }
            }
        };
    }

    public StreamObserver<T> biStream(StreamObserver<T> streamObserver, Class<T> cls) {
        ServerCallStreamObserver<T> serverCallStreamObserver = (ServerCallStreamObserver) streamObserver;
        serverCallStreamObserver.disableAutoInboundFlowControl();
        SimpleFunctionRegistry.FunctionInvocationWrapper resolveFunction = resolveFunction(null);
        AtomicBoolean atomicBoolean = new AtomicBoolean(false);
        serverCallStreamObserver.setOnReadyHandler(() -> {
            if (!serverCallStreamObserver.isReady() || atomicBoolean.get()) {
                return;
            }
            atomicBoolean.set(true);
            this.logger.info("gRPC Server receiving stream is ready.");
            serverCallStreamObserver.request(1);
        });
        if (resolveFunction.isInputTypePublisher()) {
            if (resolveFunction.isOutputTypePublisher()) {
                return biStreamReactive(streamObserver, serverCallStreamObserver, cls);
            }
            UnsupportedOperationException unsupportedOperationException = new UnsupportedOperationException("The bi-directional streaming is not supported for functions that accept Publisher but return non-Publisher: " + resolveFunction);
            streamObserver.onCompleted();
            throw unsupportedOperationException;
        }
        if (!resolveFunction.isOutputTypePublisher()) {
            return biStreamImperative(streamObserver, serverCallStreamObserver, atomicBoolean);
        }
        UnsupportedOperationException unsupportedOperationException2 = new UnsupportedOperationException("The bidirection streaming is not supported for functions that accept non-Publisher but return Publisher: " + resolveFunction);
        streamObserver.onCompleted();
        throw unsupportedOperationException2;
    }

    private StreamObserver<T> biStreamReactive(final StreamObserver<T> streamObserver, final ServerCallStreamObserver<T> serverCallStreamObserver, Class<T> cls) {
        final Sinks.Many onBackpressureBuffer = Sinks.many().unicast().onBackpressureBuffer();
        Flux.from((Publisher) resolveFunction(null).apply(onBackpressureBuffer.asFlux())).subscribe(message -> {
            T grpcMessage = toGrpcMessage(message, cls);
            if (this.logger.isDebugEnabled()) {
                this.logger.debug("gRPC Server replying: " + grpcMessage);
            }
            streamObserver.onNext(grpcMessage);
        });
        return (StreamObserver<T>) new StreamObserver<T>() { // from class: org.springframework.cloud.function.grpc.MessageHandlingHelper.2
            public void onNext(T t) {
                if (MessageHandlingHelper.this.logger.isDebugEnabled()) {
                    MessageHandlingHelper.this.logger.debug("gRPC Server receiving: " + t);
                }
                onBackpressureBuffer.tryEmitNext(MessageHandlingHelper.this.toSpringMessage(t));
                serverCallStreamObserver.request(1);
            }

            public void onError(Throwable th) {
                th.printStackTrace();
                onBackpressureBuffer.tryEmitComplete();
                streamObserver.onError(Status.UNKNOWN.withDescription("Error handling request").withCause(th).asException());
            }

            public void onCompleted() {
                MessageHandlingHelper.this.logger.info("gRPC Server has finished receiving data.");
                onBackpressureBuffer.tryEmitComplete();
                streamObserver.onCompleted();
            }
        };
    }

    private StreamObserver<T> biStreamImperative(final StreamObserver<T> streamObserver, final ServerCallStreamObserver<T> serverCallStreamObserver, final AtomicBoolean atomicBoolean) {
        return (StreamObserver<T>) new StreamObserver<T>() { // from class: org.springframework.cloud.function.grpc.MessageHandlingHelper.3
            public void onNext(T t) {
                try {
                    Message springMessage = MessageHandlingHelper.this.toSpringMessage(t);
                    streamObserver.onNext(MessageHandlingHelper.this.toGrpcMessage((Message) MessageHandlingHelper.this.resolveFunction(springMessage.getHeaders()).apply(springMessage), t.getClass()));
                    if (serverCallStreamObserver.isReady()) {
                        serverCallStreamObserver.request(1);
                    } else {
                        atomicBoolean.set(false);
                    }
                } catch (Throwable th) {
                    th.printStackTrace();
                    streamObserver.onError(Status.UNKNOWN.withDescription("Error handling request").withCause(th).asException());
                }
            }

            public void onError(Throwable th) {
                th.printStackTrace();
                streamObserver.onCompleted();
            }

            public void onCompleted() {
                MessageHandlingHelper.this.logger.info("gRPC Server has finished receiving data.");
                streamObserver.onCompleted();
            }
        };
    }

    /* JADX INFO: Access modifiers changed from: private */
    public T toGrpcMessage(Message<byte[]> message, Class<T> cls) {
        Iterator<GrpcMessageConverter<?>> it = this.grpcConverters.iterator();
        while (it.hasNext()) {
            T t = (T) it.next().fromSpringMessage(message, cls);
            if (t != null) {
                return t;
            }
        }
        throw new IllegalStateException("Failed to convert Grpc Message to Spring Message: " + message);
    }

    public void start() {
        this.running = true;
    }

    public void stop() {
        this.executor.shutdown();
        try {
            Assert.isTrue(this.executor.awaitTermination(5000L, TimeUnit.MILLISECONDS), "gRPC Server executor timed out while stopping, since there are currently executing tasks");
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
        }
        this.running = false;
    }

    public boolean isRunning() {
        return this.running;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public Message<byte[]> toSpringMessage(GeneratedMessageV3 generatedMessageV3) {
        Iterator<GrpcMessageConverter<?>> it = this.grpcConverters.iterator();
        while (it.hasNext()) {
            Message<byte[]> springMessage = it.next().toSpringMessage(generatedMessageV3);
            if (springMessage != null) {
                return springMessage;
            }
        }
        throw new IllegalStateException("Failed to convert Grpc Message to Spring Message: " + generatedMessageV3);
    }

    /* JADX INFO: Access modifiers changed from: private */
    public SimpleFunctionRegistry.FunctionInvocationWrapper resolveFunction(Map<String, Object> map) {
        String definition = this.funcProperties.getDefinition();
        if (!CollectionUtils.isEmpty(map) && map.containsKey("spring.cloud.function.definition")) {
            definition = (String) map.get("spring.cloud.function.definition");
        }
        SimpleFunctionRegistry.FunctionInvocationWrapper functionInvocationWrapper = (SimpleFunctionRegistry.FunctionInvocationWrapper) this.functionCatalog.lookup(definition, new String[]{"application/json"});
        Assert.notNull(functionInvocationWrapper, "Failed to lookup function " + this.funcProperties.getDefinition());
        return functionInvocationWrapper;
    }
}
