package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.Plan;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import org.assertj.core.api.Assertions;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPreAggregateCaseAggregations.class */
public class TestPreAggregateCaseAggregations extends BasePlanTest {
    private static final SchemaTableName TABLE = new SchemaTableName("default", "t");

    @Override // io.trino.sql.planner.assertions.BasePlanTest
    protected LocalQueryRunner createLocalQueryRunner() {
        LocalQueryRunner create = LocalQueryRunner.create(TestingSession.testSessionBuilder().setCatalog("local").setSchema("default").setSystemProperty("optimize_hash_generation", "false").setSystemProperty("prefer_partial_aggregation", "false").setSystemProperty("task_concurrency", "1").build());
        create.createCatalog("local", MockConnectorFactory.builder().withGetTableHandle((connectorSession, schemaTableName) -> {
            return new MockConnectorTableHandle(schemaTableName);
        }).withGetColumns(schemaTableName2 -> {
            if (schemaTableName2.equals(TABLE)) {
                return ImmutableList.of(new ColumnMetadata("col_varchar", VarcharType.VARCHAR), new ColumnMetadata("col_bigint", BigintType.BIGINT), new ColumnMetadata("col_tinyint", TinyintType.TINYINT), new ColumnMetadata("col_decimal", DecimalType.createDecimalType(2, 1)), new ColumnMetadata("col_long_decimal", DecimalType.createDecimalType(19, 18)), new ColumnMetadata("col_double", DoubleType.DOUBLE));
            }
            throw new IllegalArgumentException();
        }).build(), ImmutableMap.of());
        return create;
    }

