package org.springframework.yarn.am.allocate;

import java.util.ArrayList;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.yarn.api.protocolrecords.AllocateRequest;
import org.apache.hadoop.yarn.api.protocolrecords.AllocateResponse;
import org.apache.hadoop.yarn.api.records.Container;
import org.apache.hadoop.yarn.api.records.ContainerId;
import org.apache.hadoop.yarn.api.records.ContainerStatus;
import org.apache.hadoop.yarn.api.records.Priority;
import org.apache.hadoop.yarn.api.records.Resource;
import org.apache.hadoop.yarn.api.records.ResourceRequest;
import org.apache.hadoop.yarn.util.Records;
import org.springframework.util.StringUtils;
import org.springframework.yarn.am.allocate.AllocationGroup;
import org.springframework.yarn.am.allocate.DefaultAllocateCountTracker;
import org.springframework.yarn.listener.CompositeContainerAllocatorListener;
import org.springframework.yarn.listener.ContainerAllocatorListener;
import org.springframework.yarn.support.compat.ResourceCompat;

/* loaded from: input_file:BOOT-INF/lib/spring-yarn-core-2.4.0.RELEASE.jar:org/springframework/yarn/am/allocate/DefaultContainerAllocator.class */
public class DefaultContainerAllocator extends AbstractPollingAllocator implements ContainerAllocator {
    private static final Log log = LogFactory.getLog(DefaultContainerAllocator.class);
    private String labelExpression;
    private CompositeContainerAllocatorListener allocatorListener = new CompositeContainerAllocatorListener();
    private int priority = 0;
    private int virtualcores = 1;
    private int memory = 64;
    private boolean locality = false;
    private AtomicInteger requestId = new AtomicInteger();
    private float applicationProgress = 0.0f;
    private Queue<ContainerId> releaseContainers = new ConcurrentLinkedQueue();
    private Set<ContainerId> garbageContainers = new HashSet();
    private AtomicBoolean allocationDirty = new AtomicBoolean();
    private final List<ResourceRequest> EMPTY = new ArrayList();
    private final AllocationGroups allocationGroups = new AllocationGroups();

    @Override // org.springframework.yarn.am.allocate.AbstractPollingAllocator, org.springframework.yarn.am.allocate.AbstractAllocator, org.springframework.yarn.support.LifecycleObjectSupport
    protected void onInit() throws Exception {
        super.onInit();
        internalInit();
    }

    @Override // org.springframework.yarn.am.allocate.ContainerAllocator
    public void allocateContainers(int i) {
        if (log.isDebugEnabled()) {
            log.debug("Incoming count: " + i);
        }
        ContainerAllocateData containerAllocateData = new ContainerAllocateData();
        containerAllocateData.addAny(i);
        allocateContainers(containerAllocateData);
    }

    @Override // org.springframework.yarn.am.allocate.ContainerAllocator
    public void addListener(ContainerAllocatorListener containerAllocatorListener) {
        this.allocatorListener.register(containerAllocatorListener);
    }

    @Override // org.springframework.yarn.am.allocate.ContainerAllocator
    public void allocateContainers(ContainerAllocateData containerAllocateData) {
        log.info("Incoming containerAllocateData: " + containerAllocateData);
        String id = StringUtils.hasText(containerAllocateData.getId()) ? containerAllocateData.getId() : "";
        AllocationGroup allocationGroup = this.allocationGroups.get(id);
        if (allocationGroup == null) {
            allocationGroup = this.allocationGroups.get("");
        }
        boolean z = false;
        ContainerAllocateData byAny = containerAllocateData.byAny();
        if (byAny.hasData()) {
            z = true;
            DefaultAllocateCountTracker allocateCountTracker = allocationGroup.getAllocateCountTracker("any");
            if (allocateCountTracker == null) {
                this.allocationGroups.reserve(id, "any");
                allocateCountTracker = new DefaultAllocateCountTracker("any", getConfiguration());
                allocationGroup.setAllocateCountTracker("any", allocateCountTracker);
            }
            log.info("State allocateCountTracker before adding allocation data: " + allocateCountTracker);
            allocateCountTracker.addContainers(byAny);
            log.info("State allocateCountTracker after adding allocation data: " + allocateCountTracker);
        }
        ContainerAllocateData byHosts = containerAllocateData.byHosts();
        if (byHosts.hasData()) {
            z = true;
            DefaultAllocateCountTracker allocateCountTracker2 = allocationGroup.getAllocateCountTracker("host");
            if (allocateCountTracker2 == null) {
                this.allocationGroups.reserve(id, "host");
                allocateCountTracker2 = new DefaultAllocateCountTracker("host", getConfiguration());
                allocationGroup.setAllocateCountTracker("host", allocateCountTracker2);
            }
            log.info("State allocateCountTracker before adding allocation data: " + allocateCountTracker2);
            allocateCountTracker2.addContainers(byHosts);
            log.info("State allocateCountTracker after adding allocation data: " + allocateCountTracker2);
        }
        ContainerAllocateData byRacks = containerAllocateData.byRacks();
        if (byRacks.hasData()) {
            z = true;
            DefaultAllocateCountTracker allocateCountTracker3 = allocationGroup.getAllocateCountTracker(AllocationGroup.GROUP_RACK);
            if (allocateCountTracker3 == null) {
                this.allocationGroups.reserve(id, AllocationGroup.GROUP_RACK);
                allocateCountTracker3 = new DefaultAllocateCountTracker(AllocationGroup.GROUP_RACK, getConfiguration());
                allocationGroup.setAllocateCountTracker(AllocationGroup.GROUP_RACK, allocateCountTracker3);
            }
            log.info("State allocateCountTracker before adding allocation data: " + allocateCountTracker3);
            allocateCountTracker3.addContainers(byRacks);
            log.info("State allocateCountTracker after adding allocation data: " + allocateCountTracker3);
        }
        if (z) {
            this.allocationDirty.set(true);
        }
    }

