package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.GenericLiteral;
import io.trino.sql.tree.LogicalExpression;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushInequalityFilterExpressionBelowJoinRuleSet.class */
public class TestPushInequalityFilterExpressionBelowJoinRuleSet extends BaseRuleTest {
    private PushInequalityFilterExpressionBelowJoinRuleSet ruleSet;

    public TestPushInequalityFilterExpressionBelowJoinRuleSet() {
        super(new Plugin[0]);
    }

    @BeforeAll
    public void setUpBeforeClass() {
        this.ruleSet = new PushInequalityFilterExpressionBelowJoinRuleSet(tester().getMetadata(), tester().getTypeAnalyzer());
    }

    @Test
    public void testExpressionNotPushedDownToLeftJoinSource() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinNode.Type.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(ComparisonExpression.Operator.LESS_THAN, add(symbol, 1L), symbol2.toSymbolReference()), new JoinNode.EquiJoinClause[0]);
        }).doesNotFire();
    }

    @Test
    public void testJoinFilterExpressionPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinNode.Type.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(ComparisonExpression.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), new JoinNode.EquiJoinClause[0]);
        }).matches(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.filter("expr < a").left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("b + BIGINT '1'")), PlanMatchPattern.values("b")));
        }));
    }

    @Test
    public void testManyJoinFilterExpressionsPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinNode.Type.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) LogicalExpression.and(comparison(ComparisonExpression.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), comparison(ComparisonExpression.Operator.GREATER_THAN, add(symbol2, 10L), symbol.toSymbolReference())), new JoinNode.EquiJoinClause[0]);
        }).matches(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.filter("expr_less < a and expr_greater > a").left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr_less", PlanMatchPattern.expression("b + BIGINT '1'"), "expr_greater", PlanMatchPattern.expression("b + BIGINT '10'")), PlanMatchPattern.values("b")));
        }));
    }

    @Test
    public void testOnlyRightJoinFilterExpressionPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinNode.Type.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(ComparisonExpression.Operator.LESS_THAN, add(symbol2, 1L), add(symbol, 2L)), new JoinNode.EquiJoinClause[0]);
        }).matches(PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.filter("expr < a + BIGINT '2'").left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("b + BIGINT '1'")), PlanMatchPattern.values("b")));
        }));
    }

    @Test
    public void testParentFilterExpressionNotPushedDownToLeftJoinSource() {
        tester().assertThat(this.ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.filter(comparison(ComparisonExpression.Operator.LESS_THAN, add(symbol, 1L), symbol2.toSymbolReference()), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(symbol), planBuilder.values(symbol2), new JoinNode.EquiJoinClause[0]));
        }).doesNotFire();
    }

    @Test
    public void testParentFilterExpressionPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.filter(comparison(ComparisonExpression.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(symbol), planBuilder.values(symbol2), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.project(PlanMatchPattern.filter("expr < a", PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr", PlanMatchPattern.expression("b + BIGINT '1'")), PlanMatchPattern.values("b")));
        }))));
    }

    @Test
    public void testManyParentFilterExpressionsPushedDownToRightJoinSource() {
        tester().assertThat(this.ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.filter(LogicalExpression.and(comparison(ComparisonExpression.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), comparison(ComparisonExpression.Operator.GREATER_THAN, add(symbol2, 10L), symbol.toSymbolReference())), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(symbol), planBuilder.values(symbol2), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.project(PlanMatchPattern.filter("expr_less < a and expr_greater > a", PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("expr_less", PlanMatchPattern.expression("b + BIGINT '1'"), "expr_greater", PlanMatchPattern.expression("b + BIGINT '10'")), PlanMatchPattern.values("b")));
        }))));
    }

    @Test
    public void testOnlyParentFilterExpressionExposedInaJoin() {
        tester().assertThat(this.ruleSet.pushParentInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.filter(comparison(ComparisonExpression.Operator.LESS_THAN, add(symbol2, 1L), symbol.toSymbolReference()), planBuilder.join(JoinNode.Type.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(ComparisonExpression.Operator.LESS_THAN, add(symbol2, 2L), symbol.toSymbolReference()), new JoinNode.EquiJoinClause[0]));
        }).matches(PlanMatchPattern.project(PlanMatchPattern.filter("parent_expression < a", PlanMatchPattern.join(JoinNode.Type.INNER, builder -> {
            builder.filter("join_expression < a").left(PlanMatchPattern.values("a")).right(PlanMatchPattern.project(ImmutableMap.of("join_expression", PlanMatchPattern.expression("b + BIGINT '2'"), "parent_expression", PlanMatchPattern.expression("b + BIGINT '1'")), PlanMatchPattern.values("b")));
        }).withExactOutputs("a", "b", "parent_expression"))));
    }

    @Test
    public void testNoExpression() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinNode.Type.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(ComparisonExpression.Operator.LESS_THAN, symbol.toSymbolReference(), symbol2.toSymbolReference()), new JoinNode.EquiJoinClause[0]);
        }).doesNotFire();
    }

    @Test
    public void testNotSupportedExpression() {
        tester().assertThat(this.ruleSet.pushJoinInequalityFilterExpressionBelowJoinRule()).on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a");
            Symbol symbol2 = planBuilder.symbol("b");
            return planBuilder.join(JoinNode.Type.INNER, (PlanNode) planBuilder.values(symbol), (PlanNode) planBuilder.values(symbol2), (Expression) comparison(ComparisonExpression.Operator.IS_DISTINCT_FROM, symbol.toSymbolReference(), symbol2.toSymbolReference()), new JoinNode.EquiJoinClause[0]);
        }).doesNotFire();
    }

    private static ComparisonExpression comparison(ComparisonExpression.Operator operator, Expression expression, Expression expression2) {
        return new ComparisonExpression(operator, expression, expression2);
    }

    private ArithmeticBinaryExpression add(Symbol symbol, long j) {
        return new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, symbol.toSymbolReference(), new GenericLiteral("BIGINT", String.valueOf(j)));
    }
}