    @Test
    public void testPreAggregatesCaseAggregations() {
        assertPlan("SELECT (col_varchar || 'a'), sum(CASE WHEN col_bigint = 1 THEN col_bigint * 2 ELSE 0 END), CAST(sum(CASE WHEN col_bigint = 1 THEN CAST(col_bigint * 2 AS INTEGER) ELSE CAST(0 AS INTEGER) END) AS VARCHAR(10)), sum(CASE WHEN col_bigint = 2 THEN col_bigint * 2 ELSE null END), min(CASE WHEN col_bigint % 2 > 1.23 THEN col_bigint * 2 END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CAST(CASE WHEN col_bigint = 4 THEN col_decimal * 2 END AS BIGINT)) FROM t GROUP BY (col_varchar || 'a')", PlanMatchPattern.anyTree(PlanMatchPattern.project(ImmutableMap.of("SUM_2_CAST", PlanMatchPattern.expression("CAST(SUM_2 AS VARCHAR(10))")), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("KEY"), ImmutableMap.builder().put(Optional.of("SUM_1"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_1_INPUT"))).put(Optional.of("SUM_2"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_2_INPUT"))).put(Optional.of("SUM_3"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_3_INPUT"))).put(Optional.of("MIN_1"), PlanMatchPattern.functionCall("min", ImmutableList.of("MIN_1_INPUT"))).put(Optional.of("SUM_4"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_4_INPUT"))).put(Optional.of("SUM_5"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_5_INPUT"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project(ImmutableMap.builder().put("SUM_1_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '1' THEN SUM_BIGINT ELSE BIGINT '0' END")).put("SUM_2_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '1' THEN SUM_INT_CAST ELSE BIGINT '0' END")).put("SUM_3_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '2' THEN SUM_BIGINT END")).put("MIN_1_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT % BIGINT '2' > BIGINT '1' THEN MIN_BIGINT END")).put("SUM_4_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '3' THEN SUM_DECIMAL END")).put("SUM_5_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '4' THEN SUM_DECIMAL_CAST END")).buildOrThrow(), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("KEY", "COL_BIGINT"), ImmutableMap.of(Optional.of("SUM_BIGINT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("VALUE_BIGINT")), Optional.of("SUM_INT_CAST"), PlanMatchPattern.functionCall("sum", ImmutableList.of("VALUE_INT_CAST")), Optional.of("MIN_BIGINT"), PlanMatchPattern.functionCall("min", ImmutableList.of("VALUE_BIGINT")), Optional.of("SUM_DECIMAL"), PlanMatchPattern.functionCall("sum", ImmutableList.of("COL_DECIMAL")), Optional.of("SUM_DECIMAL_CAST"), PlanMatchPattern.functionCall("sum", ImmutableList.of("VALUE_DECIMAL_CAST"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.exchange(PlanMatchPattern.project(ImmutableMap.of("KEY", PlanMatchPattern.expression("CONCAT(COL_VARCHAR, VARCHAR 'a')"), "VALUE_BIGINT", PlanMatchPattern.expression("COL_BIGINT * BIGINT '2'"), "VALUE_INT_CAST", PlanMatchPattern.expression("CAST(CAST(COL_BIGINT * BIGINT '2' AS INTEGER) AS BIGINT)"), "VALUE_DECIMAL_CAST", PlanMatchPattern.expression("CAST(COL_DECIMAL * CAST(DECIMAL '2' AS DECIMAL(10, 0)) AS BIGINT)")), PlanMatchPattern.tableScan("t", ImmutableMap.of("COL_VARCHAR", "col_varchar", "COL_BIGINT", "col_bigint", "COL_DECIMAL", "col_decimal"))))))))));
    }

    @Test
    public void testGlobalPreAggregatesCaseAggregations() {
        assertPlan("SELECT sum(CASE WHEN col_bigint = 1 THEN col_bigint * 2 ELSE 0 END), CAST(sum(CASE WHEN col_bigint = 1 THEN CAST(col_bigint * 2 AS INTEGER) ELSE CAST(0 AS INTEGER) END) AS VARCHAR(10)), sum(CASE WHEN col_bigint = 2 THEN col_bigint * 2 ELSE null END), min(CASE WHEN col_bigint % 2 > 1.23 THEN col_bigint * 2 END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CAST(CASE WHEN col_bigint = 4 THEN col_decimal * 2 END AS BIGINT)) FROM t", PlanMatchPattern.anyTree(PlanMatchPattern.project(ImmutableMap.of("SUM_2_CAST", PlanMatchPattern.expression("CAST(SUM_2 AS VARCHAR(10))")), PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.builder().put(Optional.of("SUM_1"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_1_INPUT"))).put(Optional.of("SUM_2"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_2_INPUT"))).put(Optional.of("SUM_3"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_3_INPUT"))).put(Optional.of("MIN_1"), PlanMatchPattern.functionCall("min", ImmutableList.of("MIN_1_INPUT"))).put(Optional.of("SUM_4"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_4_INPUT"))).put(Optional.of("SUM_5"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_5_INPUT"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project(ImmutableMap.builder().put("SUM_1_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '1' THEN SUM_BIGINT ELSE BIGINT '0' END")).put("SUM_2_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '1' THEN SUM_INT_CAST ELSE BIGINT '0' END")).put("SUM_3_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '2' THEN SUM_BIGINT END")).put("MIN_1_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT % BIGINT '2' > BIGINT '1' THEN MIN_BIGINT END")).put("SUM_4_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '3' THEN SUM_DECIMAL END")).put("SUM_5_INPUT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '4' THEN SUM_DECIMAL_CAST END")).buildOrThrow(), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL_BIGINT"), ImmutableMap.of(Optional.of("SUM_BIGINT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("VALUE_BIGINT")), Optional.of("SUM_INT_CAST"), PlanMatchPattern.functionCall("sum", ImmutableList.of("VALUE_INT_CAST")), Optional.of("MIN_BIGINT"), PlanMatchPattern.functionCall("min", ImmutableList.of("VALUE_BIGINT")), Optional.of("SUM_DECIMAL"), PlanMatchPattern.functionCall("sum", ImmutableList.of("COL_DECIMAL")), Optional.of("SUM_DECIMAL_CAST"), PlanMatchPattern.functionCall("sum", ImmutableList.of("VALUE_DECIMAL_CAST"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.exchange(PlanMatchPattern.project(ImmutableMap.of("VALUE_BIGINT", PlanMatchPattern.expression("COL_BIGINT * BIGINT '2'"), "VALUE_INT_CAST", PlanMatchPattern.expression("CAST(CAST(COL_BIGINT * BIGINT '2' AS INTEGER) AS BIGINT)"), "VALUE_DECIMAL_CAST", PlanMatchPattern.expression("CAST(COL_DECIMAL * CAST(DECIMAL '2' AS DECIMAL(10, 0)) AS BIGINT)")), PlanMatchPattern.tableScan("t", ImmutableMap.of("COL_BIGINT", "col_bigint", "COL_DECIMAL", "col_decimal"))))))))));
    }

    @Test
    public void testPreAggregatesWithDefaultValues() {
        assertPlan("SELECT sum(CASE WHEN col_bigint = 1 THEN col_bigint ELSE BIGINT '0' END), sum(CASE WHEN col_bigint = 1 THEN col_bigint END), sum(CASE WHEN col_bigint = 2 THEN CAST(col_bigint AS INTEGER) ELSE CAST(0 AS INTEGER) END), sum(CASE WHEN col_bigint = 2 THEN CAST(col_bigint AS INTEGER) END), sum(CASE WHEN col_bigint = 3 THEN col_tinyint ELSE TINYINT '0' END), sum(CASE WHEN col_bigint = 3 THEN col_tinyint END), sum(CASE WHEN col_bigint = 4 THEN col_decimal ELSE CAST(0 AS DECIMAL(2, 1)) END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END), sum(CASE WHEN col_bigint = 5 THEN col_long_decimal ELSE CAST(0 AS DECIMAL(19, 18)) END), sum(CASE WHEN col_bigint = 5 THEN col_long_decimal END), sum(CASE WHEN col_bigint = 6 THEN col_double ELSE DOUBLE '0' END), sum(CASE WHEN col_bigint = 6 THEN col_double END) FROM t", PlanMatchPattern.anyTree(PlanMatchPattern.aggregation(PlanMatchPattern.globalAggregation(), ImmutableMap.builder().put(Optional.of("SUM_1"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_BIGINT_FINAL"))).put(Optional.of("SUM_1_DEFAULT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_BIGINT_FINAL_DEFAULT"))).put(Optional.of("SUM_2"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_INT_CAST_FINAL"))).put(Optional.of("SUM_2_DEFAULT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_INT_CAST_FINAL_DEFAULT"))).put(Optional.of("SUM_3"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_TINYINT_FINAL"))).put(Optional.of("SUM_3_DEFAULT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_TINYINT_FINAL_DEFAULT"))).put(Optional.of("SUM_4"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_DECIMAL_FINAL"))).put(Optional.of("SUM_4_DEFAULT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_DECIMAL_FINAL_DEFAULT"))).put(Optional.of("SUM_5"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_LONG_DECIMAL_FINAL"))).put(Optional.of("SUM_5_DEFAULT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_LONG_DECIMAL_FINAL_DEFAULT"))).put(Optional.of("SUM_6"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_DOUBLE_FINAL"))).put(Optional.of("SUM_6_DEFAULT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("SUM_DOUBLE_FINAL_DEFAULT"))).buildOrThrow(), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.project(ImmutableMap.builder().put("SUM_BIGINT_FINAL", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '1' THEN SUM_BIGINT END")).put("SUM_BIGINT_FINAL_DEFAULT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '1' THEN SUM_BIGINT ELSE BIGINT '0' END")).put("SUM_INT_CAST_FINAL", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '2' THEN SUM_INT_CAST END")).put("SUM_INT_CAST_FINAL_DEFAULT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '2' THEN SUM_INT_CAST ELSE BIGINT '0' END")).put("SUM_TINYINT_FINAL", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '3' THEN SUM_TINYINT END")).put("SUM_TINYINT_FINAL_DEFAULT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '3' THEN SUM_TINYINT ELSE BIGINT '0' END")).put("SUM_DECIMAL_FINAL", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '4' THEN SUM_DECIMAL END")).put("SUM_DECIMAL_FINAL_DEFAULT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '4' THEN SUM_DECIMAL ELSE CAST(DECIMAL '0.0' AS decimal(38, 1)) END")).put("SUM_LONG_DECIMAL_FINAL", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '5' THEN SUM_LONG_DECIMAL END")).put("SUM_LONG_DECIMAL_FINAL_DEFAULT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '5' THEN SUM_LONG_DECIMAL ELSE CAST(DECIMAL '0.000000000000000000' AS decimal(38, 18)) END")).put("SUM_DOUBLE_FINAL", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '6' THEN SUM_DOUBLE END")).put("SUM_DOUBLE_FINAL_DEFAULT", PlanMatchPattern.expression("CASE WHEN COL_BIGINT = BIGINT '6' THEN SUM_DOUBLE ELSE 0E0 END")).buildOrThrow(), PlanMatchPattern.aggregation(PlanMatchPattern.singleGroupingSet("COL_BIGINT"), ImmutableMap.of(Optional.of("SUM_BIGINT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("COL_BIGINT")), Optional.of("SUM_INT_CAST"), PlanMatchPattern.functionCall("sum", ImmutableList.of("VALUE_INT_CAST")), Optional.of("SUM_TINYINT"), PlanMatchPattern.functionCall("sum", ImmutableList.of("VALUE_TINYINT_CAST")), Optional.of("SUM_DECIMAL"), PlanMatchPattern.functionCall("sum", ImmutableList.of("COL_DECIMAL")), Optional.of("SUM_LONG_DECIMAL"), PlanMatchPattern.functionCall("sum", ImmutableList.of("COL_LONG_DECIMAL")), Optional.of("SUM_DOUBLE"), PlanMatchPattern.functionCall("sum", ImmutableList.of("COL_DOUBLE"))), Optional.empty(), AggregationNode.Step.SINGLE, PlanMatchPattern.exchange(PlanMatchPattern.project(ImmutableMap.of("VALUE_INT_CAST", PlanMatchPattern.expression("CAST(CAST(COL_BIGINT AS INTEGER) AS BIGINT)"), "VALUE_TINYINT_CAST", PlanMatchPattern.expression("CAST(COL_TINYINT AS BIGINT)")), PlanMatchPattern.tableScan("t", ImmutableMap.of("COL_BIGINT", "col_bigint", "COL_TINYINT", "col_tinyint", "COL_DECIMAL", "col_decimal", "COL_LONG_DECIMAL", "col_long_decimal", "COL_DOUBLE", "col_double")))))))));
    }

    @Test
    public void testPreAggregatesSumAggregationsWithZeroDefault() {
        assertFires("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_bigint ELSE BIGINT '0' END), sum(CASE WHEN col_bigint = 2 THEN col_tinyint ELSE TINYINT '0' END), sum(CASE WHEN col_bigint = 3 THEN col_double ELSE DOUBLE '0' END), sum(CASE WHEN col_bigint = 4 THEN col_decimal ELSE DECIMAL '0.0' END), sum(CASE WHEN col_bigint = 5 THEN col_long_decimal ELSE DECIMAL '0.000000000000000000' END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testPreAggregatesWithoutNewExtraGroupingKeys() {
        assertFires("SELECT col_bigint, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END) FROM t GROUP BY col_bigint");
    }

    @Test
    public void testDoesNotFireWithGroupingSets() {
        assertThatDoesNotFire("SELECT col_varchar, col_bigint, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END) FROM t GROUP BY GROUPING SETS ((col_varchar), (col_bigint))");
    }

    @Test
    public void testDoesNotFireWithoutEnoughAggregations() {
        assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireWithMultipleExtraGroupingKeys() {
        assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_decimal = DECIMAL '4.1' THEN col_decimal END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForSearchedCaseExpressionWithMultipleWithClauses() {
        assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END), sum(CASE WHEN col_bigint = 5 THEN col_decimal WHEN col_bigint = 6 THEN col_decimal * 2 END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForNonCumulativeAggregation() {
        assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), count(CASE WHEN col_bigint = 4 THEN col_decimal END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForSumAggregationWithNonZeroDefaultValue() {
        assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal ELSE 1 END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForMinAggregationWithNonNullDefaultValue() {
        assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), min(CASE WHEN col_bigint = 4 THEN col_decimal ELSE 0 END) FROM t GROUP BY col_varchar");
    }

    @Test
    public void testDoesNotFireForNonCaseAggregation() {
        assertThatDoesNotFire("SELECT col_varchar, sum(CASE WHEN col_bigint = 1 THEN col_decimal END), sum(CASE WHEN col_bigint = 2 THEN col_decimal END), sum(CASE WHEN col_bigint = 3 THEN col_decimal END), sum(CASE WHEN col_bigint = 4 THEN col_decimal END), sum(col_decimal) FROM t GROUP BY col_varchar");
    }

    private void assertFires(@Language("SQL") String str) {
        Plan plan = plan(str);
        Class<AggregationNode> cls = AggregationNode.class;
        Objects.requireNonNull(AggregationNode.class);
        Assertions.assertThat(countOfMatchingNodes(plan, (v1) -> {
            return r1.isInstance(v1);
        })).isEqualTo(2);
    }

    private void assertThatDoesNotFire(@Language("SQL") String str) {
        Plan plan = plan(str);
        Class<AggregationNode> cls = AggregationNode.class;
        Objects.requireNonNull(AggregationNode.class);
        Assertions.assertThat(countOfMatchingNodes(plan, (v1) -> {
            return r1.isInstance(v1);
        })).isEqualTo(1);
    }

    private static int countOfMatchingNodes(Plan plan, Predicate<PlanNode> predicate) {
        return PlanNodeSearcher.searchFrom(plan.getRoot()).where(predicate).count();
    }
}
