package io.trino.execution.scheduler;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Suppliers;
import com.google.common.base.Ticker;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.SetMultimap;
import com.google.common.util.concurrent.ListenableFuture;
import io.airlift.log.Logger;
import io.trino.execution.NodeTaskMap;
import io.trino.execution.RemoteTask;
import io.trino.execution.resourcegroups.IndexedPriorityQueue;
import io.trino.execution.scheduler.NodeSchedulerConfig;
import io.trino.metadata.InternalNode;
import io.trino.metadata.InternalNodeManager;
import io.trino.metadata.Split;
import io.trino.spi.HostAddress;
import io.trino.spi.SplitWeight;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import jakarta.annotation.Nullable;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Supplier;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/execution/scheduler/UniformNodeSelector.class */
public class UniformNodeSelector implements NodeSelector {
    private static final Logger log = Logger.get(UniformNodeSelector.class);
    private final InternalNodeManager nodeManager;
    private final NodeTaskMap nodeTaskMap;
    private final boolean includeCoordinator;
    private final AtomicReference<Supplier<NodeMap>> nodeMap;
    private final int minCandidates;
    private final long maxSplitsWeightPerNode;
    private final long minPendingSplitsWeightPerTask;
    private final int maxUnacknowledgedSplitsPerTask;
    private final NodeSchedulerConfig.SplitsBalancingPolicy splitsBalancingPolicy;
    private final boolean optimizedLocalScheduling;
    private final QueueSizeAdjuster queueSizeAdjuster;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:io/trino/execution/scheduler/UniformNodeSelector$QueueSizeAdjuster.class */
    public static class QueueSizeAdjuster {
        private static final long SCALE_DOWN_INTERVAL = TimeUnit.SECONDS.toNanos(1);
        private final Ticker ticker;
        private final Map<String, TaskAdjustmentInfo> taskAdjustmentInfos;
        private final Set<String> previousScheduleFullTasks;
        private final long minPendingSplitsWeightPerTask;
        private final long maxAdjustedPendingSplitsWeightPerTask;

        /* JADX INFO: Access modifiers changed from: private */
        /* loaded from: input_file:io/trino/execution/scheduler/UniformNodeSelector$QueueSizeAdjuster$TaskAdjustmentInfo.class */
        public class TaskAdjustmentInfo {
            private long adjustedMaxSplitsWeightPerTask;
            private Optional<Long> lastAdjustmentNanos = Optional.empty();

            public TaskAdjustmentInfo(long j) {
                this.adjustedMaxSplitsWeightPerTask = j;
            }

            public long getAdjustedMaxSplitsWeightPerTask() {
                return this.adjustedMaxSplitsWeightPerTask;
            }

            public void setAdjustedMaxSplitsWeightPerTask(long j) {
                this.adjustedMaxSplitsWeightPerTask = j;
                this.lastAdjustmentNanos = Optional.of(Long.valueOf(QueueSizeAdjuster.this.ticker.read()));
            }

            public Optional<Long> getLastAdjustmentNanos() {
                return this.lastAdjustmentNanos;
            }
        }

        private QueueSizeAdjuster(long j, long j2) {
            this(j, j2, Ticker.systemTicker());
        }

        @VisibleForTesting
        QueueSizeAdjuster(long j, long j2, Ticker ticker) {
            this.taskAdjustmentInfos = new HashMap();
            this.previousScheduleFullTasks = new HashSet();
            this.ticker = (Ticker) Objects.requireNonNull(ticker, "ticker is null");
            this.maxAdjustedPendingSplitsWeightPerTask = j2;
            this.minPendingSplitsWeightPerTask = j;
        }

