package io.trino.plugin.hive;

import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableList;
import io.trino.block.BlockAssertions;
import io.trino.plugin.hive.util.HiveBucketing;
import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.BlockBuilder;
import io.trino.spi.connector.BucketFunction;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.TypeOperators;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.IntStream;
import org.assertj.core.api.Assertions;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/plugin/hive/TestHivePartitionedBucketFunction.class */
public class TestHivePartitionedBucketFunction {
    /* JADX WARN: Type inference failed for: r0v1, types: [java.lang.Object[], java.lang.Object[][]] */
    @DataProvider(name = "hiveBucketingVersion")
    public static Object[][] hiveBucketingVersion() {
        return new Object[]{new Object[]{HiveBucketing.BucketingVersion.BUCKETING_V1}, new Object[]{HiveBucketing.BucketingVersion.BUCKETING_V2}};
    }

    @Test(dataProvider = "hiveBucketingVersion")
    public void testSinglePartition(HiveBucketing.BucketingVersion bucketingVersion) {
        Block createLongSequenceBlockWithNull = createLongSequenceBlockWithNull(1024);
        Page page = new Page(new Block[]{createLongSequenceBlockWithNull});
        Page page2 = new Page(new Block[]{createLongSequenceBlockWithNull, BlockAssertions.createLongRepeatBlock(78758, 1024)});
        BucketFunction bucketFunction = bucketFunction(bucketingVersion, 10, ImmutableList.of(HiveType.HIVE_LONG));
        HashMultimap create = HashMultimap.create();
        for (int i = 0; i < 1024; i++) {
            create.put(Integer.valueOf(bucketFunction.getBucket(page, i)), Integer.valueOf(i));
        }
        BucketFunction partitionedBucketFunction = partitionedBucketFunction(bucketingVersion, 10, ImmutableList.of(HiveType.HIVE_LONG), ImmutableList.of(BigintType.BIGINT), 100);
        Iterator it = create.asMap().entrySet().iterator();
        while (it.hasNext()) {
            assertBucketCount(partitionedBucketFunction, page2, (Collection) ((Map.Entry) it.next()).getValue(), 1);
        }
        assertBucketCount(partitionedBucketFunction, page2, (Collection) IntStream.range(0, 1024).boxed().collect(ImmutableList.toImmutableList()), 10);
    }

    @Test(dataProvider = "hiveBucketingVersion")
    public void testMultiplePartitions(HiveBucketing.BucketingVersion bucketingVersion) {
        Block createLongSequenceBlockWithNull = createLongSequenceBlockWithNull(1024);
        Page page = new Page(new Block[]{createLongSequenceBlockWithNull});
        BucketFunction bucketFunction = bucketFunction(bucketingVersion, 10, ImmutableList.of(HiveType.HIVE_LONG));
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < 8 - 1; i++) {
            arrayList.addAll(Collections.nCopies(1024 / 8, Long.valueOf(i * 348349)));
        }
        arrayList.addAll(Collections.nCopies(1024 / 8, null));
        Page page2 = new Page(new Block[]{createLongSequenceBlockWithNull, BlockAssertions.createLongsBlock(arrayList)});
        HashMap hashMap = new HashMap();
        for (int i2 = 0; i2 < 1024; i2++) {
            ((HashMultimap) hashMap.computeIfAbsent((Long) arrayList.get(i2), l -> {
                return HashMultimap.create();
            })).put(Integer.valueOf(bucketFunction.getBucket(page, i2)), Integer.valueOf(i2));
        }
        BucketFunction partitionedBucketFunction = partitionedBucketFunction(bucketingVersion, 10, ImmutableList.of(HiveType.HIVE_LONG), ImmutableList.of(BigintType.BIGINT), 4000);
        Iterator it = hashMap.entrySet().iterator();
        while (it.hasNext()) {
            Iterator it2 = ((HashMultimap) ((Map.Entry) it.next()).getValue()).asMap().entrySet().iterator();
            while (it2.hasNext()) {
                assertBucketCount(partitionedBucketFunction, page2, (Collection) ((Map.Entry) it2.next()).getValue(), 1);
            }
        }
        assertBucketCount(partitionedBucketFunction, page2, (Collection) IntStream.range(0, 1024).boxed().collect(ImmutableList.toImmutableList()), 10 * 8);
    }

    @Test(dataProvider = "hiveBucketingVersion")
    public void testConsecutiveBucketsWithinPartition(HiveBucketing.BucketingVersion bucketingVersion) {
        BlockBuilder createFixedSizeBlockBuilder = BigintType.BIGINT.createFixedSizeBlockBuilder(10);
        BlockBuilder createFixedSizeBlockBuilder2 = BigintType.BIGINT.createFixedSizeBlockBuilder(10);
        for (int i = 0; i < 100; i++) {
            BigintType.BIGINT.writeLong(createFixedSizeBlockBuilder, i);
            BigintType.BIGINT.writeLong(createFixedSizeBlockBuilder2, 42L);
        }
        Page page = new Page(new Block[]{createFixedSizeBlockBuilder.build(), createFixedSizeBlockBuilder2.build()});
        BucketFunction partitionedBucketFunction = partitionedBucketFunction(bucketingVersion, 10, ImmutableList.of(HiveType.HIVE_LONG), ImmutableList.of(BigintType.BIGINT), 4000);
        ArrayList arrayList = new ArrayList();
        for (int i2 = 0; i2 < 100; i2++) {
            arrayList.add(Integer.valueOf(partitionedBucketFunction.getBucket(page, i2)));
        }
        Assertions.assertThat((((Integer) Collections.max(arrayList)).intValue() - ((Integer) Collections.min(arrayList)).intValue()) + 1).isEqualTo(10);
    }

    private static void assertBucketCount(BucketFunction bucketFunction, Page page, Collection<Integer> collection, int i) {
        Assertions.assertThat(collection.stream().map(num -> {
            return Integer.valueOf(bucketFunction.getBucket(page, num.intValue()));
        }).distinct().count()).isEqualTo(i);
    }

    private static Block createLongSequenceBlockWithNull(int i) {
        BlockBuilder createFixedSizeBlockBuilder = BigintType.BIGINT.createFixedSizeBlockBuilder(i);
        int i2 = (923402935 + i) - 1;
        for (int i3 = 923402935; i3 < i2; i3++) {
            BigintType.BIGINT.writeLong(createFixedSizeBlockBuilder, i3);
        }
        createFixedSizeBlockBuilder.appendNull();
        return createFixedSizeBlockBuilder.build();
    }

    private static BucketFunction partitionedBucketFunction(HiveBucketing.BucketingVersion bucketingVersion, int i, List<HiveType> list, List<Type> list2, int i2) {
        return new HivePartitionedBucketFunction(bucketingVersion, i, list, list2, new TypeOperators(), i2);
    }

    private static BucketFunction bucketFunction(HiveBucketing.BucketingVersion bucketingVersion, int i, List<HiveType> list) {
        return new HiveBucketFunction(bucketingVersion, i, list);
    }
}
