package io.trino.execution;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import io.trino.spi.WarningCode;
import io.trino.spi.connector.StandardWarningCode;
import io.trino.testing.QueryRunner;
import io.trino.tests.tpch.TpchQueryRunnerBuilder;
import java.util.List;
import java.util.Set;
import org.assertj.core.api.Fail;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.AfterAll;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;

@Execution(ExecutionMode.CONCURRENT)
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:io/trino/execution/TestWarnings.class */
public class TestWarnings {
    private static final int STAGE_COUNT_WARNING_THRESHOLD = 20;
    private QueryRunner queryRunner;

    @BeforeAll
    public void setUp() throws Exception {
        this.queryRunner = ((TpchQueryRunnerBuilder) TpchQueryRunnerBuilder.builder().addExtraProperty("query.stage-count-warning-threshold", String.valueOf(STAGE_COUNT_WARNING_THRESHOLD))).build();
    }

    @AfterAll
    public void tearDown() {
        this.queryRunner.close();
        this.queryRunner = null;
    }

    @Test
    public void testStageCountWarningThreshold() {
        StringBuilder sb = new StringBuilder("SELECT name FROM nation WHERE nationkey = 0");
        String sb2 = sb.toString();
        for (int i = 1; i <= STAGE_COUNT_WARNING_THRESHOLD; i++) {
            sb.append("  UNION").append("  SELECT name FROM nation WHERE nationkey = ").append(i);
        }
        assertWarnings(this.queryRunner, sb.toString(), ImmutableList.of(StandardWarningCode.TOO_MANY_STAGES.toWarningCode()));
        assertWarnings(this.queryRunner, sb2, ImmutableList.of());
    }

    private static void assertWarnings(QueryRunner queryRunner, @Language("SQL") String str, List<WarningCode> list) {
        Set set = (Set) queryRunner.execute(str).getWarnings().stream().map((v0) -> {
            return v0.getWarningCode();
        }).map((v0) -> {
            return v0.getCode();
        }).collect(ImmutableSet.toImmutableSet());
        for (WarningCode warningCode : list) {
            if (!set.contains(Integer.valueOf(warningCode.getCode()))) {
                Fail.fail("Expected warning: " + String.valueOf(warningCode));
            }
        }
    }
}