        public void update(List<RemoteTask> list, NodeAssignmentStats nodeAssignmentStats) {
            if (isEnabled()) {
                Iterator<RemoteTask> it = list.iterator();
                while (it.hasNext()) {
                    String nodeId = it.next().getNodeId();
                    TaskAdjustmentInfo computeIfAbsent = this.taskAdjustmentInfos.computeIfAbsent(nodeId, str -> {
                        return new TaskAdjustmentInfo(this.minPendingSplitsWeightPerTask);
                    });
                    Optional<Long> lastAdjustmentNanos = computeIfAbsent.getLastAdjustmentNanos();
                    if (this.previousScheduleFullTasks.contains(nodeId) && nodeAssignmentStats.getQueuedSplitsWeightForStage(nodeId) == 0) {
                        computeIfAbsent.setAdjustedMaxSplitsWeightPerTask(Math.min(this.maxAdjustedPendingSplitsWeightPerTask, computeIfAbsent.getAdjustedMaxSplitsWeightPerTask() * 2));
                    } else if (lastAdjustmentNanos.isPresent() && this.ticker.read() - lastAdjustmentNanos.get().longValue() >= SCALE_DOWN_INTERVAL) {
                        computeIfAbsent.setAdjustedMaxSplitsWeightPerTask((long) Math.max(this.minPendingSplitsWeightPerTask, computeIfAbsent.getAdjustedMaxSplitsWeightPerTask() / 1.5d));
                    }
                }
                this.previousScheduleFullTasks.clear();
            }
        }

        public long getAdjustedMaxPendingSplitsWeightPerTask(String str) {
            TaskAdjustmentInfo taskAdjustmentInfo = this.taskAdjustmentInfos.get(str);
            return taskAdjustmentInfo != null ? taskAdjustmentInfo.getAdjustedMaxSplitsWeightPerTask() : this.minPendingSplitsWeightPerTask;
        }

        public void scheduleAdjustmentForNode(String str) {
            if (isEnabled()) {
                this.previousScheduleFullTasks.add(str);
            }
        }

        private boolean isEnabled() {
            return this.maxAdjustedPendingSplitsWeightPerTask != this.minPendingSplitsWeightPerTask;
        }
    }

    public UniformNodeSelector(InternalNodeManager internalNodeManager, NodeTaskMap nodeTaskMap, boolean z, Supplier<NodeMap> supplier, int i, long j, long j2, long j3, int i2, NodeSchedulerConfig.SplitsBalancingPolicy splitsBalancingPolicy, boolean z2) {
        this(internalNodeManager, nodeTaskMap, z, supplier, i, j, j2, i2, splitsBalancingPolicy, z2, new QueueSizeAdjuster(j2, j3));
    }

    @VisibleForTesting
    UniformNodeSelector(InternalNodeManager internalNodeManager, NodeTaskMap nodeTaskMap, boolean z, Supplier<NodeMap> supplier, int i, long j, long j2, int i2, NodeSchedulerConfig.SplitsBalancingPolicy splitsBalancingPolicy, boolean z2, QueueSizeAdjuster queueSizeAdjuster) {
        this.nodeManager = (InternalNodeManager) Objects.requireNonNull(internalNodeManager, "nodeManager is null");
        this.nodeTaskMap = (NodeTaskMap) Objects.requireNonNull(nodeTaskMap, "nodeTaskMap is null");
        this.includeCoordinator = z;
        this.nodeMap = new AtomicReference<>(supplier);
        this.minCandidates = i;
        this.maxSplitsWeightPerNode = j;
        this.minPendingSplitsWeightPerTask = j2;
        this.maxUnacknowledgedSplitsPerTask = i2;
        Preconditions.checkArgument(i2 > 0, "maxUnacknowledgedSplitsPerTask must be > 0, found: %s", i2);
        this.splitsBalancingPolicy = (NodeSchedulerConfig.SplitsBalancingPolicy) Objects.requireNonNull(splitsBalancingPolicy, "splitsBalancingPolicy is null");
        this.optimizedLocalScheduling = z2;
        this.queueSizeAdjuster = queueSizeAdjuster;
    }

    @Override // io.trino.execution.scheduler.NodeSelector
    public void lockDownNodes() {
        this.nodeMap.set(Suppliers.ofInstance(this.nodeMap.get().get()));
    }

    @Override // io.trino.execution.scheduler.NodeSelector
    public List<InternalNode> allNodes() {
        return NodeScheduler.getAllNodes(this.nodeMap.get().get(), this.includeCoordinator);
    }

