package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.operator.RetryPolicy;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.DynamicFilterSourceNode;
import io.trino.sql.planner.plan.ExchangeNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.SemiJoinNode;
import java.util.Collection;
import java.util.Set;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/AddDynamicFilterSource.class */
public final class AddDynamicFilterSource {

    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/AddDynamicFilterSource$PushOrRemoveDynamicFilterSource.class */
    private static class PushOrRemoveDynamicFilterSource implements Rule<PlanNode> {
        private static final Capture<DynamicFilterSourceNode> DYNAMIC_FILTER_SOURCE = Capture.newCapture();
        private static final Pattern<PlanNode> PATTERN = Pattern.typeOf(PlanNode.class).with(Patterns.source().matching(Patterns.dynamicFilterSource().capturedAs(DYNAMIC_FILTER_SOURCE)));

        private PushOrRemoveDynamicFilterSource() {
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<PlanNode> getPattern() {
            return PATTERN;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(PlanNode planNode, Captures captures, Rule.Context context) {
            if (AddDynamicFilterSource.isRemoteExchange(planNode)) {
                return Rule.Result.empty();
            }
            DynamicFilterSourceNode dynamicFilterSourceNode = (DynamicFilterSourceNode) captures.get(DYNAMIC_FILTER_SOURCE);
            PlanNode resolve = context.getLookup().resolve(dynamicFilterSourceNode.getSource());
            return !AddDynamicFilterSource.canAddDynamicFilterSource(resolve, dynamicFilterSourceNode.getDynamicFilters().values()) ? Rule.Result.ofPlanNode(planNode.replaceChildren(ImmutableList.of(resolve))) : Rule.Result.ofPlanNode(planNode.replaceChildren(ImmutableList.of(resolve.replaceChildren(ImmutableList.of(dynamicFilterSourceNode.replaceChildren(resolve.getSources()))))));
        }
    }

    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/AddDynamicFilterSource$RewriteJoinDynamicFilter.class */
    private static class RewriteJoinDynamicFilter implements Rule<JoinNode> {
        private static final Capture<PlanNode> BUILD_SIDE_NODE = Capture.newCapture();
        private static final Pattern<JoinNode> PATTERN = Patterns.join().matching(joinNode -> {
            return !joinNode.getDynamicFilters().isEmpty();
        }).with(Patterns.Join.right().capturedAs(BUILD_SIDE_NODE));

        private RewriteJoinDynamicFilter() {
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<JoinNode> getPattern() {
            return PATTERN;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
            PlanNode planNode = (PlanNode) captures.get(BUILD_SIDE_NODE);
            return !AddDynamicFilterSource.canAddDynamicFilterSource(planNode, joinNode.getDynamicFilters().values()) ? Rule.Result.ofPlanNode(joinNode.withoutDynamicFilters()) : Rule.Result.ofPlanNode(joinNode.withoutDynamicFilters().replaceChildren(ImmutableList.of(joinNode.getLeft(), planNode.replaceChildren(ImmutableList.of(new DynamicFilterSourceNode(context.getIdAllocator().getNextId(), (PlanNode) Iterables.getOnlyElement(planNode.getSources()), joinNode.getDynamicFilters()))))));
        }
    }

    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/AddDynamicFilterSource$RewriteSemiJoinDynamicFilter.class */
    private static class RewriteSemiJoinDynamicFilter implements Rule<SemiJoinNode> {
        private static final Capture<PlanNode> FILTERING_SOURCE_NODE = Capture.newCapture();
        private static final Pattern<SemiJoinNode> PATTERN = Patterns.semiJoin().matching(semiJoinNode -> {
            return semiJoinNode.getDynamicFilterId().isPresent();
        }).with(Patterns.SemiJoin.getFilteringSource().capturedAs(FILTERING_SOURCE_NODE));

        private RewriteSemiJoinDynamicFilter() {
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<SemiJoinNode> getPattern() {
            return PATTERN;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public boolean isEnabled(Session session) {
            return SystemSessionProperties.getRetryPolicy(session) == RetryPolicy.TASK;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(SemiJoinNode semiJoinNode, Captures captures, Rule.Context context) {
            PlanNode planNode = (PlanNode) captures.get(FILTERING_SOURCE_NODE);
            return !AddDynamicFilterSource.canAddDynamicFilterSource(planNode, ImmutableList.of(semiJoinNode.getFilteringSourceJoinSymbol())) ? Rule.Result.ofPlanNode(semiJoinNode.withoutDynamicFilter()) : Rule.Result.ofPlanNode(semiJoinNode.withoutDynamicFilter().replaceChildren(ImmutableList.of(semiJoinNode.getSource(), planNode.replaceChildren(ImmutableList.of(new DynamicFilterSourceNode(context.getIdAllocator().getNextId(), (PlanNode) Iterables.getOnlyElement(planNode.getSources()), ImmutableMap.of(semiJoinNode.getDynamicFilterId().orElseThrow(), semiJoinNode.getFilteringSourceJoinSymbol())))))));
        }
    }

    private AddDynamicFilterSource() {
    }

    public static Set<Rule<?>> rules() {
        return ImmutableSet.of(new RewriteJoinDynamicFilter(), new RewriteSemiJoinDynamicFilter(), new PushOrRemoveDynamicFilterSource());
    }

    private static boolean canAddDynamicFilterSource(PlanNode planNode, Collection<Symbol> collection) {
        return ((planNode instanceof ProjectNode) && collection.stream().allMatch(symbol -> {
            return ((ProjectNode) planNode).getAssignments().isIdentity(symbol);
        })) || ((planNode instanceof ExchangeNode) && planNode.getSources().size() == 1);
    }

    private static boolean isRemoteExchange(PlanNode planNode) {
        return (planNode instanceof ExchangeNode) && ((ExchangeNode) planNode).getScope() == ExchangeNode.Scope.REMOTE;
    }
}
