package io.trino.plugin.geospatial;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.spi.type.IntegerType;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.LongLiteral;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.RewriteSpatialPartitioningAggregation;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.iterative.rule.test.RuleBuilder;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.tree.QualifiedName;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/plugin/geospatial/TestRewriteSpatialPartitioningAggregation.class */
public class TestRewriteSpatialPartitioningAggregation extends BaseRuleTest {
    public TestRewriteSpatialPartitioningAggregation() {
        super(new Plugin[]{new GeoPlugin()});
    }

    @Test
    public void testDoesNotFire() {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.SINGLE).addAggregation(planBuilder.symbol("sp"), PlanBuilder.aggregation("spatial_partitioning", ImmutableList.of(new SymbolReference("geometry"), new SymbolReference("n"))), ImmutableList.of(GeometryType.GEOMETRY, IntegerType.INTEGER)).source(planBuilder.values(new Symbol[]{planBuilder.symbol("geometry"), planBuilder.symbol("n")}));
            });
        }).doesNotFire();
    }

    @Test
    public void test() {
        assertRuleApplication().on(planBuilder -> {
            return planBuilder.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.SINGLE).addAggregation(planBuilder.symbol("sp"), PlanBuilder.aggregation("spatial_partitioning", ImmutableList.of(new SymbolReference("geometry"))), ImmutableList.of(GeometryType.GEOMETRY)).source(planBuilder.values(new Symbol[]{planBuilder.symbol("geometry")}));
            });
        }).matches(PlanMatchPattern.aggregation(ImmutableMap.of("sp", PlanMatchPattern.aggregationFunction("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), PlanMatchPattern.project(ImmutableMap.of("partition_count", PlanMatchPattern.expression(new LongLiteral(100L)), "envelope", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_envelope"), ImmutableList.of(new SymbolReference("geometry"))))), PlanMatchPattern.values(new String[]{"geometry"}))));
        assertRuleApplication().on(planBuilder2 -> {
            return planBuilder2.aggregation(aggregationBuilder -> {
                aggregationBuilder.globalGrouping().step(AggregationNode.Step.SINGLE).addAggregation(planBuilder2.symbol("sp"), PlanBuilder.aggregation("spatial_partitioning", ImmutableList.of(new SymbolReference("envelope"))), ImmutableList.of(GeometryType.GEOMETRY)).source(planBuilder2.values(new Symbol[]{planBuilder2.symbol("envelope")}));
            });
        }).matches(PlanMatchPattern.aggregation(ImmutableMap.of("sp", PlanMatchPattern.aggregationFunction("spatial_partitioning", ImmutableList.of("envelope", "partition_count"))), PlanMatchPattern.project(ImmutableMap.of("partition_count", PlanMatchPattern.expression(new LongLiteral(100L)), "envelope", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_envelope"), ImmutableList.of(new SymbolReference("geometry"))))), PlanMatchPattern.values(new String[]{"geometry"}))));
    }

    private RuleBuilder assertRuleApplication() {
        return tester().assertThat(new RewriteSpatialPartitioningAggregation(tester().getPlannerContext()));
    }
}