    @Override // io.trino.execution.scheduler.NodeSelector
    public InternalNode selectCurrentNode() {
        return this.nodeManager.getCurrentNode();
    }

    @Override // io.trino.execution.scheduler.NodeSelector
    public List<InternalNode> selectRandomNodes(int i, Set<InternalNode> set) {
        return NodeScheduler.selectNodes(i, NodeScheduler.randomizedNodes(this.nodeMap.get().get(), this.includeCoordinator, set));
    }

    @Override // io.trino.execution.scheduler.NodeSelector
    public SplitPlacementResult computeAssignments(Set<Split> set, List<RemoteTask> list) {
        HashMultimap create = HashMultimap.create();
        NodeMap nodeMap = this.nodeMap.get().get();
        NodeAssignmentStats nodeAssignmentStats = new NodeAssignmentStats(this.nodeTaskMap, nodeMap, list);
        this.queueSizeAdjuster.update(list, nodeAssignmentStats);
        HashSet hashSet = new HashSet();
        boolean z = false;
        boolean z2 = false;
        Set<Split> hashSet2 = new HashSet(set.size());
        List<InternalNode> filterNodes = NodeScheduler.filterNodes(nodeMap, this.includeCoordinator, ImmutableSet.of());
        ResettableRandomizedIterator resettableRandomizedIterator = new ResettableRandomizedIterator(filterNodes);
        HashSet hashSet3 = new HashSet(filterNodes);
        if (this.optimizedLocalScheduling) {
            for (Split split : set) {
                if (split.isRemotelyAccessible() && !split.getAddresses().isEmpty()) {
                    Stream<InternalNode> filter = NodeScheduler.selectExactNodes(nodeMap, split.getAddresses(), this.includeCoordinator).stream().filter(internalNode -> {
                        return nodeAssignmentStats.getTotalSplitsWeight(internalNode) < this.maxSplitsWeightPerNode && nodeAssignmentStats.getUnacknowledgedSplitCountForStage(internalNode) < this.maxUnacknowledgedSplitsPerTask;
                    });
                    Objects.requireNonNull(nodeAssignmentStats);
                    Optional<InternalNode> min = filter.min(Comparator.comparingLong(nodeAssignmentStats::getTotalSplitsWeight));
                    if (min.isPresent()) {
                        create.put(min.get(), split);
                        nodeAssignmentStats.addAssignedSplit(min.get(), split.getSplitWeight());
                        z2 = true;
                    }
                }
                hashSet2.add(split);
            }
        } else {
            hashSet2 = set;
        }
        for (Split split2 : hashSet2) {
            resettableRandomizedIterator.reset();
            List<InternalNode> selectExactNodes = !split2.isRemotelyAccessible() ? NodeScheduler.selectExactNodes(nodeMap, split2.getAddresses(), this.includeCoordinator) : NodeScheduler.selectNodes(this.minCandidates, resettableRandomizedIterator);
            if (selectExactNodes.isEmpty()) {
                log.debug("No nodes available to schedule %s. Available nodes %s", new Object[]{split2, nodeMap.getNodesByHost().keys()});
                throw new TrinoException(StandardErrorCode.NO_NODES_AVAILABLE, "No nodes available to run query");
            }
            InternalNode chooseNodeForSplit = chooseNodeForSplit(nodeAssignmentStats, selectExactNodes);
            if (chooseNodeForSplit == null) {
                long j = Long.MAX_VALUE;
                for (InternalNode internalNode2 : selectExactNodes) {
                    long queuedSplitsWeightForStage = nodeAssignmentStats.getQueuedSplitsWeightForStage(internalNode2);
                    long adjustedMaxPendingSplitsWeightPerTask = this.queueSizeAdjuster.getAdjustedMaxPendingSplitsWeightPerTask(internalNode2.getNodeIdentifier());
                    if (queuedSplitsWeightForStage <= j && queuedSplitsWeightForStage < adjustedMaxPendingSplitsWeightPerTask && nodeAssignmentStats.getUnacknowledgedSplitCountForStage(internalNode2) < this.maxUnacknowledgedSplitsPerTask) {
                        chooseNodeForSplit = internalNode2;
                        j = queuedSplitsWeightForStage;
                    }
                    if (queuedSplitsWeightForStage >= adjustedMaxPendingSplitsWeightPerTask) {
                        this.queueSizeAdjuster.scheduleAdjustmentForNode(internalNode2.getNodeIdentifier());
                    }
                }
            }
            if (chooseNodeForSplit == null) {
                Objects.requireNonNull(hashSet3);
                selectExactNodes.forEach((v1) -> {
                    r1.remove(v1);
                });
                if (split2.isRemotelyAccessible()) {
                    z = true;
                } else if (!z) {
                    hashSet.addAll(selectExactNodes);
                }
                if (z && hashSet3.isEmpty()) {
                    break;
                }
            } else {
                create.put(chooseNodeForSplit, split2);
                nodeAssignmentStats.addAssignedSplit(chooseNodeForSplit, split2.getSplitWeight());
            }
        }
        ListenableFuture<Void> whenHasSplitQueueSpaceFuture = z ? NodeScheduler.toWhenHasSplitQueueSpaceFuture(list, NodeScheduler.calculateLowWatermark(this.minPendingSplitsWeightPerTask)) : NodeScheduler.toWhenHasSplitQueueSpaceFuture(hashSet, list, NodeScheduler.calculateLowWatermark(this.minPendingSplitsWeightPerTask));
        if (z2) {
            equateDistribution(create, nodeAssignmentStats, nodeMap, this.includeCoordinator);
        }
        return new SplitPlacementResult(whenHasSplitQueueSpaceFuture, create);
    }

