package io.trino.plugin.geospatial;

import com.google.common.base.Preconditions;
import com.google.common.base.Strings;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.DynamicSliceOutput;
import io.trino.Session;
import io.trino.block.BlockSerdeUtil;
import io.trino.geospatial.KdbTree;
import io.trino.geospatial.KdbTreeUtils;
import io.trino.geospatial.Rectangle;
import io.trino.plugin.memory.MemoryConnectorFactory;
import io.trino.plugin.tpch.TpchConnectorFactory;
import io.trino.spi.StandardErrorCode;
import io.trino.spi.TrinoException;
import io.trino.spi.block.Block;
import io.trino.spi.block.TestingBlockEncodingSerde;
import io.trino.spi.predicate.Utils;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureProvider;
import io.trino.sql.ir.Cast;
import io.trino.sql.ir.ComparisonExpression;
import io.trino.sql.ir.DoubleLiteral;
import io.trino.sql.ir.Expression;
import io.trino.sql.ir.FunctionCall;
import io.trino.sql.ir.GenericLiteral;
import io.trino.sql.ir.IsNullPredicate;
import io.trino.sql.ir.LogicalExpression;
import io.trino.sql.ir.NotExpression;
import io.trino.sql.ir.SearchedCaseExpression;
import io.trino.sql.ir.StringLiteral;
import io.trino.sql.ir.SymbolReference;
import io.trino.sql.ir.WhenClause;
import io.trino.sql.planner.assertions.BasePlanTest;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinType;
import io.trino.sql.tree.QualifiedName;
import io.trino.testing.PlanTester;
import io.trino.testing.TestingSession;
import java.util.Base64;
import java.util.List;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;

@TestInstance(TestInstance.Lifecycle.PER_CLASS)
/* loaded from: input_file:io/trino/plugin/geospatial/TestSpatialJoinPlanning.class */
public class TestSpatialJoinPlanning extends BasePlanTest {
    private static final String KDB_TREE_JSON = KdbTreeUtils.toJson(new KdbTree(KdbTree.Node.newLeaf(new Rectangle(0.0d, 0.0d, 10.0d, 10.0d), 0)));
    private Expression kdbTreeLiteral;

    protected PlanTester createPlanTester() {
        PlanTester create = PlanTester.create(TestingSession.testSessionBuilder().setCatalog("memory").setSchema("default").build());
        create.installPlugin(new GeoPlugin());
        create.createCatalog("tpch", new TpchConnectorFactory(1), ImmutableMap.of());
        create.createCatalog("memory", new MemoryConnectorFactory(), ImmutableMap.of());
        create.executeStatement(String.format("CREATE TABLE kdb_tree AS SELECT '%s' AS v", KDB_TREE_JSON));
        create.executeStatement("CREATE TABLE points (lng, lat, name) AS (VALUES (2.1e0, 2.1e0, 'x'))");
        create.executeStatement("CREATE TABLE polygons (wkt, name) AS (VALUES ('POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))', 'a'))");
        return create;
    }

    @BeforeAll
    public void setUp() {
        Block nativeValueToBlock = Utils.nativeValueToBlock(KdbTreeType.KDB_TREE, KdbTreeUtils.fromJson(KDB_TREE_JSON));
        DynamicSliceOutput dynamicSliceOutput = new DynamicSliceOutput(0);
        BlockSerdeUtil.writeBlock(new TestingBlockEncodingSerde(), dynamicSliceOutput, nativeValueToBlock);
        this.kdbTreeLiteral = new FunctionCall(QualifiedName.of("$literal$"), ImmutableList.of(new FunctionCall(QualifiedName.of("from_base64"), ImmutableList.of(new StringLiteral(Base64.getEncoder().encodeToString(dynamicSliceOutput.slice().getBytes()))))));
    }

