package io.trino.plugin.raptor.legacy.storage;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.VerifyException;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.HashMultiset;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import com.google.common.collect.UnmodifiableIterator;
import com.google.inject.Inject;
import io.airlift.concurrent.Threads;
import io.airlift.log.Logger;
import io.airlift.stats.CounterStat;
import io.airlift.units.Duration;
import io.trino.plugin.base.CatalogName;
import io.trino.plugin.raptor.legacy.NodeSupplier;
import io.trino.plugin.raptor.legacy.backup.BackupService;
import io.trino.plugin.raptor.legacy.metadata.BucketNode;
import io.trino.plugin.raptor.legacy.metadata.Distribution;
import io.trino.plugin.raptor.legacy.metadata.ShardManager;
import io.trino.spi.NodeManager;
import jakarta.annotation.PostConstruct;
import jakarta.annotation.PreDestroy;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.weakref.jmx.Managed;
import org.weakref.jmx.Nested;

/* loaded from: input_file:io/trino/plugin/raptor/legacy/storage/BucketBalancer.class */
public class BucketBalancer {
    private static final Logger log = Logger.get(BucketBalancer.class);
    private final NodeSupplier nodeSupplier;
    private final ShardManager shardManager;
    private final boolean enabled;
    private final Duration interval;
    private final boolean backupAvailable;
    private final boolean coordinator;
    private final ScheduledExecutorService executor;
    private final AtomicBoolean started;
    private final CounterStat bucketsBalanced;
    private final CounterStat jobErrors;

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:io/trino/plugin/raptor/legacy/storage/BucketBalancer$BucketAssignment.class */
    public static class BucketAssignment {
        private final long distributionId;
        private final int bucketNumber;
        private final String nodeIdentifier;

        public BucketAssignment(long j, int i, String str) {
            this.distributionId = j;
            this.bucketNumber = i;
            this.nodeIdentifier = (String) Objects.requireNonNull(str, "nodeIdentifier is null");
        }

        public long getDistributionId() {
            return this.distributionId;
        }

        public int getBucketNumber() {
            return this.bucketNumber;
        }

        public String getNodeIdentifier() {
            return this.nodeIdentifier;
        }
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    @VisibleForTesting
    /* loaded from: input_file:io/trino/plugin/raptor/legacy/storage/BucketBalancer$ClusterState.class */
    public static class ClusterState {
        private final Set<String> activeNodes;
        private final Map<String, Long> assignedBytes;
        private final Multimap<Distribution, BucketAssignment> distributionAssignments;
        private final Map<Distribution, Long> distributionBucketSize;

        public ClusterState(Set<String> set, Map<String, Long> map, Multimap<Distribution, BucketAssignment> multimap, Map<Distribution, Long> map2) {
            this.activeNodes = ImmutableSet.copyOf((Collection) Objects.requireNonNull(set, "activeNodes is null"));
            this.assignedBytes = ImmutableMap.copyOf((Map) Objects.requireNonNull(map, "assignedBytes is null"));
            this.distributionAssignments = ImmutableMultimap.copyOf((Multimap) Objects.requireNonNull(multimap, "distributionAssignments is null"));
            this.distributionBucketSize = ImmutableMap.copyOf((Map) Objects.requireNonNull(map2, "distributionBucketSize is null"));
        }

        public Set<String> getActiveNodes() {
            return this.activeNodes;
        }

        public Map<String, Long> getAssignedBytes() {
            return this.assignedBytes;
        }

        public Multimap<Distribution, BucketAssignment> getDistributionAssignments() {
            return this.distributionAssignments;
        }

        public Map<Distribution, Long> getDistributionBucketSize() {
            return this.distributionBucketSize;
        }
    }

    @Inject
    public BucketBalancer(NodeManager nodeManager, NodeSupplier nodeSupplier, ShardManager shardManager, BucketBalancerConfig bucketBalancerConfig, BackupService backupService, CatalogName catalogName) {
        this(nodeSupplier, shardManager, bucketBalancerConfig.isBalancerEnabled(), bucketBalancerConfig.getBalancerInterval(), backupService.isBackupAvailable(), nodeManager.getCurrentNode().isCoordinator(), catalogName.toString());
    }