    @Override // io.trino.execution.scheduler.NodeSelector
    public SplitPlacementResult computeAssignments(Set<Split> set, List<RemoteTask> list, BucketNodeMap bucketNodeMap) {
        return NodeScheduler.selectDistributionNodes(this.nodeMap.get().get(), this.nodeTaskMap, this.maxSplitsWeightPerNode, this.minPendingSplitsWeightPerTask, this.maxUnacknowledgedSplitsPerTask, set, list, bucketNodeMap);
    }

    @Nullable
    private InternalNode chooseNodeForSplit(NodeAssignmentStats nodeAssignmentStats, List<InternalNode> list) {
        InternalNode internalNode = null;
        long j = Long.MAX_VALUE;
        List<InternalNode> freeNodesForStage = getFreeNodesForStage(nodeAssignmentStats, list);
        switch (this.splitsBalancingPolicy) {
            case STAGE:
                for (InternalNode internalNode2 : freeNodesForStage) {
                    long queuedSplitsWeightForStage = nodeAssignmentStats.getQueuedSplitsWeightForStage(internalNode2);
                    if (queuedSplitsWeightForStage <= j) {
                        internalNode = internalNode2;
                        j = queuedSplitsWeightForStage;
                    }
                }
                break;
            case NODE:
                for (InternalNode internalNode3 : freeNodesForStage) {
                    long totalSplitsWeight = nodeAssignmentStats.getTotalSplitsWeight(internalNode3);
                    if (totalSplitsWeight <= j) {
                        internalNode = internalNode3;
                        j = totalSplitsWeight;
                    }
                }
                break;
            default:
                throw new UnsupportedOperationException("Unsupported split balancing policy " + String.valueOf(this.splitsBalancingPolicy));
        }
        return internalNode;
    }

    private List<InternalNode> getFreeNodesForStage(NodeAssignmentStats nodeAssignmentStats, List<InternalNode> list) {
        ImmutableList.Builder builder = ImmutableList.builder();
        for (InternalNode internalNode : list) {
            if (nodeAssignmentStats.getTotalSplitsWeight(internalNode) < this.maxSplitsWeightPerNode && nodeAssignmentStats.getUnacknowledgedSplitCountForStage(internalNode) < this.maxUnacknowledgedSplitsPerTask) {
                builder.add(internalNode);
            }
        }
        return builder.build();
    }