    @Override // org.springframework.yarn.am.allocate.ContainerAllocator
    public void releaseContainers(List<Container> list) {
        Iterator<Container> it = list.iterator();
        while (it.hasNext()) {
            releaseContainer(it.next().getId());
        }
    }

    @Override // org.springframework.yarn.am.allocate.ContainerAllocator
    public void releaseContainer(ContainerId containerId) {
        log.info("Adding new container to be released containerId=" + containerId);
        this.releaseContainers.add(containerId);
    }

    @Override // org.springframework.yarn.am.allocate.AbstractPollingAllocator
    protected AllocateResponse doContainerRequest() {
        List<ResourceRequest> createRequests = this.allocationDirty.getAndSet(false) ? createRequests() : this.EMPTY;
        ArrayList arrayList = new ArrayList();
        while (true) {
            ContainerId poll = this.releaseContainers.poll();
            if (poll == null) {
                break;
            }
            arrayList.add(poll);
        }
        if (log.isDebugEnabled()) {
            log.debug("Requesting containers using " + createRequests.size() + " requests.");
            for (ResourceRequest resourceRequest : createRequests) {
                log.debug("ResourceRequest: " + resourceRequest + " with count=" + resourceRequest.getNumContainers() + " with hostName=" + resourceRequest.getResourceName());
            }
            log.debug("Releasing containers " + arrayList.size());
            Iterator<ContainerId> it = arrayList.iterator();
            while (it.hasNext()) {
                log.debug("Release container=" + it.next());
            }
            log.debug("Request id will be: " + this.requestId.get());
        }
        AllocateRequest allocateRequest = (AllocateRequest) Records.newRecord(AllocateRequest.class);
        allocateRequest.setResponseId(this.requestId.get());
        allocateRequest.setAskList(createRequests);
        allocateRequest.setReleaseList(arrayList);
        allocateRequest.setProgress(this.applicationProgress);
        AllocateResponse allocate = getRmTemplate().allocate(allocateRequest);
        this.requestId.set(allocate.getResponseId());
        return allocate;
    }

    @Override // org.springframework.yarn.am.allocate.AbstractPollingAllocator
    protected List<Container> preProcessAllocatedContainers(List<Container> list) {
        ArrayList arrayList = new ArrayList();
        for (Container container : list) {
            DefaultAllocateCountTracker allocateCountTracker = this.allocationGroups.get(Integer.valueOf(container.getPriority().getPriority())).getAllocateCountTracker(Integer.valueOf(container.getPriority().getPriority()));
            if (log.isDebugEnabled()) {
                log.debug("State allocateCountTracker before handling allocated container: " + allocateCountTracker);
            }
            Container processAllocatedContainer = allocateCountTracker.processAllocatedContainer(container);
            if (processAllocatedContainer != null) {
                arrayList.add(processAllocatedContainer);
            } else {
                this.garbageContainers.add(container.getId());
                this.releaseContainers.add(container.getId());
            }
            if (log.isDebugEnabled()) {
                log.debug("State allocateCountTracker after handling allocated container: " + allocateCountTracker);
            }
        }
        return arrayList;
    }

    @Override // org.springframework.yarn.am.allocate.AbstractPollingAllocator
    protected void handleAllocatedContainers(List<Container> list) {
        this.allocatorListener.allocated(list);
    }

    @Override // org.springframework.yarn.am.allocate.AbstractPollingAllocator
    protected void handleCompletedContainers(List<ContainerStatus> list) {
        ArrayList arrayList = new ArrayList();
        for (ContainerStatus containerStatus : list) {
            if (!this.garbageContainers.contains(containerStatus.getContainerId())) {
                arrayList.add(containerStatus);
            }
        }
        this.allocatorListener.completed(arrayList);
    }

    @Override // org.springframework.yarn.am.allocate.ContainerAllocator
    public void setProgress(float f) {
        this.applicationProgress = f;
    }

