package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import io.trino.spi.Plugin;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.assertions.PlanMatchPattern;
import io.trino.sql.planner.iterative.rule.test.BaseRuleTest;
import io.trino.sql.planner.iterative.rule.test.PlanBuilder;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.tree.SortItem;
import java.util.Optional;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/TestPushLimitThroughSemiJoin.class */
public class TestPushLimitThroughSemiJoin extends BaseRuleTest {
    public TestPushLimitThroughSemiJoin() {
        super(new Plugin[0]);
    }

    @Test
    public void test() {
        tester().assertThat(new PushLimitThroughSemiJoin()).on(planBuilder -> {
            return planBuilder.limit(1L, buildSemiJoin(planBuilder));
        }).matches(PlanMatchPattern.semiJoin("leftKey", "rightKey", "match", PlanMatchPattern.limit(1L, PlanMatchPattern.values("leftKey")), PlanMatchPattern.values("rightKey")));
    }

    @Test
    public void testPushLimitWithTies() {
        tester().assertThat(new PushLimitThroughSemiJoin()).on(planBuilder -> {
            return planBuilder.limit(1L, ImmutableList.of(planBuilder.symbol("leftKey")), buildSemiJoin(planBuilder));
        }).matches(PlanMatchPattern.semiJoin("leftKey", "rightKey", "match", PlanMatchPattern.limit(1L, ImmutableList.of(PlanMatchPattern.sort("leftKey", SortItem.Ordering.ASCENDING, SortItem.NullOrdering.FIRST)), PlanMatchPattern.values("leftKey")), PlanMatchPattern.values("rightKey")));
    }

    @Test
    public void testPushLimitWithPreSortedInputs() {
        tester().assertThat(new PushLimitThroughSemiJoin()).on(planBuilder -> {
            return planBuilder.limit(1L, false, ImmutableList.of(planBuilder.symbol("leftKey")), buildSemiJoin(planBuilder));
        }).matches(PlanMatchPattern.semiJoin("leftKey", "rightKey", "match", PlanMatchPattern.limit(1L, ImmutableList.of(), false, ImmutableList.of("leftKey"), PlanMatchPattern.values("leftKey")), PlanMatchPattern.values("rightKey")));
    }

    @Test
    public void testDoesNotFire() {
        tester().assertThat(new PushLimitThroughSemiJoin()).on(planBuilder -> {
            return planBuilder.semiJoin(planBuilder.symbol("leftKey"), planBuilder.symbol("rightKey"), planBuilder.symbol("output"), Optional.empty(), Optional.empty(), planBuilder.values(planBuilder.symbol("leftKey")), planBuilder.limit(1L, planBuilder.values(planBuilder.symbol("rightKey"))));
        }).doesNotFire();
        tester().assertThat(new PushLimitThroughSemiJoin()).on(planBuilder2 -> {
            return planBuilder2.limit(1L, ImmutableList.of(planBuilder2.symbol("match")), buildSemiJoin(planBuilder2));
        }).doesNotFire();
        tester().assertThat(new PushLimitThroughSemiJoin()).on(planBuilder3 -> {
            return planBuilder3.limit(1L, false, ImmutableList.of(planBuilder3.symbol("match")), buildSemiJoin(planBuilder3));
        }).doesNotFire();
    }

    private static PlanNode buildSemiJoin(PlanBuilder planBuilder) {
        Symbol symbol = planBuilder.symbol("leftKey");
        Symbol symbol2 = planBuilder.symbol("rightKey");
        return planBuilder.semiJoin(symbol, symbol2, planBuilder.symbol("match"), Optional.empty(), Optional.empty(), planBuilder.values(symbol), planBuilder.values(symbol2));
    }
}