    private void equateDistribution(Multimap<InternalNode, Split> multimap, NodeAssignmentStats nodeAssignmentStats, NodeMap nodeMap, boolean z) {
        if (multimap.isEmpty()) {
            return;
        }
        Collection<InternalNode> collection = (Collection) nodeMap.getNodesByHostAndPort().values().stream().filter(internalNode -> {
            return z || !nodeMap.getCoordinatorNodeIds().contains(internalNode.getNodeIdentifier());
        }).collect(ImmutableList.toImmutableList());
        if (collection.size() < 2) {
            return;
        }
        IndexedPriorityQueue indexedPriorityQueue = new IndexedPriorityQueue();
        for (InternalNode internalNode2 : multimap.keySet()) {
            indexedPriorityQueue.addOrUpdate(internalNode2, nodeAssignmentStats.getTotalSplitsWeight(internalNode2));
        }
        IndexedPriorityQueue indexedPriorityQueue2 = new IndexedPriorityQueue();
        for (InternalNode internalNode3 : collection) {
            indexedPriorityQueue2.addOrUpdate(internalNode3, Long.MAX_VALUE - nodeAssignmentStats.getTotalSplitsWeight(internalNode3));
        }
        while (!indexedPriorityQueue.isEmpty()) {
            InternalNode internalNode4 = (InternalNode) indexedPriorityQueue.poll();
            InternalNode internalNode5 = (InternalNode) indexedPriorityQueue2.poll();
            if (nodeAssignmentStats.getTotalSplitsWeight(internalNode4) - nodeAssignmentStats.getTotalSplitsWeight(internalNode5) <= SplitWeight.rawValueForStandardSplitCount(5)) {
                return;
            }
            Split redistributeSplit = redistributeSplit(multimap, internalNode4, internalNode5, nodeMap.getNodesByHost());
            nodeAssignmentStats.removeAssignedSplit(internalNode4, redistributeSplit.getSplitWeight());
            nodeAssignmentStats.addAssignedSplit(internalNode5, redistributeSplit.getSplitWeight());
            if (multimap.containsKey(internalNode4)) {
                indexedPriorityQueue.addOrUpdate(internalNode4, nodeAssignmentStats.getTotalSplitsWeight(internalNode4));
            }
            indexedPriorityQueue.addOrUpdate(internalNode5, nodeAssignmentStats.getTotalSplitsWeight(internalNode5));
            indexedPriorityQueue2.addOrUpdate(internalNode5, Long.MAX_VALUE - nodeAssignmentStats.getTotalSplitsWeight(internalNode5));
            indexedPriorityQueue2.addOrUpdate(internalNode4, Long.MAX_VALUE - nodeAssignmentStats.getTotalSplitsWeight(internalNode4));
        }
    }

    @VisibleForTesting
    public static Split redistributeSplit(Multimap<InternalNode, Split> multimap, InternalNode internalNode, InternalNode internalNode2, SetMultimap<InetAddress, InternalNode> setMultimap) {
        Iterator it = multimap.get(internalNode).iterator();
        Split split = null;
        while (true) {
            if (!it.hasNext()) {
                break;
            }
            Split split2 = (Split) it.next();
            if (!split2.getAddresses().isEmpty() && !isSplitLocal(split2.getAddresses(), internalNode.getHostAndPort(), setMultimap)) {
                split = split2;
                break;
            }
        }
        if (split == null) {
            it = multimap.get(internalNode).iterator();
            split = (Split) it.next();
        }
        it.remove();
        multimap.put(internalNode2, split);
        return split;
    }

    private static boolean isSplitLocal(List<HostAddress> list, HostAddress hostAddress, SetMultimap<InetAddress, InternalNode> setMultimap) {
        InetAddress inetAddress;
        for (HostAddress hostAddress2 : list) {
            if (hostAddress.equals(hostAddress2)) {
                return true;
            }
            try {
                inetAddress = hostAddress2.toInetAddress();
            } catch (UnknownHostException e) {
            }
            if (!hostAddress2.hasPort()) {
                return setMultimap.get(inetAddress).stream().anyMatch(internalNode -> {
                    return internalNode.getHostAndPort().equals(hostAddress);
                });
            }
        }
        return false;
    }
}
