package io.trino.sql.planner.iterative.rule;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Streams;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolsExtractor;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.PatternRecognitionNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.rowpattern.ExpressionAndValuePointersEquivalence;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor;
import io.trino.sql.planner.rowpattern.ir.IrLabel;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.stream.Stream;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/MergePatternRecognitionNodes.class */
public class MergePatternRecognitionNodes {

    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/MergePatternRecognitionNodes$MergePatternRecognitionNodesWithProject.class */
    public static final class MergePatternRecognitionNodesWithProject implements Rule<PatternRecognitionNode> {
        private static final Capture<ProjectNode> PROJECT = Capture.newCapture();
        private static final Capture<PatternRecognitionNode> CHILD = Capture.newCapture();
        private static final Pattern<PatternRecognitionNode> PATTERN = Patterns.patternRecognition().with(Patterns.source().matching(Patterns.project().capturedAs(PROJECT).with(Patterns.source().matching(Patterns.patternRecognition().capturedAs(CHILD)))));

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<PatternRecognitionNode> getPattern() {
            return PATTERN;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(PatternRecognitionNode patternRecognitionNode, Captures captures, Rule.Context context) {
            ProjectNode projectNode;
            ProjectNode projectNode2 = (ProjectNode) captures.get(PROJECT);
            PatternRecognitionNode patternRecognitionNode2 = (PatternRecognitionNode) captures.get(CHILD);
            if (MergePatternRecognitionNodes.patternRecognitionSpecificationsMatch(patternRecognitionNode, patternRecognitionNode2) && !MergePatternRecognitionNodes.dependsOnSourceCreatedOutputs(patternRecognitionNode, projectNode2, patternRecognitionNode2)) {
                PatternRecognitionNode merge = MergePatternRecognitionNodes.merge(patternRecognitionNode, patternRecognitionNode2);
                Assignments extractPrerequisites = MergePatternRecognitionNodes.extractPrerequisites(patternRecognitionNode, projectNode2);
                if (extractPrerequisites.isEmpty()) {
                    projectNode = new ProjectNode(context.getIdAllocator().getNextId(), merge, Assignments.builder().putIdentities(merge.getOutputSymbols()).putAll(projectNode2.getAssignments()).build());
                } else {
                    Assignments filter = projectNode2.getAssignments().filter(symbol -> {
                        return !extractPrerequisites.getSymbols().contains(symbol);
                    });
                    PatternRecognitionNode patternRecognitionNode3 = (PatternRecognitionNode) merge.replaceChildren(ImmutableList.of(new ProjectNode(context.getIdAllocator().getNextId(), merge.getSource(), Assignments.builder().putIdentities(merge.getSource().getOutputSymbols()).putAll(extractPrerequisites).build())));
                    projectNode = new ProjectNode(context.getIdAllocator().getNextId(), patternRecognitionNode3, Assignments.builder().putIdentities(patternRecognitionNode3.getOutputSymbols()).putAll(filter).build());
                }
                return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), projectNode, ImmutableSet.copyOf(patternRecognitionNode.getOutputSymbols())).orElse(projectNode));
            }
            return Rule.Result.empty();
        }
    }

    /* loaded from: input_file:io/trino/sql/planner/iterative/rule/MergePatternRecognitionNodes$MergePatternRecognitionNodesWithoutProject.class */
    public static final class MergePatternRecognitionNodesWithoutProject implements Rule<PatternRecognitionNode> {
        private static final Capture<PatternRecognitionNode> CHILD = Capture.newCapture();
        private static final Pattern<PatternRecognitionNode> PATTERN = Patterns.patternRecognition().with(Patterns.source().matching(Patterns.patternRecognition().capturedAs(CHILD)));

        @Override // io.trino.sql.planner.iterative.Rule
        public Pattern<PatternRecognitionNode> getPattern() {
            return PATTERN;
        }

        @Override // io.trino.sql.planner.iterative.Rule
        public Rule.Result apply(PatternRecognitionNode patternRecognitionNode, Captures captures, Rule.Context context) {
            PatternRecognitionNode patternRecognitionNode2 = (PatternRecognitionNode) captures.get(CHILD);
            if (MergePatternRecognitionNodes.patternRecognitionSpecificationsMatch(patternRecognitionNode, patternRecognitionNode2) && !MergePatternRecognitionNodes.dependsOnSourceCreatedOutputs(patternRecognitionNode, patternRecognitionNode2)) {
                PatternRecognitionNode merge = MergePatternRecognitionNodes.merge(patternRecognitionNode, patternRecognitionNode2);
                return Rule.Result.ofPlanNode(Util.restrictOutputs(context.getIdAllocator(), merge, ImmutableSet.copyOf(patternRecognitionNode.getOutputSymbols())).orElse(merge));
            }
            return Rule.Result.empty();
        }
    }

    private MergePatternRecognitionNodes() {
    }

    public static Set<Rule<?>> rules() {
        return ImmutableSet.of(new MergePatternRecognitionNodesWithoutProject(), new MergePatternRecognitionNodesWithProject());
    }

    private static boolean patternRecognitionSpecificationsMatch(PatternRecognitionNode patternRecognitionNode, PatternRecognitionNode patternRecognitionNode2) {
        return patternRecognitionNode.getSpecification().equals(patternRecognitionNode2.getSpecification()) && patternRecognitionNode.getCommonBaseFrame().equals(patternRecognitionNode2.getCommonBaseFrame()) && patternRecognitionNode.getRowsPerMatch() == patternRecognitionNode2.getRowsPerMatch() && patternRecognitionNode.getSkipToLabel().equals(patternRecognitionNode2.getSkipToLabel()) && patternRecognitionNode.getSkipToPosition() == patternRecognitionNode2.getSkipToPosition() && patternRecognitionNode.isInitial() == patternRecognitionNode2.isInitial() && patternRecognitionNode.getPattern().equals(patternRecognitionNode2.getPattern()) && patternRecognitionNode.getSubsets().equals(patternRecognitionNode2.getSubsets()) && equivalent(patternRecognitionNode.getVariableDefinitions(), patternRecognitionNode2.getVariableDefinitions());
    }

    private static boolean equivalent(Map<IrLabel, LogicalIndexExtractor.ExpressionAndValuePointers> map, Map<IrLabel, LogicalIndexExtractor.ExpressionAndValuePointers> map2) {
        if (!map.keySet().equals(map2.keySet())) {
            return false;
        }
        for (Map.Entry<IrLabel, LogicalIndexExtractor.ExpressionAndValuePointers> entry : map.entrySet()) {
            if (!ExpressionAndValuePointersEquivalence.equivalent(entry.getValue(), map2.get(entry.getKey()))) {
                return false;
            }
        }
        return true;
    }

    private static boolean dependsOnSourceCreatedOutputs(PatternRecognitionNode patternRecognitionNode, PatternRecognitionNode patternRecognitionNode2) {
        Set<Symbol> createdSymbols = patternRecognitionNode2.getCreatedSymbols();
        Stream concat = Streams.concat(new Stream[]{patternRecognitionNode.getWindowFunctions().values().stream().map(SymbolsExtractor::extractAll).flatMap((v0) -> {
            return v0.stream();
        }), patternRecognitionNode.getMeasures().values().stream().map((v0) -> {
            return v0.getExpressionAndValuePointers();
        }).map((v0) -> {
            return v0.getInputSymbols();
        }).flatMap((v0) -> {
            return v0.stream();
        })});
        Objects.requireNonNull(createdSymbols);
        return concat.anyMatch((v1) -> {
            return r1.contains(v1);
        });
    }

    private static boolean dependsOnSourceCreatedOutputs(PatternRecognitionNode patternRecognitionNode, ProjectNode projectNode, PatternRecognitionNode patternRecognitionNode2) {
        Set<Symbol> createdSymbols = patternRecognitionNode2.getCreatedSymbols();
        Assignments assignments = projectNode.getAssignments();
        ImmutableSet.Builder builder = ImmutableSet.builder();
        Stream<R> map = patternRecognitionNode.getWindowFunctions().values().stream().map(SymbolsExtractor::extractAll);
        Objects.requireNonNull(builder);
        map.forEach((v1) -> {
            r1.addAll(v1);
        });
        Stream map2 = patternRecognitionNode.getMeasures().values().stream().map((v0) -> {
            return v0.getExpressionAndValuePointers();
        }).map((v0) -> {
            return v0.getInputSymbols();
        });
        Objects.requireNonNull(builder);
        map2.forEach((v1) -> {
            r1.addAll(v1);
        });
        Stream stream = builder.build().stream();
        Objects.requireNonNull(assignments);
        Stream flatMap = stream.map(assignments::get).map(SymbolsExtractor::extractAll).flatMap((v0) -> {
            return v0.stream();
        });
        Objects.requireNonNull(createdSymbols);
        return flatMap.anyMatch((v1) -> {
            return r1.contains(v1);
        });
    }

    private static Assignments extractPrerequisites(PatternRecognitionNode patternRecognitionNode, ProjectNode projectNode) {
        Assignments assignments = projectNode.getAssignments();
        ImmutableSet.Builder builder = ImmutableSet.builder();
        Stream<R> map = patternRecognitionNode.getWindowFunctions().values().stream().map(SymbolsExtractor::extractAll);
        Objects.requireNonNull(builder);
        map.forEach((v1) -> {
            r1.addAll(v1);
        });
        Stream map2 = patternRecognitionNode.getMeasures().values().stream().map((v0) -> {
            return v0.getExpressionAndValuePointers();
        }).map((v0) -> {
            return v0.getInputSymbols();
        });
        Objects.requireNonNull(builder);
        map2.forEach((v1) -> {
            r1.addAll(v1);
        });
        ImmutableSet build = builder.build();
        Assignments filter = assignments.filter(symbol -> {
            return !assignments.isIdentity(symbol);
        });
        Objects.requireNonNull(build);
        return filter.filter((v1) -> {
            return r1.contains(v1);
        });
    }

    private static PatternRecognitionNode merge(PatternRecognitionNode patternRecognitionNode, PatternRecognitionNode patternRecognitionNode2) {
        return new PatternRecognitionNode(patternRecognitionNode.getId(), patternRecognitionNode2.getSource(), patternRecognitionNode.getSpecification(), patternRecognitionNode.getHashSymbol(), patternRecognitionNode.getPrePartitionedInputs(), patternRecognitionNode.getPreSortedOrderPrefix(), ImmutableMap.builder().putAll(patternRecognitionNode.getWindowFunctions()).putAll(patternRecognitionNode2.getWindowFunctions()).build(), ImmutableMap.builder().putAll(patternRecognitionNode.getMeasures()).putAll(patternRecognitionNode2.getMeasures()).build(), patternRecognitionNode.getCommonBaseFrame(), patternRecognitionNode.getRowsPerMatch(), patternRecognitionNode.getSkipToLabel(), patternRecognitionNode.getSkipToPosition(), patternRecognitionNode.isInitial(), patternRecognitionNode.getPattern(), patternRecognitionNode.getSubsets(), patternRecognitionNode.getVariableDefinitions());
    }
}
