package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.spi.Plugin;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.GroupReference;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.optimizations.joins.JoinGraph;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.ValuesNode;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArithmeticUnaryExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.SymbolReference;
import io.trino.testing.TestingSession;
import java.util.Arrays;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import org.testng.Assert;
import org.testng.annotations.Test;

@Test(singleThreaded = true)
/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestEliminateCrossJoins.class */
public class TestEliminateCrossJoins extends BaseRuleTest {
    private final PlanNodeIdAllocator idAllocator;

    public TestEliminateCrossJoins() {
        super(new Plugin[0]);
        this.idAllocator = new PlanNodeIdAllocator();
    }

    @Test
    public void testEliminateCrossJoin() {
        tester().assertThat(new EliminateCrossJoins(tester().getPlannerContext(), tester().getTypeAnalyzer())).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(crossJoinAndJoin(JoinNode.Type.INNER)).matches(PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(symbolAliases -> {
            return new JoinNode.EquiJoinClause(new Symbol("cySymbol"), new Symbol("bySymbol"));
        }), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(symbolAliases2 -> {
            return new JoinNode.EquiJoinClause(new Symbol("axSymbol"), new Symbol("cxSymbol"));
        }), PlanMatchPattern.any(new PlanMatchPattern[0]), PlanMatchPattern.any(new PlanMatchPattern[0])), PlanMatchPattern.any(new PlanMatchPattern[0])));
    }

    @Test
    public void testRetainOutgoingGroupReferences() {
        tester().assertThat(new EliminateCrossJoins(tester().getPlannerContext(), tester().getTypeAnalyzer())).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(crossJoinAndJoin(JoinNode.Type.INNER)).matches(PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.node(JoinNode.class, PlanMatchPattern.node(GroupReference.class, new PlanMatchPattern[0]), PlanMatchPattern.node(GroupReference.class, new PlanMatchPattern[0])), PlanMatchPattern.node(GroupReference.class, new PlanMatchPattern[0])));
    }

    @Test
    public void testDoNotReorderOuterJoin() {
        tester().assertThat(new EliminateCrossJoins(tester().getPlannerContext(), tester().getTypeAnalyzer())).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(crossJoinAndJoin(JoinNode.Type.LEFT)).doesNotFire();
    }

    @Test
    public void testIsOriginalOrder() {
        Assert.assertTrue(EliminateCrossJoins.isOriginalOrder(ImmutableList.of(0, 1, 2, 3, 4)));
        Assert.assertFalse(EliminateCrossJoins.isOriginalOrder(ImmutableList.of(0, 2, 1, 3, 4)));
    }

    @Test
    public void testJoinOrder() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder(JoinGraph.buildFrom(tester().getPlannerContext(), joinNode(joinNode(values("a"), values("b"), new String[0]), values("c"), "a", "c", "b", "c"), Lookup.noLookup(), new PlanNodeIdAllocator(), TestingSession.testSessionBuilder().build(), TypeAnalyzer.createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty())), ImmutableList.of(0, 2, 1));
    }

    @Test
    public void testJoinOrderWithRealCrossJoin() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder(JoinGraph.buildFrom(tester().getPlannerContext(), joinNode(joinNode(joinNode(values("a"), values("b"), new String[0]), values("c"), "a", "c", "b", "c"), joinNode(joinNode(values("x"), values("y"), new String[0]), values("z"), "x", "z", "y", "z"), new String[0]), Lookup.noLookup(), new PlanNodeIdAllocator(), TestingSession.testSessionBuilder().build(), TypeAnalyzer.createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty())), ImmutableList.of(0, 2, 1, 3, 5, 4));
    }

    @Test
    public void testJoinOrderWithMultipleEdgesBetweenNodes() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder(JoinGraph.buildFrom(tester().getPlannerContext(), joinNode(joinNode(values("a"), values("b1", "b2"), new String[0]), values("c1", "c2"), "a", "c1", "b1", "c1", "b2", "c2"), Lookup.noLookup(), new PlanNodeIdAllocator(), TestingSession.testSessionBuilder().build(), TypeAnalyzer.createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty())), ImmutableList.of(0, 2, 1));
    }

    @Test
    public void testDoesNotChangeOrderWithoutCrossJoin() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder(JoinGraph.buildFrom(tester().getPlannerContext(), joinNode(joinNode(values("a"), values("b"), "a", "b"), values("c"), "b", "c"), Lookup.noLookup(), new PlanNodeIdAllocator(), TestingSession.testSessionBuilder().build(), TypeAnalyzer.createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty())), ImmutableList.of(0, 1, 2));
    }

    @Test
    public void testDoNotReorderCrossJoins() {
        Assert.assertEquals(EliminateCrossJoins.getJoinOrder(JoinGraph.buildFrom(tester().getPlannerContext(), joinNode(joinNode(values("a"), values("b"), new String[0]), values("c"), "b", "c"), Lookup.noLookup(), new PlanNodeIdAllocator(), TestingSession.testSessionBuilder().build(), TypeAnalyzer.createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty())), ImmutableList.of(0, 1, 2));
    }

    @Test
    public void testEliminateCrossJoinWithNonIdentityProjections() {
        tester().assertThat(new EliminateCrossJoins(tester().getPlannerContext(), tester().getTypeAnalyzer())).setSystemProperty("join_reordering_strategy", "ELIMINATE_CROSS_JOINS").on(planBuilder -> {
            Symbol symbol = planBuilder.symbol("a1");
            Symbol symbol2 = planBuilder.symbol("a2");
            Symbol symbol3 = planBuilder.symbol("b");
            Symbol symbol4 = planBuilder.symbol("c");
            Symbol symbol5 = planBuilder.symbol("d");
            Symbol symbol6 = planBuilder.symbol("e");
            Symbol symbol7 = planBuilder.symbol("f");
            return planBuilder.join(JoinNode.Type.INNER, planBuilder.project(Assignments.of(symbol2, new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, new SymbolReference("a1")), symbol7, new SymbolReference("f")), planBuilder.join(JoinNode.Type.INNER, planBuilder.project(Assignments.of(symbol, new SymbolReference("a1"), symbol7, new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, new SymbolReference("b"))), planBuilder.join(JoinNode.Type.INNER, planBuilder.values(symbol), planBuilder.values(symbol3), new JoinNode.EquiJoinClause[0])), planBuilder.values(symbol6), new JoinNode.EquiJoinClause(symbol, symbol6))), planBuilder.values(symbol4, symbol5), new JoinNode.EquiJoinClause(symbol2, symbol4), new JoinNode.EquiJoinClause(symbol7, symbol5));
        }).matches(PlanMatchPattern.node(ProjectNode.class, PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(symbolAliases -> {
            return new JoinNode.EquiJoinClause(new Symbol("d"), new Symbol("f"));
        }), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(symbolAliases2 -> {
            return new JoinNode.EquiJoinClause(new Symbol("a2"), new Symbol("c"));
        }), PlanMatchPattern.join(JoinNode.Type.INNER, ImmutableList.of(symbolAliases3 -> {
            return new JoinNode.EquiJoinClause(new Symbol("a1"), new Symbol("e"));
        }), PlanMatchPattern.strictProject(ImmutableMap.of("a2", PlanMatchPattern.expression("-a1"), "a1", PlanMatchPattern.expression("a1")), PlanMatchPattern.values("a1")), PlanMatchPattern.strictProject(ImmutableMap.of("e", PlanMatchPattern.expression("e")), PlanMatchPattern.values("e"))), PlanMatchPattern.any(new PlanMatchPattern[0])), PlanMatchPattern.strictProject(ImmutableMap.of("f", PlanMatchPattern.expression("-b")), PlanMatchPattern.values("b")))));
    }

    @Test
    public void testGiveUpOnComplexProjections() {
        Assert.assertEquals(JoinGraph.buildFrom(tester().getPlannerContext(), joinNode(projectNode(joinNode(values("a1"), values("b"), new String[0]), "a2", new ArithmeticBinaryExpression(ArithmeticBinaryExpression.Operator.ADD, new SymbolReference("a1"), new SymbolReference("b")), "b", new SymbolReference("b")), values("c"), "a2", "c", "b", "c"), Lookup.noLookup(), new PlanNodeIdAllocator(), TestingSession.testSessionBuilder().build(), TypeAnalyzer.createTestingTypeAnalyzer(tester().getPlannerContext()), TypeProvider.empty()).size(), 2);
    }

    private Function<PlanBuilder, PlanNode> crossJoinAndJoin(JoinNode.Type type) {
        return planBuilder -> {
            Symbol symbol = planBuilder.symbol("axSymbol");
            Symbol symbol2 = planBuilder.symbol("bySymbol");
            Symbol symbol3 = planBuilder.symbol("cxSymbol");
            Symbol symbol4 = planBuilder.symbol("cySymbol");
            return planBuilder.join(JoinNode.Type.INNER, planBuilder.join(type, planBuilder.values(symbol), planBuilder.values(symbol2), new JoinNode.EquiJoinClause[0]), planBuilder.values(symbol3, symbol4), new JoinNode.EquiJoinClause(symbol, symbol3), new JoinNode.EquiJoinClause(symbol2, symbol4));
        };
    }

    private PlanNode projectNode(PlanNode planNode, String str, Expression expression, String str2, Expression expression2) {
        return new ProjectNode(this.idAllocator.getNextId(), planNode, Assignments.of(new Symbol(str), expression, new Symbol(str2), expression2));
    }

    private JoinNode joinNode(PlanNode planNode, PlanNode planNode2, String... strArr) {
        Preconditions.checkArgument(strArr.length % 2 == 0);
        ImmutableList.Builder builder = ImmutableList.builder();
        for (int i = 0; i < strArr.length; i += 2) {
            builder.add(new JoinNode.EquiJoinClause(new Symbol(strArr[i]), new Symbol(strArr[i + 1])));
        }
        return new JoinNode(this.idAllocator.getNextId(), JoinNode.Type.INNER, planNode, planNode2, builder.build(), planNode.getOutputSymbols(), planNode2.getOutputSymbols(), false, Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableMap.of(), Optional.empty());
    }

    private ValuesNode values(String... strArr) {
        return new ValuesNode(this.idAllocator.getNextId(), (List) Arrays.stream(strArr).map(Symbol::new).collect(ImmutableList.toImmutableList()), ImmutableList.of());
    }
}