    public BucketBalancer(NodeSupplier nodeSupplier, ShardManager shardManager, boolean z, Duration duration, boolean z2, boolean z3, String str) {
        this.started = new AtomicBoolean();
        this.bucketsBalanced = new CounterStat();
        this.jobErrors = new CounterStat();
        this.nodeSupplier = (NodeSupplier) Objects.requireNonNull(nodeSupplier, "nodeSupplier is null");
        this.shardManager = (ShardManager) Objects.requireNonNull(shardManager, "shardManager is null");
        this.enabled = z;
        this.interval = (Duration) Objects.requireNonNull(duration, "interval is null");
        this.backupAvailable = z2;
        this.coordinator = z3;
        this.executor = Executors.newSingleThreadScheduledExecutor(Threads.daemonThreadsNamed("bucket-balancer-" + str));
    }

    @PostConstruct
    public void start() {
        if (this.enabled && this.backupAvailable && this.coordinator && !this.started.getAndSet(true)) {
            this.executor.scheduleWithFixedDelay(this::runBalanceJob, this.interval.toMillis(), this.interval.toMillis(), TimeUnit.MILLISECONDS);
        }
    }

    @PreDestroy
    public void shutdown() {
        this.executor.shutdownNow();
    }

    @Managed
    @Nested
    public CounterStat getBucketsBalanced() {
        return this.bucketsBalanced;
    }

    @Managed
    @Nested
    public CounterStat getJobErrors() {
        return this.jobErrors;
    }

    @Managed
    public void startBalanceJob() {
        this.executor.submit(this::runBalanceJob);
    }

    private void runBalanceJob() {
        try {
            balance();
        } catch (Throwable th) {
            log.error(th, "Error balancing buckets");
            this.jobErrors.update(1L);
        }
    }

    @VisibleForTesting
    synchronized int balance() {
        log.info("Bucket balancer started. Computing assignments...");
        Multimap<String, BucketAssignment> computeAssignmentChanges = computeAssignmentChanges(fetchClusterState());
        log.info("Moving buckets...");
        int updateAssignments = updateAssignments(computeAssignmentChanges);
        log.info("Bucket balancing finished. Moved %s buckets.", new Object[]{Integer.valueOf(updateAssignments)});
        return updateAssignments;
    }

    private static Multimap<String, BucketAssignment> computeAssignmentChanges(ClusterState clusterState) {
        HashMultimap create = HashMultimap.create();
        HashMap hashMap = new HashMap(clusterState.getAssignedBytes());
        Set<String> activeNodes = clusterState.getActiveNodes();
        for (Distribution distribution : clusterState.getDistributionAssignments().keySet()) {
            HashMultiset create2 = HashMultiset.create();
            Collection collection = clusterState.getDistributionAssignments().get(distribution);
            Stream map = collection.stream().map((v0) -> {
                return v0.getNodeIdentifier();
            });
            Objects.requireNonNull(create2);
            map.forEach((v1) -> {
                r1.add(v1);
            });
            Stream stream = hashMap.keySet().stream();
            Objects.requireNonNull(create2);
            int asInt = stream.mapToInt((v1) -> {
                return r1.count(v1);
            }).min().getAsInt();
            Stream stream2 = hashMap.keySet().stream();
            Objects.requireNonNull(create2);
            int asInt2 = stream2.mapToInt((v1) -> {
                return r1.count(v1);
            }).max().getAsInt();
            int size = collection.size();
            int floor = (int) Math.floor((size * 1.0d) / clusterState.getActiveNodes().size());
            int ceil = (int) Math.ceil((size * 1.0d) / clusterState.getActiveNodes().size());
            log.info("Distribution %s: Current bucket skew: min %s, max %s. Target bucket skew: min %s, max %s", new Object[]{Long.valueOf(distribution.getId()), Integer.valueOf(asInt), Integer.valueOf(asInt2), Integer.valueOf(floor), Integer.valueOf(ceil)});
            UnmodifiableIterator it = ImmutableSet.copyOf(create2).iterator();
            while (it.hasNext()) {
                String str = (String) it.next();
                for (BucketAssignment bucketAssignment : (List) collection.stream().filter(bucketAssignment2 -> {
                    return bucketAssignment2.getNodeIdentifier().equals(str);
                }).collect(Collectors.toList())) {
                    if (!activeNodes.contains(str) || create2.count(str) > floor) {
                        Stream<String> filter = activeNodes.stream().filter(str2 -> {
                            return !str2.equals(str) && create2.count(str2) < ceil;
                        });
                        Objects.requireNonNull(create2);
                        Stream<String> sorted = filter.sorted(Comparator.comparingInt((v1) -> {
                            return r1.count(v1);
                        }));
                        Objects.requireNonNull(hashMap);
                        String orElseThrow = sorted.min(Comparator.comparingDouble((v1) -> {
                            return r1.get(v1);
                        })).orElseThrow(() -> {
                            return new VerifyException("unable to find target for rebalancing");
                        });
                        long longValue = clusterState.getDistributionBucketSize().get(distribution).longValue();
                        if (!activeNodes.contains(str) || create2.count(str) != ceil || create2.count(orElseThrow) != floor) {
                            create2.remove(str);
                            create2.add(orElseThrow);
                            hashMap.compute(str, (str3, l) -> {
                                return Long.valueOf(l.longValue() - longValue);
                            });
                            hashMap.compute(orElseThrow, (str4, l2) -> {
                                return Long.valueOf(l2.longValue() + longValue);
                            });
                            create.put(bucketAssignment.getNodeIdentifier(), new BucketAssignment(bucketAssignment.getDistributionId(), bucketAssignment.getBucketNumber(), orElseThrow));
                        }
                    }
                }
            }
        }
        return create;
    }

