package io.trino.sql.planner.rowpattern;

import com.google.common.collect.ImmutableMap;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.rowpattern.LogicalIndexExtractor;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.SymbolReference;
import io.trino.sql.util.AstUtils;
import java.util.Map;
import java.util.Set;
import java.util.function.BiFunction;

/* loaded from: input_file:io/trino/sql/planner/rowpattern/ExpressionAndValuePointersEquivalence.class */
public class ExpressionAndValuePointersEquivalence {
    private ExpressionAndValuePointersEquivalence() {
    }

    public static boolean equivalent(LogicalIndexExtractor.ExpressionAndValuePointers expressionAndValuePointers, LogicalIndexExtractor.ExpressionAndValuePointers expressionAndValuePointers2) {
        return equivalent(expressionAndValuePointers, expressionAndValuePointers2, (BiFunction<Symbol, Symbol, Boolean>) (v0, v1) -> {
            return v0.equals(v1);
        });
    }

    public static boolean equivalent(LogicalIndexExtractor.ExpressionAndValuePointers expressionAndValuePointers, LogicalIndexExtractor.ExpressionAndValuePointers expressionAndValuePointers2, BiFunction<Symbol, Symbol, Boolean> biFunction) {
        if (expressionAndValuePointers.getLayout().size() != expressionAndValuePointers2.getLayout().size()) {
            return false;
        }
        for (int i = 0; i < expressionAndValuePointers.getLayout().size(); i++) {
            ValuePointer valuePointer = expressionAndValuePointers.getValuePointers().get(i);
            ValuePointer valuePointer2 = expressionAndValuePointers2.getValuePointers().get(i);
            if (valuePointer.getClass() != valuePointer2.getClass()) {
                return false;
            }
            if (valuePointer instanceof ScalarValuePointer) {
                if (!equivalent((ScalarValuePointer) valuePointer, (ScalarValuePointer) valuePointer2, expressionAndValuePointers.getClassifierSymbols(), expressionAndValuePointers.getMatchNumberSymbols(), expressionAndValuePointers2.getClassifierSymbols(), expressionAndValuePointers2.getMatchNumberSymbols(), biFunction)) {
                    return false;
                }
            } else {
                if (!(valuePointer instanceof AggregationValuePointer)) {
                    throw new UnsupportedOperationException("unexpected ValuePointer type: " + valuePointer.getClass().getSimpleName());
                }
                if (!equivalent((AggregationValuePointer) valuePointer, (AggregationValuePointer) valuePointer2, biFunction)) {
                    return false;
                }
            }
        }
        ImmutableMap.Builder builder = ImmutableMap.builder();
        for (int i2 = 0; i2 < expressionAndValuePointers.getLayout().size(); i2++) {
            builder.put(expressionAndValuePointers.getLayout().get(i2), expressionAndValuePointers2.getLayout().get(i2));
        }
        return AstUtils.treeEqual(expressionAndValuePointers.getExpression(), expressionAndValuePointers2.getExpression(), mappingComparator(builder.buildOrThrow()));
    }

    private static boolean equivalent(ScalarValuePointer scalarValuePointer, ScalarValuePointer scalarValuePointer2, Set<Symbol> set, Set<Symbol> set2, Set<Symbol> set3, Set<Symbol> set4, BiFunction<Symbol, Symbol, Boolean> biFunction) {
        if (!scalarValuePointer.getLogicalIndexPointer().equals(scalarValuePointer2.getLogicalIndexPointer())) {
            return false;
        }
        Symbol inputSymbol = scalarValuePointer.getInputSymbol();
        Symbol inputSymbol2 = scalarValuePointer2.getInputSymbol();
        boolean contains = set.contains(inputSymbol);
        boolean contains2 = set2.contains(inputSymbol);
        boolean contains3 = set3.contains(inputSymbol2);
        boolean contains4 = set4.contains(inputSymbol2);
        if (contains != contains3 || contains2 != contains4) {
            return false;
        }
        if (contains || contains2) {
            return true;
        }
        return biFunction.apply(inputSymbol, inputSymbol2).booleanValue();
    }

    private static boolean equivalent(AggregationValuePointer aggregationValuePointer, AggregationValuePointer aggregationValuePointer2, BiFunction<Symbol, Symbol, Boolean> biFunction) {
        if (!aggregationValuePointer.getFunction().equals(aggregationValuePointer2.getFunction()) || !aggregationValuePointer.getSetDescriptor().equals(aggregationValuePointer2.getSetDescriptor()) || aggregationValuePointer.getArguments().size() != aggregationValuePointer2.getArguments().size()) {
            return false;
        }
        BiFunction<Node, Node, Boolean> subsetComparator = subsetComparator(aggregationValuePointer.getClassifierSymbol(), aggregationValuePointer.getMatchNumberSymbol(), aggregationValuePointer2.getClassifierSymbol(), aggregationValuePointer2.getMatchNumberSymbol(), biFunction);
        for (int i = 0; i < aggregationValuePointer.getArguments().size(); i++) {
            if (!AstUtils.treeEqual(aggregationValuePointer.getArguments().get(i), aggregationValuePointer2.getArguments().get(i), subsetComparator)) {
                return false;
            }
        }
        return true;
    }

    private static BiFunction<Node, Node, Boolean> subsetComparator(Symbol symbol, Symbol symbol2, Symbol symbol3, Symbol symbol4, BiFunction<Symbol, Symbol, Boolean> biFunction) {
        return (node, node2) -> {
            if (!(node instanceof SymbolReference) || !(node2 instanceof SymbolReference)) {
                return !node.shallowEquals(node2) ? false : null;
            }
            Symbol from = Symbol.from((SymbolReference) node);
            Symbol from2 = Symbol.from((SymbolReference) node2);
            boolean equals = from.equals(symbol);
            boolean equals2 = from.equals(symbol2);
            boolean equals3 = from2.equals(symbol3);
            boolean equals4 = from2.equals(symbol4);
            if (equals != equals3 || equals2 != equals4) {
                return false;
            }
            if (equals || equals2) {
                return true;
            }
            return (Boolean) biFunction.apply(from, from2);
        };
    }

    private static BiFunction<Node, Node, Boolean> mappingComparator(Map<Symbol, Symbol> map) {
        return (node, node2) -> {
            if ((node instanceof SymbolReference) && (node2 instanceof SymbolReference)) {
                return Boolean.valueOf(Symbol.from((SymbolReference) node2).equals(map.get(Symbol.from((SymbolReference) node))));
            }
            return !node.shallowEquals(node2) ? false : null;
        };
    }
}
