package io.trino.sql.planner.optimizations;

import com.google.common.collect.ImmutableList;
import io.trino.Session;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.AbstractMockMetadata;
import io.trino.metadata.FunctionManager;
import io.trino.metadata.TableHandle;
import io.trino.metadata.TableMetadata;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorTableMetadata;
import io.trino.spi.connector.SchemaTableName;
import io.trino.spi.type.BigintType;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.testing.TestingSession;
import java.util.List;
import java.util.function.Function;
import org.assertj.core.api.Assertions;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/optimizations/TestBeginTableWrite.class */
public class TestBeginTableWrite {

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/sql/planner/optimizations/TestBeginTableWrite$MockMetadata.class */
    public static class MockMetadata extends AbstractMockMetadata {
        private MockMetadata() {
        }

        @Override // io.trino.metadata.AbstractMockMetadata
        public TableHandle beginDelete(Session session, TableHandle tableHandle) {
            return tableHandle;
        }

        @Override // io.trino.metadata.AbstractMockMetadata
        public TableHandle beginUpdate(Session session, TableHandle tableHandle, List<ColumnHandle> list) {
            return tableHandle;
        }

        @Override // io.trino.metadata.AbstractMockMetadata
        public TableMetadata getTableMetadata(Session session, TableHandle tableHandle) {
            return new TableMetadata(tableHandle.getCatalogName(), new ConnectorTableMetadata(new SchemaTableName("sch", "tab"), ImmutableList.of()));
        }
    }

    @Test
    public void testValidDelete() {
        Assertions.assertThatCode(() -> {
            applyOptimization(planBuilder -> {
                return planBuilder.tableDelete(new SchemaTableName("sch", "tab"), planBuilder.tableScan((List<Symbol>) ImmutableList.of(planBuilder.symbol("rowId")), true), planBuilder.symbol("rowId", BigintType.BIGINT));
            });
        }).doesNotThrowAnyException();
    }

    @Test
    public void testValidUpdate() {
        Assertions.assertThatCode(() -> {
            applyOptimization(planBuilder -> {
                return planBuilder.tableUpdate(new SchemaTableName("sch", "tab"), planBuilder.tableScan((List<Symbol>) ImmutableList.of(planBuilder.symbol("columnToBeUpdated")), true), planBuilder.symbol("rowId", BigintType.BIGINT), ImmutableList.of(planBuilder.symbol("columnToBeUpdated")));
            });
        }).doesNotThrowAnyException();
    }

    @Test
    public void testDeleteWithNonDeletableTableScan() {
        Assertions.assertThatThrownBy(() -> {
            applyOptimization(planBuilder -> {
                return planBuilder.tableDelete(new SchemaTableName("sch", "tab"), planBuilder.join(JoinNode.Type.INNER, planBuilder.tableScan((List<Symbol>) ImmutableList.of(), false), planBuilder.limit(1L, planBuilder.tableScan((List<Symbol>) ImmutableList.of(planBuilder.symbol("rowId")), true)), new JoinNode.EquiJoinClause[0]), planBuilder.symbol("rowId", BigintType.BIGINT));
            });
        }).isInstanceOf(IllegalArgumentException.class).hasMessage("TableScanNode should be an updatable target");
    }

    @Test
    public void testUpdateWithNonUpdatableTableScan() {
        Assertions.assertThatThrownBy(() -> {
            applyOptimization(planBuilder -> {
                return planBuilder.tableUpdate(new SchemaTableName("sch", "tab"), planBuilder.join(JoinNode.Type.INNER, planBuilder.tableScan((List<Symbol>) ImmutableList.of(), false), planBuilder.limit(1L, planBuilder.tableScan((List<Symbol>) ImmutableList.of(planBuilder.symbol("columnToBeUpdated"), planBuilder.symbol("rowId")), true)), new JoinNode.EquiJoinClause[0]), planBuilder.symbol("rowId", BigintType.BIGINT), ImmutableList.of(planBuilder.symbol("columnToBeUpdated")));
            });
        }).isInstanceOf(IllegalArgumentException.class).hasMessage("TableScanNode should be an updatable target");
    }

    @Test
    public void testDeleteWithInvalidNode() {
        Assertions.assertThatThrownBy(() -> {
            applyOptimization(planBuilder -> {
                return planBuilder.tableDelete(new SchemaTableName("sch", "tab"), planBuilder.distinctLimit(10L, ImmutableList.of(planBuilder.symbol("rowId")), planBuilder.tableScan((List<Symbol>) ImmutableList.of(planBuilder.symbol("a")), true)), planBuilder.symbol("rowId", BigintType.BIGINT));
            });
        }).isInstanceOf(IllegalArgumentException.class).hasMessage("Invalid descendant for DeleteNode or UpdateNode: io.trino.sql.planner.plan.DistinctLimitNode");
    }

    @Test
    public void testUpdateWithInvalidNode() {
        Assertions.assertThatThrownBy(() -> {
            applyOptimization(planBuilder -> {
                return planBuilder.tableUpdate(new SchemaTableName("sch", "tab"), planBuilder.distinctLimit(10L, ImmutableList.of(planBuilder.symbol("a"), planBuilder.symbol("rowId")), planBuilder.tableScan((List<Symbol>) ImmutableList.of(planBuilder.symbol("a")), true)), planBuilder.symbol("rowId", BigintType.BIGINT), ImmutableList.of(planBuilder.symbol("columnToBeUpdated")));
            });
        }).isInstanceOf(IllegalArgumentException.class).hasMessage("Invalid descendant for DeleteNode or UpdateNode: io.trino.sql.planner.plan.DistinctLimitNode");
    }

    private void applyOptimization(Function<PlanBuilder, PlanNode> function) {
        MockMetadata mockMetadata = new MockMetadata();
        new BeginTableWrite(mockMetadata, FunctionManager.createTestingFunctionManager()).optimize(function.apply(new PlanBuilder(new PlanNodeIdAllocator(), mockMetadata, TestingSession.testSessionBuilder().build())), TestingSession.testSessionBuilder().build(), TypeProvider.empty(), new SymbolAllocator(), new PlanNodeIdAllocator(), WarningCollector.NOOP);
    }
}
