package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.trino.Session;
import io.trino.connector.MockConnectorColumnHandle;
import io.trino.connector.MockConnectorFactory;
import io.trino.connector.MockConnectorTableHandle;
import io.trino.metadata.TableHandle;
import io.trino.spi.TrinoException;
import io.trino.spi.connector.CatalogSchemaTableName;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ColumnMetadata;
import io.trino.spi.connector.ConnectorTableHandle;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.connector.TableScanRedirectApplicationResult;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.VarcharType;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.RuleAssert;
import io.trino.sql.planner.iterative.rule.test.RuleTester;
import io.trino.testing.LocalQueryRunner;
import io.trino.testing.TestingSession;
import io.trino.testing.TestingTransactionHandle;
import io.trino.tests.BogusType;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestApplyTableScanRedirection.class */
public class TestApplyTableScanRedirection {
    private static final String TEST_SCHEMA = "test_schema";
    private static final String TEST_TABLE = "test_table";
    private static final SchemaTableName SOURCE_TABLE = new SchemaTableName("test_schema", TEST_TABLE);
    private static final Session MOCK_SESSION = TestingSession.testSessionBuilder().setCatalog("test_catalog").setSchema("test_schema").build();
    private static final String SOURCE_COLUMN_NAME_A = "source_col_a";
    private static final ColumnHandle SOURCE_COLUMN_HANDLE_A = new MockConnectorColumnHandle(SOURCE_COLUMN_NAME_A, VarcharType.VARCHAR);
    private static final String SOURCE_COLUMN_NAME_B = "source_col_b";
    private static final ColumnHandle SOURCE_COLUMN_HANDLE_B = new MockConnectorColumnHandle(SOURCE_COLUMN_NAME_B, VarcharType.VARCHAR);
    private static final SchemaTableName DESTINATION_TABLE = new SchemaTableName("target_schema", "target_table");
    private static final String DESTINATION_COLUMN_NAME_A = "destination_col_a";
    private static final ColumnHandle DESTINATION_COLUMN_HANDLE_A = new MockConnectorColumnHandle(DESTINATION_COLUMN_NAME_A, VarcharType.VARCHAR);
    private static final String DESTINATION_COLUMN_NAME_B = "destination_col_b";
    private static final ColumnHandle DESTINATION_COLUMN_HANDLE_B = new MockConnectorColumnHandle(DESTINATION_COLUMN_NAME_B, VarcharType.VARCHAR);
    private static final String DESTINATION_COLUMN_NAME_C = "destination_col_c";
    private static final ColumnHandle DESTINATION_COLUMN_HANDLE_C = new MockConnectorColumnHandle(DESTINATION_COLUMN_NAME_C, BigintType.BIGINT);
    private static final String DESTINATION_COLUMN_NAME_D = "destination_col_d";

    private static TableHandle createTableHandle(RuleTester ruleTester, ConnectorTableHandle connectorTableHandle) {
        return new TableHandle(ruleTester.getCurrentCatalogHandle(), connectorTableHandle, TestingTransactionHandle.create());
    }