    private int updateAssignments(Multimap<String, BucketAssignment> multimap) {
        int i = 0;
        for (String str : (List) multimap.asMap().entrySet().stream().sorted((entry, entry2) -> {
            return Integer.compare(((Collection) entry2.getValue()).size(), ((Collection) entry.getValue()).size());
        }).map((v0) -> {
            return v0.getKey();
        }).collect(Collectors.toList())) {
            for (BucketAssignment bucketAssignment : multimap.get(str)) {
                this.shardManager.updateBucketAssignment(bucketAssignment.getDistributionId(), bucketAssignment.getBucketNumber(), bucketAssignment.getNodeIdentifier());
                this.bucketsBalanced.update(1L);
                i++;
                log.info("Distribution %s: Moved bucket %s from %s to %s", new Object[]{Long.valueOf(bucketAssignment.getDistributionId()), Integer.valueOf(bucketAssignment.getBucketNumber()), str, bucketAssignment.getNodeIdentifier()});
            }
        }
        return i;
    }

    @VisibleForTesting
    ClusterState fetchClusterState() {
        Set set = (Set) this.nodeSupplier.getWorkerNodes().stream().map((v0) -> {
            return v0.getNodeIdentifier();
        }).collect(Collectors.toSet());
        HashMap hashMap = new HashMap((Map) set.stream().collect(Collectors.toMap(str -> {
            return str;
        }, str2 -> {
            return 0L;
        })));
        ImmutableMultimap.Builder builder = ImmutableMultimap.builder();
        ImmutableMap.Builder builder2 = ImmutableMap.builder();
        for (Distribution distribution : this.shardManager.getDistributions()) {
            long distributionSizeInBytes = ((long) (1.0d * this.shardManager.getDistributionSizeInBytes(distribution.getId()))) / distribution.getBucketCount();
            builder2.put(distribution, Long.valueOf(distributionSizeInBytes));
            for (BucketNode bucketNode : this.shardManager.getBucketNodes(distribution.getId())) {
                String nodeIdentifier = bucketNode.getNodeIdentifier();
                builder.put(distribution, new BucketAssignment(distribution.getId(), bucketNode.getBucketNumber(), nodeIdentifier));
                hashMap.merge(nodeIdentifier, Long.valueOf(distributionSizeInBytes), (v0, v1) -> {
                    return Math.addExact(v0, v1);
                });
            }
        }
        return new ClusterState(set, hashMap, builder.build(), builder2.buildOrThrow());
    }
}
