package org.springframework.cloud.gateway.rsocket.core;

import io.micrometer.core.instrument.MeterRegistry;
import io.micrometer.core.instrument.Tags;
import io.micrometer.core.instrument.Timer;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.ResponderRSocket;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;
import java.util.logging.Level;
import java.util.stream.Collectors;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.reactivestreams.Publisher;
import org.springframework.cloud.gateway.rsocket.autoconfigure.BrokerProperties;
import org.springframework.cloud.gateway.rsocket.common.metadata.TagsMetadata;
import org.springframework.cloud.gateway.rsocket.core.GatewayExchange;
import org.springframework.cloud.gateway.rsocket.route.Route;
import org.springframework.cloud.gateway.rsocket.route.Routes;
import org.springframework.cloud.gateway.rsocket.routing.LoadBalancerFactory;
import org.springframework.messaging.rsocket.MetadataExtractor;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.core.publisher.SignalType;

/* loaded from: input_file:org/springframework/cloud/gateway/rsocket/core/GatewayRSocket.class */
public class GatewayRSocket extends AbstractGatewayRSocket {
    private static final Log log = LogFactory.getLog(GatewayRSocket.class);
    private final Routes routes;
    private final PendingRequestRSocketFactory pendingFactory;
    private final LoadBalancerFactory loadBalancerFactory;

    /* JADX INFO: Access modifiers changed from: package-private */
    public GatewayRSocket(Routes routes, PendingRequestRSocketFactory pendingRequestRSocketFactory, LoadBalancerFactory loadBalancerFactory, MeterRegistry meterRegistry, BrokerProperties brokerProperties, MetadataExtractor metadataExtractor, TagsMetadata tagsMetadata) {
        super(meterRegistry, brokerProperties, metadataExtractor, tagsMetadata);
        this.routes = routes;
        this.pendingFactory = pendingRequestRSocketFactory;
        this.loadBalancerFactory = loadBalancerFactory;
    }

    protected PendingRequestRSocketFactory getPendingFactory() {
        return this.pendingFactory;
    }

    public Mono<Void> fireAndForget(Payload payload) {
        GatewayExchange createExchange = createExchange(GatewayExchange.Type.FIRE_AND_FORGET, payload);
        return findRSocketOrCreatePending(createExchange).flatMap(list -> {
            retain(payload, list);
            return Flux.merge((List) list.stream().map(rSocket -> {
                return rSocket.fireAndForget(payload);
            }).collect(Collectors.toList())).then();
        }).doOnError(th -> {
            count(createExchange, "error");
        }).doFinally(signalType -> {
            count(createExchange, "");
        });
    }

    public Flux<Payload> requestChannel(Payload payload, Publisher<Payload> publisher) {
        GatewayExchange createExchange = createExchange(GatewayExchange.Type.REQUEST_CHANNEL, payload);
        Tags of = Tags.of("source", "responder");
        return findRSocketOrCreatePending(createExchange).flatMapMany(list -> {
            Tags of2 = Tags.of("source", "requester");
            Flux refCount = Flux.from(publisher).doOnNext(payload2 -> {
                retain(payload2, list);
                count(createExchange, "payload", of2);
            }).doOnError(th -> {
                count(createExchange, "error", of2);
            }).doFinally(signalType -> {
                count(createExchange, of2);
            }).publish().refCount(list.size());
            return Flux.merge((List) list.stream().map(rSocket -> {
                return rSocket instanceof ResponderRSocket ? ((ResponderRSocket) rSocket).requestChannel(payload, refCount) : rSocket.requestChannel(refCount);
            }).collect(Collectors.toList())).log(GatewayRSocket.class.getName() + ".request-channel", Level.FINEST, new SignalType[0]);
        }).doOnNext(payload2 -> {
            count(createExchange, "payload", of);
        }).doOnError(th -> {
            count(createExchange, "error", of);
        }).doFinally(signalType -> {
            count(createExchange, of);
        });
    }

    public Mono<Payload> requestResponse(Payload payload) {
        AtomicReference atomicReference = new AtomicReference();
        GatewayExchange createExchange = createExchange(GatewayExchange.Type.REQUEST_RESPONSE, payload);
        return findRSocketOrCreatePending(createExchange).flatMap(list -> {
            retain(payload, list);
            return Flux.merge((List) list.stream().map(rSocket -> {
                return rSocket.requestResponse(payload);
            }).collect(Collectors.toList())).next();
        }).doOnSubscribe(subscription -> {
            atomicReference.set(Timer.start(this.meterRegistry));
        }).doOnError(th -> {
            count(createExchange, "error");
        }).doFinally(signalType -> {
            ((Timer.Sample) atomicReference.get()).stop(this.meterRegistry.timer(getMetricName(createExchange), createExchange.getTags()));
        });
    }

    public Flux<Payload> requestStream(Payload payload) {
        GatewayExchange createExchange = createExchange(GatewayExchange.Type.REQUEST_STREAM, payload);
        return findRSocketOrCreatePending(createExchange).flatMapMany(list -> {
            retain(payload, list);
            return Flux.merge((List) list.stream().map(rSocket -> {
                return rSocket.requestStream(payload);
            }).collect(Collectors.toList()));
        }).doOnNext(payload2 -> {
            count(createExchange, "payload");
        }).doOnError(th -> {
            count(createExchange, "error");
        }).doFinally(signalType -> {
            count(createExchange, Tags.empty());
        });
    }

    private void retain(Payload payload, List<RSocket> list) {
        if (list.size() > 1) {
            payload.retain(list.size() - 1);
        }
    }

    private Mono<List<RSocket>> findRSocketOrCreatePending(GatewayExchange gatewayExchange) {
        return this.routes.findRoute(gatewayExchange).log(GatewayRSocket.class.getName() + ".find route", Level.FINEST, new SignalType[0]).flatMap(route -> {
            gatewayExchange.getAttributes().put(GatewayExchange.ROUTE_ATTR, route);
            return findRSocketOrCreatePending(gatewayExchange, route);
        }).switchIfEmpty(createPending(gatewayExchange));
    }

    private Mono<List<RSocket>> findRSocketOrCreatePending(GatewayExchange gatewayExchange, Route route) {
        return GatewayFilterChain.executeFilterChain(route.getFilters(), gatewayExchange).log(GatewayRSocket.class.getName() + ".after filter chain", Level.FINEST, new SignalType[0]).flatMapMany(success -> {
            return gatewayExchange.getRoutingMetadata().getTags().containsKey(new TagsMetadata.Key("multicast")) ? Flux.fromIterable(this.loadBalancerFactory.find(gatewayExchange.getRoutingMetadata())) : this.loadBalancerFactory.choose(gatewayExchange.getRoutingMetadata()).flatMapMany(tuple2 -> {
                return Flux.just(tuple2);
            });
        }).map(tuple2 -> {
            return (RSocket) tuple2.getT2();
        }).cast(RSocket.class).map(rSocket -> {
            if (log.isDebugEnabled()) {
                log.debug("Found RSocket: " + rSocket);
            }
            return rSocket;
        }).collectList().log(GatewayRSocket.class.getName() + ".find rsocket", Level.FINEST, new SignalType[0]);
    }

    protected Mono<List<RSocket>> createPending(GatewayExchange gatewayExchange) {
        if (log.isDebugEnabled()) {
            log.debug("Unable to find destination RSocket for " + gatewayExchange.getRoutingMetadata());
        }
        return this.pendingFactory.create(gatewayExchange).cast(RSocket.class).map((v0) -> {
            return Collections.singletonList(v0);
        });
    }
}
