package org.springframework.batch.integration.partition;

import ch.qos.logback.core.spi.AbstractComponentTracker;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Set;
import java.util.concurrent.Callable;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import javax.sql.DataSource;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.batch.core.StepExecution;
import org.springframework.batch.core.explore.JobExplorer;
import org.springframework.batch.core.explore.support.JobExplorerFactoryBean;
import org.springframework.batch.core.partition.PartitionHandler;
import org.springframework.batch.core.partition.StepExecutionSplitter;
import org.springframework.batch.poller.DirectPoller;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.integration.MessageTimeoutException;
import org.springframework.integration.annotation.Aggregator;
import org.springframework.integration.annotation.MessageEndpoint;
import org.springframework.integration.annotation.Payloads;
import org.springframework.integration.channel.QueueChannel;
import org.springframework.integration.core.MessagingTemplate;
import org.springframework.integration.support.MessageBuilder;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageChannel;
import org.springframework.messaging.PollableChannel;
import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils;

@MessageEndpoint
/* loaded from: input_file:lib/spring-batch-integration-3.0.6.RELEASE.jar:org/springframework/batch/integration/partition/MessageChannelPartitionHandler.class */
public class MessageChannelPartitionHandler implements PartitionHandler, InitializingBean {
    private static Log logger = LogFactory.getLog(MessageChannelPartitionHandler.class);
    private MessagingTemplate messagingGateway;
    private String stepName;
    private JobExplorer jobExplorer;
    private DataSource dataSource;
    private PollableChannel replyChannel;
    private int gridSize = 1;
    private long pollInterval = AbstractComponentTracker.LINGERING_TIMEOUT;
    private boolean pollRepositoryForResults = false;
    private long timeout = -1;

    @Override // org.springframework.beans.factory.InitializingBean
    public void afterPropertiesSet() throws Exception {
        Assert.notNull(this.stepName, "A step name must be provided for the remote workers.");
        Assert.state(this.messagingGateway != null, "The MessagingOperations must be set");
        this.pollRepositoryForResults = (this.dataSource == null && this.jobExplorer == null) ? false : true;
        if (this.pollRepositoryForResults) {
            logger.debug("MessageChannelPartitionHandler is configured to poll the job repository for slave results");
        }
        if (this.dataSource == null || this.jobExplorer != null) {
            return;
        }
        JobExplorerFactoryBean jobExplorerFactoryBean = new JobExplorerFactoryBean();
        jobExplorerFactoryBean.setDataSource(this.dataSource);
        jobExplorerFactoryBean.afterPropertiesSet();
        this.jobExplorer = jobExplorerFactoryBean.getObject2();
    }

    public void setTimeout(long j) {
        this.timeout = j;
    }

    public void setJobExplorer(JobExplorer jobExplorer) {
        this.jobExplorer = jobExplorer;
    }

    public void setPollInterval(long j) {
        this.pollInterval = j;
    }

    public void setDataSource(DataSource dataSource) {
        this.dataSource = dataSource;
    }

    public void setMessagingOperations(MessagingTemplate messagingTemplate) {
        this.messagingGateway = messagingTemplate;
    }

    public void setGridSize(int i) {
        this.gridSize = i;
    }

    public void setStepName(String str) {
        this.stepName = str;
    }

    @Aggregator(sendPartialResultsOnExpiry = true)
    public List<?> aggregate(@Payloads List<?> list) {
        return list;
    }

    public void setReplyChannel(PollableChannel pollableChannel) {
        this.replyChannel = pollableChannel;
    }

    @Override // org.springframework.batch.core.partition.PartitionHandler
    public Collection<StepExecution> handle(StepExecutionSplitter stepExecutionSplitter, StepExecution stepExecution) throws Exception {
        Set<StepExecution> split = stepExecutionSplitter.split(stepExecution, this.gridSize);
        if (CollectionUtils.isEmpty(split)) {
            return null;
        }
        int i = 0;
        PollableChannel pollableChannel = this.replyChannel;
        if (!this.pollRepositoryForResults && pollableChannel == null) {
            pollableChannel = new QueueChannel();
        }
        for (StepExecution stepExecution2 : split) {
            int i2 = i;
            i++;
            Message<StepExecutionRequest> createMessage = createMessage(i2, split.size(), new StepExecutionRequest(this.stepName, stepExecution2.getJobExecutionId(), stepExecution2.getId()), pollableChannel);
            if (logger.isDebugEnabled()) {
                logger.debug("Sending request: " + createMessage);
            }
            this.messagingGateway.send(createMessage);
        }
        return !this.pollRepositoryForResults ? receiveReplies(pollableChannel) : pollReplies(stepExecution, split);
    }

    private Collection<StepExecution> pollReplies(final StepExecution stepExecution, final Set<StepExecution> set) throws Exception {
        final ArrayList arrayList = new ArrayList(set.size());
        Future<S> poll = new DirectPoller(this.pollInterval).poll(new Callable<Collection<StepExecution>>() { // from class: org.springframework.batch.integration.partition.MessageChannelPartitionHandler.1
            /* JADX WARN: Can't rename method to resolve collision */
            @Override // java.util.concurrent.Callable
            public Collection<StepExecution> call() throws Exception {
                for (StepExecution stepExecution2 : set) {
                    if (!arrayList.contains(stepExecution2)) {
                        StepExecution stepExecution3 = MessageChannelPartitionHandler.this.jobExplorer.getStepExecution(stepExecution.getJobExecutionId(), stepExecution2.getId());
                        if (!stepExecution3.getStatus().isRunning()) {
                            arrayList.add(stepExecution3);
                        }
                    }
                }
                if (MessageChannelPartitionHandler.logger.isDebugEnabled()) {
                    MessageChannelPartitionHandler.logger.debug(String.format("Currently waiting on %s partitions to finish", Integer.valueOf(set.size())));
                }
                if (arrayList.size() == set.size()) {
                    return arrayList;
                }
                return null;
            }
        });
        return this.timeout >= 0 ? (Collection) poll.get(this.timeout, TimeUnit.MILLISECONDS) : (Collection) poll.get();
    }

    private Collection<StepExecution> receiveReplies(PollableChannel pollableChannel) {
        Message<?> receive = this.messagingGateway.receive((MessagingTemplate) pollableChannel);
        if (receive == null) {
            throw new MessageTimeoutException("Timeout occurred before all partitions returned");
        }
        if (logger.isDebugEnabled()) {
            logger.debug("Received replies: " + receive);
        }
        return (Collection) receive.getPayload();
    }

    private Message<StepExecutionRequest> createMessage(int i, int i2, StepExecutionRequest stepExecutionRequest, PollableChannel pollableChannel) {
        return MessageBuilder.withPayload(stepExecutionRequest).setSequenceNumber(Integer.valueOf(i)).setSequenceSize(Integer.valueOf(i2)).setCorrelationId((Object) (stepExecutionRequest.getJobExecutionId() + ":" + stepExecutionRequest.getStepName())).setReplyChannel((MessageChannel) pollableChannel).build();
    }
}