    @Test
    public void testSpatialJoinContains() {
        assertPlan("SELECT b.name, a.name FROM points a, polygons b WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name")))}))}));
        assertPlan("SELECT * FROM (SELECT length(name), * FROM points), (SELECT length(name), * FROM polygons) WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))), "length", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("length"), ImmutableList.of(new SymbolReference("name"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name", "name"))), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR)))), "length_2", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("length"), ImmutableList.of(new SymbolReference("name_2"))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_2", "name")))}))}));
        assertDistributedPlan("SELECT b.name, a.name FROM points a, polygons b WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", withSpatialPartitioning("kdb_tree"), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), Optional.of(KDB_TREE_JSON), Optional.empty(), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.unnest(PlanMatchPattern.project(ImmutableMap.of("partitions_a", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("st_point"))))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name", "name")))))}), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.unnest(PlanMatchPattern.project(ImmutableMap.of("partitions_b", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("st_geometryfromtext"))))), PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_2", "name")))))}))}));
    }

    @Test
    public void testSpatialJoinWithin() {
        assertPlan("SELECT points.name, polygons.name FROM points, polygons WHERE ST_Within(ST_GeometryFromText(wkt), ST_Point(lng, lat))", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_within"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name")))}))}));
        assertPlan("SELECT * FROM (SELECT length(name), * FROM points), (SELECT length(name), * FROM polygons) WHERE ST_Within(ST_GeometryFromText(wkt), ST_Point(lng, lat))", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_within"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))), "length", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("length"), ImmutableList.of(new SymbolReference("name"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name", "name"))), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR)))), "length_2", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("length"), ImmutableList.of(new SymbolReference("name_2"))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_2", "name")))}))}));
        assertDistributedPlan("SELECT b.name, a.name FROM points a, polygons b WHERE ST_Within(ST_GeometryFromText(wkt), ST_Point(lng, lat))", withSpatialPartitioning("kdb_tree"), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_within"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), Optional.of(KDB_TREE_JSON), Optional.empty(), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.unnest(PlanMatchPattern.project(ImmutableMap.of("partitions_a", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("st_point"))))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name")))))}), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.unnest(PlanMatchPattern.project(ImmutableMap.of("partitions_b", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("st_geometryfromtext"))))), PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name")))))}))}));
    }

    @Test
    public void testInvalidKdbTree() {
        assertInvalidSpatialPartitioning(withSpatialPartitioning("non_existent_table"), "SELECT b.name, a.name FROM points a, polygons b WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", "Table not found: memory.default.non_existent_table");
        getPlanTester().executeStatement("CREATE TABLE empty_table AS SELECT 'a' AS v WHERE false");
        assertInvalidSpatialPartitioning(withSpatialPartitioning("empty_table"), "SELECT b.name, a.name FROM points a, polygons b WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", "Expected exactly one row for table memory.default.empty_table, but got none");
        getPlanTester().executeStatement("CREATE TABLE invalid_kdb_tree AS SELECT 'invalid-json' AS v");
        assertInvalidSpatialPartitioning(withSpatialPartitioning("invalid_kdb_tree"), "SELECT b.name, a.name FROM points a, polygons b WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", "Invalid JSON string for KDB tree: .*");
        getPlanTester().executeStatement(String.format("CREATE TABLE too_many_rows AS SELECT * FROM (VALUES '%s', '%s') AS t(v)", KDB_TREE_JSON, KDB_TREE_JSON));
        assertInvalidSpatialPartitioning(withSpatialPartitioning("too_many_rows"), "SELECT b.name, a.name FROM points a, polygons b WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", "Expected exactly one row for table memory.default.too_many_rows, but found 2 rows");
        getPlanTester().executeStatement("CREATE TABLE too_many_columns AS SELECT '%s' as c1, 100 as c2");
        assertInvalidSpatialPartitioning(withSpatialPartitioning("too_many_columns"), "SELECT b.name, a.name FROM points a, polygons b WHERE ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", "Expected single column for table memory.default.too_many_columns, but found 2 columns");
    }

    private void assertInvalidSpatialPartitioning(Session session, String str, String str2) {
        PlanTester planTester = getPlanTester();
        try {
            planTester.inTransaction(session, session2 -> {
                return planTester.createPlan(session2, str);
            });
            throw new AssertionError(String.format("Expected query to fail: %s", str));
        } catch (TrinoException e) {
            Assertions.assertThat(e.getErrorCode()).isEqualTo(StandardErrorCode.INVALID_SPATIAL_PARTITIONING.toErrorCode());
            if (!Strings.nullToEmpty(e.getMessage()).matches(str2)) {
                throw new AssertionError(String.format("Expected exception message '%s' to match '%s' for query: %s", e.getMessage(), str2, str), e);
            }
        }
    }

    @Test
    public void testSpatialJoinIntersects() {
        assertPlan("SELECT b.name, a.name FROM polygons a, polygons b WHERE ST_Intersects(ST_GeometryFromText(a.wkt), ST_GeometryFromText(b.wkt))", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_intersects"), ImmutableList.of(new SymbolReference("geometry_a"), new SymbolReference("geometry_b"))), PlanMatchPattern.project(ImmutableMap.of("geometry_a", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt_a", "wkt", "name_a", "name"))), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("geometry_b", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt_b", "wkt", "name_b", "name")))}))}));
        assertDistributedPlan("SELECT b.name, a.name FROM polygons a, polygons b WHERE ST_Intersects(ST_GeometryFromText(a.wkt), ST_GeometryFromText(b.wkt))", withSpatialPartitioning("default.kdb_tree"), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_intersects"), ImmutableList.of(new SymbolReference("geometry_a"), new SymbolReference("geometry_b"))), Optional.of(KDB_TREE_JSON), Optional.empty(), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.unnest(PlanMatchPattern.project(ImmutableMap.of("partitions_a", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("geometry_a"))))), PlanMatchPattern.project(ImmutableMap.of("geometry_a", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt_a", "wkt", "name_a", "name")))))}), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("partitions_b", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("geometry_b"))))), PlanMatchPattern.project(ImmutableMap.of("geometry_b", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt_b", "wkt", "name_b", "name"))))}))}));
    }

    @Test
    public void testNotContains() {
        assertPlan("SELECT b.name, a.name FROM points a, polygons b WHERE NOT ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.filter(new NotExpression(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))), new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.left(PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))).right(PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))}));
        }))}));
    }

    @Test
    public void testNotIntersects() {
        assertPlan(String.format("SELECT b.name, a.name FROM " + singleRow("IF(rand() >= 0, 'POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))')", "'a'") + " AS a (wkt, name), " + singleRow("IF(rand() >= 0, 'POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))')", "'a'") + " AS b (wkt, name)            WHERE NOT ST_Intersects(ST_GeometryFromText(a.wkt), ST_GeometryFromText(b.wkt))", singleRow()), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.filter(new NotExpression(functionCall("ST_Intersects", ImmutableList.of(GeometryType.GEOMETRY, GeometryType.GEOMETRY), ImmutableList.of(functionCall("ST_GeometryFromText", ImmutableList.of(VarcharType.VARCHAR), ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VarcharType.VARCHAR))), functionCall("ST_GeometryFromText", ImmutableList.of(VarcharType.VARCHAR), ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VarcharType.VARCHAR)))))), PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.left(PlanMatchPattern.project(ImmutableMap.of("wkt_a", PlanMatchPattern.expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new DoubleLiteral(0.0d)), new StringLiteral("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))"))), Optional.empty())), "name_a", PlanMatchPattern.expression(new StringLiteral("a"))), singleRow())).right(PlanMatchPattern.any(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("wkt_b", PlanMatchPattern.expression(new SearchedCaseExpression(ImmutableList.of(new WhenClause(new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new DoubleLiteral(0.0d)), new StringLiteral("POLYGON ((30 10, 40 40, 20 40, 10 20, 30 10))"))), Optional.empty())), "name_b", PlanMatchPattern.expression(new StringLiteral("a"))), singleRow())}));
        }))}));
    }

    @Test
    public void testContainsWithEquiClause() {
        assertPlan("SELECT b.name, a.name FROM points a, polygons b WHERE a.name = b.name AND ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.equiCriteria("name_a", "name_b").filter(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))), new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat")))))).left(PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))})).right(PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name"))}));
        })}));
    }

    @Test
    public void testIntersectsWithEquiClause() {
        assertPlan("SELECT b.name, a.name FROM polygons a, polygons b WHERE a.name = b.name AND ST_Intersects(ST_GeometryFromText(a.wkt), ST_GeometryFromText(b.wkt))", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.join(JoinType.INNER, builder -> {
            builder.equiCriteria("name_a", "name_b").filter(new FunctionCall(QualifiedName.of("st_intersects"), ImmutableList.of(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_a"), VarcharType.VARCHAR))), new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt_b"), VarcharType.VARCHAR)))))).left(PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt_a", "wkt", "name_a", "name"))})).right(PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt_b", "wkt", "name_b", "name"))}));
        })}));
    }

    @Test
    public void testSpatialLeftJoins() {
        assertPlan("SELECT b.name, a.name FROM points a LEFT JOIN polygons b ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat))", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialLeftJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name")))}))}));
        assertPlan("SELECT b.name, a.name FROM points a LEFT JOIN polygons b ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat)) AND a.name <> b.name", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialLeftJoin(new LogicalExpression(LogicalExpression.Operator.AND, ImmutableList.of(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), new ComparisonExpression(ComparisonExpression.Operator.NOT_EQUAL, new SymbolReference("name_a"), new SymbolReference("name_b")))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name")))}))}));
        assertPlan("SELECT b.name, a.name FROM points a LEFT JOIN polygons b ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat)) AND rand() < 0.5", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialLeftJoin(new LogicalExpression(LogicalExpression.Operator.AND, ImmutableList.of(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new FunctionCall(QualifiedName.of("random"), ImmutableList.of()), new DoubleLiteral(0.5d)))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name")))}))}));
        assertPlan("SELECT b.name, a.name FROM points a LEFT JOIN polygons b    ON ST_Contains(ST_GeometryFromText(wkt), ST_Point(lng, lat)) WHERE concat(a.name, b.name) is null", PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.filter(new IsNullPredicate(new FunctionCall(QualifiedName.of("concat"), ImmutableList.of(new Cast(new SymbolReference("name_a"), VarcharType.VARCHAR), new Cast(new SymbolReference("name_b"), VarcharType.VARCHAR)))), PlanMatchPattern.spatialLeftJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("st_geometryfromtext"), new SymbolReference("st_point"))), PlanMatchPattern.project(ImmutableMap.of("st_point", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_point"), ImmutableList.of(new SymbolReference("lng"), new SymbolReference("lat"))))), PlanMatchPattern.tableScan("points", ImmutableMap.of("lng", "lng", "lat", "lat", "name_a", "name"))), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("st_geometryfromtext", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("wkt"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("polygons", ImmutableMap.of("wkt", "wkt", "name_b", "name")))})))}));
    }

    @Test
    public void testDistributedSpatialJoinOverUnion() {
        assertDistributedPlan("SELECT a.name, b.name FROM (SELECT name FROM tpch.tiny.region UNION ALL SELECT name FROM tpch.tiny.nation) a, tpch.tiny.customer b WHERE ST_Contains(ST_GeometryFromText(a.name), ST_GeometryFromText(b.name))", withSpatialPartitioning("kdb_tree"), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("g1"), new SymbolReference("g3"))), Optional.of(KDB_TREE_JSON), Optional.empty(), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.unnest(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("p1", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("g1"))))), PlanMatchPattern.project(ImmutableMap.of("g1", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_a1"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("region", ImmutableMap.of("name_a1", "name")))), PlanMatchPattern.project(ImmutableMap.of("p2", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("g2"))))), PlanMatchPattern.project(ImmutableMap.of("g2", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_a2"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("nation", ImmutableMap.of("name_a2", "name"))))}))}), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.unnest(PlanMatchPattern.project(ImmutableMap.of("p3", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("g3"))))), PlanMatchPattern.project(ImmutableMap.of("g3", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_b"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("customer", ImmutableMap.of("name_b", "name")))))}))}));
        assertDistributedPlan("SELECT a.name, b.name FROM tpch.tiny.customer a, (SELECT name FROM tpch.tiny.region UNION ALL SELECT name FROM tpch.tiny.nation) b WHERE ST_Contains(ST_GeometryFromText(a.name), ST_GeometryFromText(b.name))", withSpatialPartitioning("kdb_tree"), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.spatialJoin(new FunctionCall(QualifiedName.of("st_contains"), ImmutableList.of(new SymbolReference("g1"), new SymbolReference("g2"))), Optional.of(KDB_TREE_JSON), Optional.empty(), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.unnest(PlanMatchPattern.project(ImmutableMap.of("p1", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("g1"))))), PlanMatchPattern.project(ImmutableMap.of("g1", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_a"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("customer", ImmutableMap.of("name_a", "name")))))}), PlanMatchPattern.anyTree(new PlanMatchPattern[]{PlanMatchPattern.unnest(PlanMatchPattern.exchange(ExchangeNode.Scope.LOCAL, ExchangeNode.Type.REPARTITION, new PlanMatchPattern[]{PlanMatchPattern.project(ImmutableMap.of("p2", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("g2"))))), PlanMatchPattern.project(ImmutableMap.of("g2", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_b1"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("region", ImmutableMap.of("name_b1", "name")))), PlanMatchPattern.project(ImmutableMap.of("p3", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("spatial_partitions"), ImmutableList.of(this.kdbTreeLiteral, new SymbolReference("g3"))))), PlanMatchPattern.project(ImmutableMap.of("g3", PlanMatchPattern.expression(new FunctionCall(QualifiedName.of("st_geometryfromtext"), ImmutableList.of(new Cast(new SymbolReference("name_b2"), VarcharType.VARCHAR))))), PlanMatchPattern.tableScan("nation", ImmutableMap.of("name_b2", "name"))))}))}))}));
    }

    private String singleRow(String... strArr) {
        return String.format("(SELECT %s FROM tpch.tiny.region WHERE regionkey = 1)", String.join(", ", strArr));
    }

    private PlanMatchPattern singleRow() {
        return PlanMatchPattern.filter(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("regionkey"), new GenericLiteral(BigintType.BIGINT, "1")), PlanMatchPattern.tableScan("region", ImmutableMap.of("regionkey", "regionkey")));
    }

    private Session withSpatialPartitioning(String str) {
        return Session.builder(getPlanTester().getDefaultSession()).setSystemProperty("spatial_partitioning_table_name", str).build();
    }

    private static String doubleLiteral(double d) {
        Preconditions.checkArgument(Double.isFinite(d));
        return String.format("%.16E", Double.valueOf(d));
    }

    private FunctionCall functionCall(String str, List<Type> list, List<Expression> list2) {
        return new FunctionCall(getPlanTester().getPlannerContext().getMetadata().resolveBuiltinFunction(str, TypeSignatureProvider.fromTypes(list)).toQualifiedName(), list2);
    }
}