    public void setAllocationValues(String str, Integer num, String str2, Integer num2, Integer num3, Boolean bool) {
        if (log.isTraceEnabled()) {
            log.trace("setAllocationValues 1: id=" + str + " priority=" + num + " labelExpression=" + str2 + " cores=" + num2 + " memory=" + num3 + " locality=" + bool);
        }
        this.allocationGroups.add(StringUtils.hasText(str) ? str : "", num).setContainerAllocationValues(new AllocationGroup.ContainerAllocationValues(num, str2, num2, num3, bool));
    }

    public int getPriority() {
        return this.priority;
    }

    public void setPriority(int i) {
        this.priority = i;
    }

    public void setLabelExpression(String str) {
        this.labelExpression = str;
    }

    public int getVirtualcores() {
        return this.virtualcores;
    }

    public void setVirtualcores(int i) {
        this.virtualcores = i;
    }

    public int getMemory() {
        return this.memory;
    }

    public void setMemory(int i) {
        this.memory = i;
    }

    public boolean isLocality() {
        return this.locality;
    }

    public void setLocality(boolean z) {
        this.locality = z;
    }

    private List<ResourceRequest> createRequests() {
        ArrayList arrayList = new ArrayList();
        for (AllocationGroup allocationGroup : this.allocationGroups.getGroups()) {
            for (DefaultAllocateCountTracker defaultAllocateCountTracker : allocationGroup.getAllocateCountTrackers()) {
                DefaultAllocateCountTracker.AllocateCountInfo allocateCounts = defaultAllocateCountTracker.getAllocateCounts();
                Integer priority = allocationGroup.getPriority(defaultAllocateCountTracker.getId());
                AllocationGroup.ContainerAllocationValues containerAllocationValues = allocationGroup.getContainerAllocationValues();
                if (priority == null) {
                    priority = Integer.valueOf(containerAllocationValues.priority);
                }
                if (log.isTraceEnabled()) {
                    log.trace("trace 1 " + containerAllocationValues.locality);
                    log.trace("trace 2 tracker id:" + defaultAllocateCountTracker.getId());
                }
                boolean z = false;
                for (Map.Entry<String, Integer> entry : allocateCounts.hostsInfo.entrySet()) {
                    if (log.isTraceEnabled()) {
                        log.trace("trace 3 entry key=" + entry.getKey() + " value=" + entry.getValue());
                    }
                    arrayList.add(getContainerResourceRequest(containerAllocationValues, priority.intValue(), entry.getValue().intValue(), entry.getKey(), true));
                    z = true;
                }
                for (Map.Entry<String, Integer> entry2 : allocateCounts.racksInfo.entrySet()) {
                    if (log.isTraceEnabled()) {
                        log.trace("trace 4 entry key=" + entry2.getKey() + " value=" + entry2.getValue());
                    }
                    arrayList.add(getContainerResourceRequest(containerAllocationValues, priority.intValue(), entry2.getValue().intValue(), entry2.getKey(), (z && containerAllocationValues.locality) ? false : true));
                }
                for (Map.Entry<String, Integer> entry3 : allocateCounts.anysInfo.entrySet()) {
                    if (log.isTraceEnabled()) {
                        log.trace("trace 5 entry key=" + entry3.getKey() + " value=" + entry3.getValue());
                    }
                    arrayList.add(getContainerResourceRequest(containerAllocationValues, priority.intValue(), entry3.getValue().intValue(), entry3.getKey(), defaultAllocateCountTracker.getId().equals("any") ? true : !containerAllocationValues.locality));
                }
            }
        }
        return arrayList;
    }

    private void internalInit() {
        setAllocationValues(null, Integer.valueOf(this.priority), this.labelExpression, Integer.valueOf(this.virtualcores), Integer.valueOf(this.memory), Boolean.valueOf(this.locality));
        Iterator<DefaultAllocateCountTracker> it = this.allocationGroups.getAllocateCountTrackers().iterator();
        while (it.hasNext()) {
            it.next().setConfiguration(getConfiguration());
        }
    }

    private ResourceRequest getContainerResourceRequest(AllocationGroup.ContainerAllocationValues containerAllocationValues, int i, int i2, String str, boolean z) {
        ResourceRequest resourceRequest = (ResourceRequest) Records.newRecord(ResourceRequest.class);
        resourceRequest.setRelaxLocality(z);
        resourceRequest.setResourceName(str);
        resourceRequest.setNumContainers(i2);
        resourceRequest.setNodeLabelExpression(containerAllocationValues.labelExpression);
        Priority priority = (Priority) Records.newRecord(Priority.class);
        priority.setPriority(i);
        resourceRequest.setPriority(priority);
        Resource resource = (Resource) Records.newRecord(Resource.class);
        resource.setMemory(containerAllocationValues.memory);
        ResourceCompat.setVirtualCores(resource, containerAllocationValues.virtualcores);
        resourceRequest.setCapability(resource);
        return resourceRequest;
    }
}
