package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Streams;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Lookup;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.PlanNodeSearcher;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.AggregationNode;
import io.trino.sql.planner.plan.AssignUniqueId;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.CorrelatedJoinNode;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.PlanNodeId;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.UnnestNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.IsNotNullPredicate;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/DecorrelateInnerUnnestWithGlobalAggregation.class */
public class DecorrelateInnerUnnestWithGlobalAggregation implements Rule<CorrelatedJoinNode> {
    private static final Pattern<CorrelatedJoinNode> PATTERN = Patterns.correlatedJoin().with(Pattern.nonEmpty(Patterns.CorrelatedJoin.correlation())).with(Patterns.CorrelatedJoin.filter().equalTo(BooleanLiteral.TRUE_LITERAL)).matching(correlatedJoinNode -> {
        return correlatedJoinNode.getType() == CorrelatedJoinNode.Type.INNER || correlatedJoinNode.getType() == CorrelatedJoinNode.Type.LEFT;
    });

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<CorrelatedJoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(CorrelatedJoinNode correlatedJoinNode, Captures captures, Rule.Context context) {
        List findAll = PlanNodeSearcher.searchFrom(correlatedJoinNode.getSubquery(), context.getLookup()).where(DecorrelateInnerUnnestWithGlobalAggregation::isGlobalAggregation).recurseOnlyWhen(planNode -> {
            return (planNode instanceof ProjectNode) || isGlobalAggregation(planNode);
        }).findAll();
        if (findAll.isEmpty()) {
            return Rule.Result.empty();
        }
        AggregationNode aggregationNode = (AggregationNode) findAll.get(findAll.size() - 1);
        Optional findFirst = PlanNodeSearcher.searchFrom(aggregationNode.getSource(), context.getLookup()).where(planNode2 -> {
            return isSupportedUnnest(planNode2, correlatedJoinNode.getCorrelation(), context.getLookup());
        }).recurseOnlyWhen(planNode3 -> {
            return (planNode3 instanceof ProjectNode) || isGroupedAggregation(planNode3);
        }).findFirst();
        if (findFirst.isEmpty()) {
            return Rule.Result.empty();
        }
        UnnestNode unnestNode = (UnnestNode) findFirst.get();
        PlanNode assignUniqueId = new AssignUniqueId(context.getIdAllocator().getNextId(), correlatedJoinNode.getInput(), context.getSymbolAllocator().newSymbol("unique", (Type) BigintType.BIGINT));
        PlanNode resolve = context.getLookup().resolve(unnestNode.getSource());
        if (resolve instanceof ProjectNode) {
            ProjectNode projectNode = (ProjectNode) resolve;
            assignUniqueId = new ProjectNode(projectNode.getId(), assignUniqueId, Assignments.builder().putIdentities(assignUniqueId.getOutputSymbols()).putAll(projectNode.getAssignments()).build());
        }
        Symbol orElseGet = unnestNode.getOrdinalitySymbol().orElseGet(() -> {
            return context.getSymbolAllocator().newSymbol("ordinality", (Type) BigintType.BIGINT);
        });
        UnnestNode unnestNode2 = new UnnestNode(context.getIdAllocator().getNextId(), assignUniqueId, assignUniqueId.getOutputSymbols(), unnestNode.getMappings(), Optional.of(orElseGet), JoinNode.Type.LEFT, Optional.empty());
        Symbol newSymbol = context.getSymbolAllocator().newSymbol("mask", (Type) BooleanType.BOOLEAN);
        PlanNode rewriteNodeSequence = rewriteNodeSequence(context.getLookup().resolve(correlatedJoinNode.getSubquery()), assignUniqueId.getOutputSymbols(), newSymbol, new ProjectNode(context.getIdAllocator().getNextId(), unnestNode2, Assignments.builder().putIdentities(unnestNode2.getOutputSymbols()).put(newSymbol, new IsNotNullPredicate(orElseGet.toSymbolReference())).build()), aggregationNode.getId(), unnestNode.getId(), context.getSymbolAllocator(), context.getIdAllocator(), context.getLookup());
        return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), rewriteNodeSequence, ImmutableSet.copyOf(correlatedJoinNode.getOutputSymbols())).orElse(rewriteNodeSequence));
    }

    private static boolean isGlobalAggregation(PlanNode planNode) {
        if (!(planNode instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode) planNode;
        return aggregationNode.hasEmptyGroupingSet() && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    private static boolean isGroupedAggregation(PlanNode planNode) {
        if (!(planNode instanceof AggregationNode)) {
            return false;
        }
        AggregationNode aggregationNode = (AggregationNode) planNode;
        return aggregationNode.hasNonEmptyGroupingSet() && aggregationNode.getGroupingSetCount() == 1 && aggregationNode.getStep() == AggregationNode.Step.SINGLE;
    }

    /* JADX INFO: Access modifiers changed from: private */
    public static boolean isSupportedUnnest(PlanNode planNode, List<Symbol> list, Lookup lookup) {
        if (!(planNode instanceof UnnestNode)) {
            return false;
        }
        UnnestNode unnestNode = (UnnestNode) planNode;
        List list2 = (List) unnestNode.getMappings().stream().map((v0) -> {
            return v0.getInput();
        }).collect(ImmutableList.toImmutableList());
        PlanNode resolve = lookup.resolve(unnestNode.getSource());
        ImmutableSet copyOf = ImmutableSet.copyOf(list);
        return QueryCardinalityUtil.isScalar(unnestNode.getSource(), lookup) && unnestNode.getReplicateSymbols().isEmpty() && (copyOf.containsAll(list2) || ((resolve instanceof ProjectNode) && copyOf.containsAll(SymbolsExtractor.extractUnique(((ProjectNode) resolve).getAssignments().getExpressions())))) && unnestNode.getJoinType() == JoinNode.Type.INNER && (unnestNode.getFilter().isEmpty() || unnestNode.getFilter().get().equals(BooleanLiteral.TRUE_LITERAL));
    }

    private static PlanNode rewriteNodeSequence(PlanNode planNode, List<Symbol> list, Symbol symbol, PlanNode planNode2, PlanNodeId planNodeId, PlanNodeId planNodeId2, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator, Lookup lookup) {
        if (planNode.getId().equals(planNodeId2)) {
            return planNode2;
        }
        PlanNode rewriteNodeSequence = rewriteNodeSequence(lookup.resolve((PlanNode) Iterables.getOnlyElement(planNode.getSources())), list, symbol, planNode2, planNodeId, planNodeId2, symbolAllocator, planNodeIdAllocator, lookup);
        if (isGlobalAggregation(planNode)) {
            AggregationNode aggregationNode = (AggregationNode) planNode;
            return aggregationNode.getId().equals(planNodeId) ? withGroupingAndMask(aggregationNode, list, symbol, rewriteNodeSequence, symbolAllocator, planNodeIdAllocator) : withGrouping(aggregationNode, list, rewriteNodeSequence);
        }
        if (isGroupedAggregation(planNode)) {
            return withGrouping((AggregationNode) planNode, ImmutableList.builder().addAll(list).add(symbol).build(), rewriteNodeSequence);
        }
        if (!(planNode instanceof ProjectNode)) {
            throw new IllegalStateException("unexpected node: " + planNode);
        }
        ProjectNode projectNode = (ProjectNode) planNode;
        return new ProjectNode(projectNode.getId(), rewriteNodeSequence, Assignments.builder().putAll(projectNode.getAssignments()).putIdentities(list).putIdentities(ImmutableSet.copyOf(rewriteNodeSequence.getOutputSymbols()).contains(symbol) ? ImmutableList.of(symbol) : ImmutableList.of()).build());
    }

    private static AggregationNode withGroupingAndMask(AggregationNode aggregationNode, List<Symbol> list, Symbol symbol, PlanNode planNode, SymbolAllocator symbolAllocator, PlanNodeIdAllocator planNodeIdAllocator) {
        ImmutableMap.Builder builder = ImmutableMap.builder();
        Assignments.Builder builder2 = Assignments.builder();
        for (Map.Entry<Symbol, AggregationNode.Aggregation> entry : aggregationNode.getAggregations().entrySet()) {
            AggregationNode.Aggregation value = entry.getValue();
            if (value.getMask().isPresent()) {
                Symbol newSymbol = symbolAllocator.newSymbol("mask", (Type) BooleanType.BOOLEAN);
                builder2.put(newSymbol, ExpressionUtils.and(value.getMask().get().toSymbolReference(), symbol.toSymbolReference()));
                builder.put(entry.getKey(), newSymbol);
            } else {
                builder.put(entry.getKey(), symbol);
            }
        }
        Assignments build = builder2.build();
        if (!build.isEmpty()) {
            planNode = new ProjectNode(planNodeIdAllocator.getNextId(), planNode, Assignments.builder().putIdentities(planNode.getOutputSymbols()).putAll(build).build());
        }
        return AggregationNode.singleAggregation(aggregationNode.getId(), planNode, AggregationDecorrelation.rewriteWithMasks(aggregationNode.getAggregations(), builder.buildOrThrow()), AggregationNode.singleGroupingSet(list));
    }

    private static AggregationNode withGrouping(AggregationNode aggregationNode, List<Symbol> list, PlanNode planNode) {
        return AggregationNode.singleAggregation(aggregationNode.getId(), planNode, aggregationNode.getAggregations(), AggregationNode.singleGroupingSet((List) Streams.concat(new Stream[]{list.stream(), aggregationNode.getGroupingKeys().stream()}).distinct().collect(ImmutableList.toImmutableList())));
    }
}
