package io.trino.operator.aggregation;

import io.trino.spi.Page;
import io.trino.spi.block.Block;
import io.trino.spi.block.ByteArrayBlock;
import io.trino.spi.block.IntArrayBlock;
import io.trino.spi.block.RunLengthEncodedBlock;
import io.trino.spi.block.ShortArrayBlock;
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Supplier;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/operator/aggregation/TestAggregationMaskCompiler.class */
public class TestAggregationMaskCompiler {
    private static final Supplier<AggregationMaskBuilder> INTERPRETED_MASK_BUILDER_SUPPLIER = () -> {
        return new InterpretedAggregationMaskBuilder(1);
    };
    private static final Supplier<AggregationMaskBuilder> COMPILED_MASK_BUILDER_SUPPLIER = () -> {
        try {
            return (AggregationMaskBuilder) AggregationMaskCompiler.generateAggregationMaskBuilder(new int[]{1}).newInstance(new Object[0]);
        } catch (ReflectiveOperationException e) {
            throw new RuntimeException(e);
        }
    };

    @Test
    public void testSupplier() {
        testSupplier(INTERPRETED_MASK_BUILDER_SUPPLIER);
        testSupplier(COMPILED_MASK_BUILDER_SUPPLIER);
    }

    private void testSupplier(Supplier<AggregationMaskBuilder> supplier) {
        Assertions.assertThat(supplier.get()).isNotSameAs(supplier.get());
        Page buildSingleColumnPage = buildSingleColumnPage(5);
        Assertions.assertThat(supplier.get().buildAggregationMask(buildSingleColumnPage, Optional.empty())).isNotSameAs(supplier.get().buildAggregationMask(buildSingleColumnPage, Optional.empty()));
        Page buildSingleColumnPage2 = buildSingleColumnPage(new boolean[]{false, true, false, true});
        Assertions.assertThat(supplier.get().buildAggregationMask(buildSingleColumnPage2, Optional.empty())).isNotSameAs(supplier.get().buildAggregationMask(buildSingleColumnPage2, Optional.empty()));
        Assertions.assertThat(supplier.get().buildAggregationMask(buildSingleColumnPage2, Optional.empty()).getSelectedPositions()).isNotSameAs(supplier.get().buildAggregationMask(buildSingleColumnPage2, Optional.empty()).getSelectedPositions());
        Assertions.assertThat(supplier.get().buildAggregationMask(buildSingleColumnPage2, Optional.empty()).getSelectedPositions()).isEqualTo(supplier.get().buildAggregationMask(buildSingleColumnPage2, Optional.empty()).getSelectedPositions());
        AggregationMaskBuilder aggregationMaskBuilder = supplier.get();
        Assertions.assertThat(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage2, Optional.empty()).getSelectedPositions()).isSameAs(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage2, Optional.empty()).getSelectedPositions());
    }

    @Test
    public void testUnsetNulls() {
        testUnsetNulls(INTERPRETED_MASK_BUILDER_SUPPLIER);
        testUnsetNulls(COMPILED_MASK_BUILDER_SUPPLIER);
    }

    private void testUnsetNulls(Supplier<AggregationMaskBuilder> supplier) {
        AggregationMaskBuilder aggregationMaskBuilder = supplier.get();
        assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(0), Optional.empty()), 0);
        for (int i = 7; i < 10; i++) {
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPageRle(i, Optional.of(true)), Optional.empty()), i, new int[0]);
            assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.empty()), i);
            boolean[] zArr = new boolean[i];
            assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(zArr), Optional.empty()), i);
            Arrays.fill(zArr, true);
            zArr[1] = false;
            zArr[3] = false;
            zArr[5] = false;
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(zArr), Optional.empty()), i, 1, 3, 5);
            zArr[3] = true;
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(zArr), Optional.empty()), i, 1, 5);
            zArr[2] = false;
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(zArr), Optional.empty()), i, 1, 2, 5);
            assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPageRle(i, Optional.empty()), Optional.empty()), i);
            assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPageRle(i, Optional.of(false)), Optional.empty()), i);
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPageRle(i, Optional.of(true)), Optional.empty()), i, new int[0]);
        }
    }

    @Test
    public void testApplyMask() {
        testApplyMask(INTERPRETED_MASK_BUILDER_SUPPLIER);
        testApplyMask(COMPILED_MASK_BUILDER_SUPPLIER);
    }

    private void testApplyMask(Supplier<AggregationMaskBuilder> supplier) {
        AggregationMaskBuilder aggregationMaskBuilder = supplier.get();
        for (int i = 7; i < 10; i++) {
            byte[] bArr = new byte[i];
            Arrays.fill(bArr, (byte) 1);
            assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlock(i, bArr))), i);
            Arrays.fill(bArr, (byte) 0);
            bArr[1] = 1;
            bArr[3] = 1;
            bArr[5] = 1;
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlock(i, bArr))), i, 1, 3, 5);
            bArr[3] = 0;
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlock(i, bArr))), i, 1, 5);
            bArr[2] = 1;
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlock(i, bArr))), i, 1, 2, 5);
            assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlockRle(i, (byte) 1))), i);
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlockRle(i, (byte) 0))), i, new int[0]);
        }
    }

    @Test
    public void testApplyMaskNulls() {
        testApplyMaskNulls(INTERPRETED_MASK_BUILDER_SUPPLIER);
        testApplyMaskNulls(COMPILED_MASK_BUILDER_SUPPLIER);
    }

    private void testApplyMaskNulls(Supplier<AggregationMaskBuilder> supplier) {
        AggregationMaskBuilder aggregationMaskBuilder = supplier.get();
        for (int i = 7; i < 10; i++) {
            byte[] bArr = new byte[i];
            Arrays.fill(bArr, (byte) 1);
            assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlock(i, bArr))), i);
            boolean[] zArr = new boolean[i];
            assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlockNulls(zArr))), i);
            Arrays.fill(zArr, true);
            zArr[1] = false;
            zArr[3] = false;
            zArr[5] = false;
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlockNulls(zArr))), i, 1, 3, 5);
            zArr[3] = true;
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlockNulls(zArr))), i, 1, 5);
            zArr[1] = true;
            zArr[5] = true;
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlockNulls(zArr))), i, new int[0]);
            assertAggregationMaskAll(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlockNullsRle(i, false))), i);
            assertAggregationMaskPositions(aggregationMaskBuilder.buildAggregationMask(buildSingleColumnPage(i), Optional.of(createMaskBlockNullsRle(i, true))), i, new int[0]);
        }
    }

    private static Block createMaskBlock(int i, byte[] bArr) {
        return new ByteArrayBlock(i, Optional.empty(), bArr);
    }

    private static Block createMaskBlockRle(int i, byte b) {
        return RunLengthEncodedBlock.create(createMaskBlock(1, new byte[]{b}), i);
    }

    private static Block createMaskBlockNulls(boolean[] zArr) {
        int length = zArr.length;
        byte[] bArr = new byte[length];
        Arrays.fill(bArr, (byte) 1);
        return new ByteArrayBlock(length, Optional.of(zArr), bArr);
    }

    private static Block createMaskBlockNullsRle(int i, boolean z) {
        return RunLengthEncodedBlock.create(createMaskBlockNulls(new boolean[]{z}), i);
    }

    private static Page buildSingleColumnPage(int i) {
        boolean[] zArr = new boolean[i];
        Arrays.fill(zArr, true);
        return new Page(new Block[]{new ShortArrayBlock(i, Optional.of(zArr), new short[i]), new IntArrayBlock(i, Optional.empty(), new int[i])});
    }

    private static Page buildSingleColumnPage(boolean[] zArr) {
        int length = zArr.length;
        boolean[] zArr2 = new boolean[length];
        Arrays.fill(zArr2, true);
        return new Page(new Block[]{new ShortArrayBlock(length, Optional.of(zArr2), new short[length]), new IntArrayBlock(length, Optional.of(zArr), new int[length])});
    }

    private static Page buildSingleColumnPageRle(int i, Optional<Boolean> optional) {
        Optional<U> map = optional.map(bool -> {
            return new boolean[]{bool.booleanValue()};
        });
        boolean[] zArr = new boolean[i];
        Arrays.fill(zArr, true);
        return new Page(new Block[]{new ShortArrayBlock(i, Optional.of(zArr), new short[i]), RunLengthEncodedBlock.create(new IntArrayBlock(1, map, new int[i]), i)});
    }

    private static void assertAggregationMaskAll(AggregationMask aggregationMask, int i) {
        Assertions.assertThat(aggregationMask.isSelectAll()).isTrue();
        Assertions.assertThat(aggregationMask.isSelectNone()).isEqualTo(i == 0);
        Assertions.assertThat(aggregationMask.getPositionCount()).isEqualTo(i);
        Assertions.assertThat(aggregationMask.getSelectedPositionCount()).isEqualTo(i);
        Objects.requireNonNull(aggregationMask);
        Assertions.assertThatThrownBy(aggregationMask::getSelectedPositions).isInstanceOf(IllegalStateException.class);
    }

    private static void assertAggregationMaskPositions(AggregationMask aggregationMask, int i, int... iArr) {
        Assertions.assertThat(aggregationMask.isSelectAll()).isFalse();
        Assertions.assertThat(aggregationMask.isSelectNone()).isEqualTo(iArr.length == 0);
        Assertions.assertThat(aggregationMask.getPositionCount()).isEqualTo(i);
        Assertions.assertThat(aggregationMask.getSelectedPositionCount()).isEqualTo(iArr.length);
        if (iArr.length > 0) {
            Assertions.assertThat(aggregationMask.getSelectedPositions()).startsWith(iArr);
        }
    }
}
