package io.trino.sql.planner.iterative.rule.test;

import com.google.common.collect.ImmutableList;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.TableHandle;
import io.trino.plugin.tpch.TpchTableHandle;
import io.trino.spi.connector.TestingColumnHandle;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.tree.Expression;
import io.trino.testing.TestingTransactionHandle;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import org.assertj.core.api.Assertions;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/test/TestRuleTester.class */
public class TestRuleTester {

    /* JADX INFO: Access modifiers changed from: private */
    @FunctionalInterface
    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/test/TestRuleTester$RuleApplyImplementation.class */
    public interface RuleApplyImplementation<T> {
        Rule.Result apply(T t, Captures captures, Rule.Context context);
    }

    @Test
    public void testReportWrongMatch() {
        RuleTester defaultRuleTester = RuleTester.defaultRuleTester();
        try {
            RuleAssert on = defaultRuleTester.assertThat(rule("testReportWrongMatch rule", Pattern.typeOf(PlanNode.class), (planNode, captures, context) -> {
                return Rule.Result.ofPlanNode(planNode.replaceChildren(planNode.getSources()));
            })).on(planBuilder -> {
                return planBuilder.project(Assignments.of(planBuilder.symbol("y"), PlanBuilder.expression("x")), planBuilder.values((List<Symbol>) ImmutableList.of(planBuilder.symbol("x")), (List<List<Expression>>) ImmutableList.of(ImmutableList.of(PlanBuilder.expression("1")))));
            });
            PlanMatchPattern values = PlanMatchPattern.values((List<String>) ImmutableList.of("different"), (List<List<Expression>>) ImmutableList.of());
            Assertions.assertThatThrownBy(() -> {
                on.matches(values);
            }).isInstanceOf(AssertionError.class).hasMessageMatching("(?s)Plan does not match, expected .* but found .*");
            if (defaultRuleTester != null) {
                defaultRuleTester.close();
            }
        } catch (Throwable th) {
            if (defaultRuleTester != null) {
                try {
                    defaultRuleTester.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testReportNoFire() {
        RuleTester defaultRuleTester = RuleTester.defaultRuleTester();
        try {
            RuleAssert on = defaultRuleTester.assertThat(rule("testReportNoFire rule", Pattern.typeOf(PlanNode.class), (planNode, captures, context) -> {
                return Rule.Result.empty();
            })).on(planBuilder -> {
                return planBuilder.values(List.of(planBuilder.symbol("x")), List.of(List.of(PlanBuilder.expression("1"))));
            });
            PlanMatchPattern values = PlanMatchPattern.values((List<String>) List.of("whatever"), (List<List<Expression>>) List.of());
            Assertions.assertThatThrownBy(() -> {
                on.matches(values);
            }).isInstanceOf(AssertionError.class).hasMessageMatching("testReportNoFire rule did not fire for:(?s:.*)");
            if (defaultRuleTester != null) {
                defaultRuleTester.close();
            }
        } catch (Throwable th) {
            if (defaultRuleTester != null) {
                try {
                    defaultRuleTester.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Test
    public void testReportNoFireWithTableScan() {
        RuleTester defaultRuleTester = RuleTester.defaultRuleTester();
        try {
            RuleAssert on = defaultRuleTester.assertThat(rule("testReportNoFireWithTableScan rule", Pattern.typeOf(PlanNode.class), (planNode, captures, context) -> {
                return Rule.Result.empty();
            })).on(planBuilder -> {
                return planBuilder.tableScan(new TableHandle(defaultRuleTester.getCurrentConnectorId(), new TpchTableHandle("sf1", "nation", 1.0d), TestingTransactionHandle.create()), List.of(planBuilder.symbol("x")), Map.of(planBuilder.symbol("x"), new TestingColumnHandle("column")));
            });
            PlanMatchPattern values = PlanMatchPattern.values((List<String>) List.of("whatever"), (List<List<Expression>>) List.of());
            Assertions.assertThatThrownBy(() -> {
                on.matches(values);
            }).isInstanceOf(AssertionError.class).hasMessageMatching("testReportNoFireWithTableScan rule did not fire for:\n(?s:.*)\\QEstimates: {rows: 25 (225B), cpu: 225, memory: 0B, network: 0B}\\E\n(?s:.*)");
            if (defaultRuleTester != null) {
                defaultRuleTester.close();
            }
        } catch (Throwable th) {
            if (defaultRuleTester != null) {
                try {
                    defaultRuleTester.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private static <T> Rule<T> rule(final String str, final Pattern<T> pattern, final RuleApplyImplementation<T> ruleApplyImplementation) {
        Objects.requireNonNull(str, "name is null");
        Objects.requireNonNull(pattern, "pattern is null");
        Objects.requireNonNull(ruleApplyImplementation, "apply is null");
        return new Rule<T>() { // from class: io.trino.sql.planner.iterative.rule.test.TestRuleTester.1
            public String toString() {
                return str;
            }

            public Pattern<T> getPattern() {
                return pattern;
            }

            public Rule.Result apply(T t, Captures captures, Rule.Context context) {
                return ruleApplyImplementation.apply(t, captures, context);
            }
        };
    }
}