    @Test
    public void testDoesNotFire() {
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(Optional.empty())).build();
        try {
            build.assertThat(new ApplyTableScanRedirection(build.getPlannerContext())).withSession(MOCK_SESSION).on(planBuilder -> {
                Symbol symbol = planBuilder.symbol(SOURCE_COLUMN_NAME_A, VarcharType.VARCHAR);
                return planBuilder.tableScan(createTableHandle(build, new MockConnectorTableHandle(SOURCE_TABLE)), ImmutableList.of(symbol), ImmutableMap.of(symbol, SOURCE_COLUMN_HANDLE_A));
            }).doesNotFire();
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testDoesNotFireForDeleteTableScan() {
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(Optional.of(getMockApplyRedirect(ImmutableMap.of(SOURCE_COLUMN_HANDLE_A, DESTINATION_COLUMN_NAME_A))))).build();
        try {
            build.assertThat(new ApplyTableScanRedirection(build.getPlannerContext())).withSession(MOCK_SESSION).on(planBuilder -> {
                Symbol symbol = planBuilder.symbol(SOURCE_COLUMN_NAME_A, VarcharType.VARCHAR);
                return planBuilder.tableScan(createTableHandle(build, new MockConnectorTableHandle(SOURCE_TABLE)), (List<Symbol>) ImmutableList.of(symbol), (Map<Symbol, ColumnHandle>) ImmutableMap.of(symbol, SOURCE_COLUMN_HANDLE_A), true);
            }).doesNotFire();
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void doesNotFireIfNoTableScan() {
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(Optional.of(getMockApplyRedirect(ImmutableMap.of(SOURCE_COLUMN_HANDLE_A, DESTINATION_COLUMN_NAME_A))))).build();
        try {
            build.assertThat(new ApplyTableScanRedirection(build.getPlannerContext())).withSession(MOCK_SESSION).on(planBuilder -> {
                return planBuilder.values(planBuilder.symbol("a", BigintType.BIGINT));
            }).doesNotFire();
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testMismatchedTypesWithCoercion() {
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(Optional.of(getMockApplyRedirect(ImmutableMap.of(SOURCE_COLUMN_HANDLE_A, DESTINATION_COLUMN_NAME_C))))).build();
        try {
            RuleAssert on = build.assertThat(new ApplyTableScanRedirection(build.getPlannerContext())).withSession(MOCK_SESSION).on(planBuilder -> {
                Symbol symbol = planBuilder.symbol(SOURCE_COLUMN_NAME_A, VarcharType.VARCHAR);
                return planBuilder.tableScan(createTableHandle(build, new MockConnectorTableHandle(SOURCE_TABLE)), ImmutableList.of(symbol), ImmutableMap.of(symbol, SOURCE_COLUMN_HANDLE_A));
            });
            ImmutableMap of = ImmutableMap.of("COL", PlanMatchPattern.expression("CAST(DEST_COL AS VARCHAR)"));
            MockConnectorTableHandle mockConnectorTableHandle = new MockConnectorTableHandle(DESTINATION_TABLE);
            Predicate predicate = (v1) -> {
                return r2.equals(v1);
            };
            TupleDomain all = TupleDomain.all();
            ColumnHandle columnHandle = DESTINATION_COLUMN_HANDLE_C;
            Objects.requireNonNull(columnHandle);
            on.matches(PlanMatchPattern.project(of, PlanMatchPattern.tableScan(predicate, all, ImmutableMap.of("DEST_COL", (v1) -> {
                return r5.equals(v1);
            }))));
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testMismatchedTypesWithMissingCoercion() {
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(Optional.of(getMockApplyRedirect(ImmutableMap.of(SOURCE_COLUMN_HANDLE_A, DESTINATION_COLUMN_NAME_D))))).build();
        try {
            LocalQueryRunner queryRunner = build.getQueryRunner();
            queryRunner.inTransaction(MOCK_SESSION, session -> {
                Assertions.assertThatThrownBy(() -> {
                    queryRunner.createPlan(session, "SELECT source_col_a FROM test_table");
                }).isInstanceOf(TrinoException.class).hasMessageMatching("Cast not possible from redirected column test_catalog.target_schema.target_table.destination_col_d with type Bogus to source column .*test_catalog.test_schema.test_table.*source_col_a.* with type: varchar");
                return null;
            });
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testApplyTableScanRedirection() {
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(Optional.of(getMockApplyRedirect(ImmutableMap.of(SOURCE_COLUMN_HANDLE_A, DESTINATION_COLUMN_NAME_A))))).build();
        try {
            RuleAssert on = build.assertThat(new ApplyTableScanRedirection(build.getPlannerContext())).withSession(MOCK_SESSION).on(planBuilder -> {
                Symbol symbol = planBuilder.symbol(SOURCE_COLUMN_NAME_A, VarcharType.VARCHAR);
                return planBuilder.tableScan(createTableHandle(build, new MockConnectorTableHandle(SOURCE_TABLE)), ImmutableList.of(symbol), ImmutableMap.of(symbol, SOURCE_COLUMN_HANDLE_A));
            });
            MockConnectorTableHandle mockConnectorTableHandle = new MockConnectorTableHandle(DESTINATION_TABLE);
            Predicate predicate = (v1) -> {
                return r1.equals(v1);
            };
            TupleDomain all = TupleDomain.all();
            ColumnHandle columnHandle = DESTINATION_COLUMN_HANDLE_A;
            Objects.requireNonNull(columnHandle);
            on.matches(PlanMatchPattern.tableScan(predicate, all, ImmutableMap.of("DEST_COL", (v1) -> {
                return r4.equals(v1);
            })));
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testApplyTableScanRedirectionWithFilter() {
        RuleTester build = RuleTester.builder().withDefaultCatalogConnectorFactory(createMockFactory(Optional.of(getMockApplyRedirect(ImmutableMap.of(SOURCE_COLUMN_HANDLE_A, DESTINATION_COLUMN_NAME_A, SOURCE_COLUMN_HANDLE_B, DESTINATION_COLUMN_NAME_B))))).build();
        try {
            ApplyTableScanRedirection applyTableScanRedirection = new ApplyTableScanRedirection(build.getPlannerContext());
            TupleDomain withColumnDomains = TupleDomain.withColumnDomains(ImmutableMap.of(SOURCE_COLUMN_HANDLE_A, Domain.singleValue(VarcharType.VARCHAR, Slices.utf8Slice("foo"))));
            RuleAssert on = build.assertThat(applyTableScanRedirection).withSession(MOCK_SESSION).on(planBuilder -> {
                Symbol symbol = planBuilder.symbol(SOURCE_COLUMN_NAME_A, VarcharType.VARCHAR);
                return planBuilder.tableScan(createTableHandle(build, new MockConnectorTableHandle(SOURCE_TABLE, withColumnDomains, Optional.empty())), (List<Symbol>) ImmutableList.of(symbol), (Map<Symbol, ColumnHandle>) ImmutableMap.of(symbol, SOURCE_COLUMN_HANDLE_A), (TupleDomain<ColumnHandle>) withColumnDomains);
            });
            MockConnectorTableHandle mockConnectorTableHandle = new MockConnectorTableHandle(DESTINATION_TABLE);
            Predicate predicate = (v1) -> {
                return r2.equals(v1);
            };
            TupleDomain all = TupleDomain.all();
            ColumnHandle columnHandle = DESTINATION_COLUMN_HANDLE_A;
            Objects.requireNonNull(columnHandle);
            on.matches(PlanMatchPattern.filter("DEST_COL = VARCHAR 'foo'", PlanMatchPattern.tableScan(predicate, all, ImmutableMap.of("DEST_COL", (v1) -> {
                return r5.equals(v1);
            }))));
            RuleAssert on2 = build.assertThat(applyTableScanRedirection).withSession(MOCK_SESSION).on(planBuilder2 -> {
                Symbol symbol = planBuilder2.symbol(SOURCE_COLUMN_NAME_B, VarcharType.VARCHAR);
                return planBuilder2.tableScan(createTableHandle(build, new MockConnectorTableHandle(SOURCE_TABLE, withColumnDomains, Optional.empty())), (List<Symbol>) ImmutableList.of(symbol), (Map<Symbol, ColumnHandle>) ImmutableMap.of(symbol, SOURCE_COLUMN_HANDLE_B), TupleDomain.all());
            });
            ImmutableMap of = ImmutableMap.of("expr", PlanMatchPattern.expression("DEST_COL_B"));
            MockConnectorTableHandle mockConnectorTableHandle2 = new MockConnectorTableHandle(DESTINATION_TABLE);
            Predicate predicate2 = (v1) -> {
                return r3.equals(v1);
            };
            TupleDomain all2 = TupleDomain.all();
            ColumnHandle columnHandle2 = DESTINATION_COLUMN_HANDLE_A;
            Objects.requireNonNull(columnHandle2);
            Predicate predicate3 = (v1) -> {
                return r6.equals(v1);
            };
            ColumnHandle columnHandle3 = DESTINATION_COLUMN_HANDLE_B;
            Objects.requireNonNull(columnHandle3);
            on2.matches(PlanMatchPattern.project(of, PlanMatchPattern.filter("DEST_COL_A = VARCHAR 'foo'", PlanMatchPattern.tableScan(predicate2, all2, ImmutableMap.of("DEST_COL_A", predicate3, "DEST_COL_B", (v1) -> {
                return r8.equals(v1);
            })))));
            if (build != null) {
                build.close();
            }
        } catch (Throwable th) {
            if (build != null) {
                try {
                    build.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private MockConnectorFactory.ApplyTableScanRedirect getMockApplyRedirect(Map<ColumnHandle, String> map) {
        return (connectorSession, connectorTableHandle) -> {
            CatalogSchemaTableName catalogSchemaTableName = new CatalogSchemaTableName("test_catalog", DESTINATION_TABLE);
            TupleDomain<ColumnHandle> constraint = ((MockConnectorTableHandle) connectorTableHandle).getConstraint();
            Class<MockConnectorColumnHandle> cls = MockConnectorColumnHandle.class;
            Objects.requireNonNull(MockConnectorColumnHandle.class);
            TupleDomain transformKeys = constraint.transformKeys((v1) -> {
                return r5.cast(v1);
            });
            Objects.requireNonNull(map);
            return Optional.of(new TableScanRedirectApplicationResult(catalogSchemaTableName, map, transformKeys.transformKeys((v1) -> {
                return r5.get(v1);
            })));
        };
    }

    private MockConnectorFactory createMockFactory(Optional<MockConnectorFactory.ApplyTableScanRedirect> optional) {
        MockConnectorFactory.Builder withGetColumns = MockConnectorFactory.builder().withGetColumns(schemaTableName -> {
            if (schemaTableName.equals(SOURCE_TABLE)) {
                return ImmutableList.of(new ColumnMetadata(SOURCE_COLUMN_NAME_A, VarcharType.VARCHAR), new ColumnMetadata(SOURCE_COLUMN_NAME_B, VarcharType.VARCHAR));
            }
            if (schemaTableName.equals(DESTINATION_TABLE)) {
                return ImmutableList.of(new ColumnMetadata(DESTINATION_COLUMN_NAME_A, VarcharType.VARCHAR), new ColumnMetadata(DESTINATION_COLUMN_NAME_B, VarcharType.VARCHAR), new ColumnMetadata(DESTINATION_COLUMN_NAME_C, BigintType.BIGINT), new ColumnMetadata(DESTINATION_COLUMN_NAME_D, BogusType.BOGUS));
            }
            throw new IllegalArgumentException();
        });
        Objects.requireNonNull(withGetColumns);
        optional.ifPresent(withGetColumns::withApplyTableScanRedirect);
        return withGetColumns.build();
    }
}
