package io.trino.cost;

import com.google.common.base.Preconditions;
import com.google.common.base.VerifyException;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.execution.warnings.WarningCollector;
import io.trino.metadata.Metadata;
import io.trino.security.AccessControl;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.statistics.StatsUtil;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.DynamicFilters;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.analyzer.ExpressionAnalyzer;
import io.trino.sql.analyzer.Scope;
import io.trino.sql.planner.ExpressionInterpreter;
import io.trino.sql.planner.LiteralEncoder;
import io.trino.sql.planner.LiteralInterpreter;
import io.trino.sql.planner.NoOpSymbolResolver;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.tree.AstVisitor;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.InListExpression;
import io.trino.sql.tree.InPredicate;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.Literal;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.Node;
import io.trino.sql.tree.NodeRef;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.Parameter;
import io.trino.sql.tree.SymbolReference;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.OptionalDouble;
import java.util.function.Function;
import java.util.function.Supplier;
import javax.annotation.Nullable;
import javax.inject.Inject;

/* loaded from: input_file:io/trino/cost/FilterStatsCalculator.class */
public class FilterStatsCalculator {
    static final double UNKNOWN_FILTER_COEFFICIENT = 0.9d;
    private final Metadata metadata;
    private final ScalarStatsCalculator scalarStatsCalculator;
    private final StatsNormalizer normalizer;

    /* JADX INFO: Access modifiers changed from: package-private */
    /* renamed from: io.trino.cost.FilterStatsCalculator$1, reason: invalid class name */
    /* loaded from: input_file:io/trino/cost/FilterStatsCalculator$1.class */
    public static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$io$trino$sql$tree$LogicalExpression$Operator = new int[LogicalExpression.Operator.values().length];

        static {
            try {
                $SwitchMap$io$trino$sql$tree$LogicalExpression$Operator[LogicalExpression.Operator.AND.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$io$trino$sql$tree$LogicalExpression$Operator[LogicalExpression.Operator.OR.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
        }
    }

    /* loaded from: input_file:io/trino/cost/FilterStatsCalculator$FilterExpressionStatsCalculatingVisitor.class */
    private class FilterExpressionStatsCalculatingVisitor extends AstVisitor<PlanNodeStatsEstimate, Void> {
        private final PlanNodeStatsEstimate input;
        private final Session session;
        private final TypeProvider types;

        FilterExpressionStatsCalculatingVisitor(PlanNodeStatsEstimate planNodeStatsEstimate, Session session, TypeProvider typeProvider) {
            this.input = planNodeStatsEstimate;
            this.session = session;
            this.types = typeProvider;
        }

        public PlanNodeStatsEstimate process(Node node, @Nullable Void r7) {
            return FilterStatsCalculator.this.normalizer.normalize((PlanNodeStatsEstimate) super.process(node, r7), this.types);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitExpression(Expression expression, Void r4) {
            return PlanNodeStatsEstimate.unknown();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitNotExpression(NotExpression notExpression, Void r7) {
            return notExpression.getValue() instanceof IsNullPredicate ? (PlanNodeStatsEstimate) process(new IsNotNullPredicate(notExpression.getValue().getValue())) : PlanNodeStatsEstimateMath.subtractSubsetStats(this.input, (PlanNodeStatsEstimate) process(notExpression.getValue()));
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitLogicalExpression(LogicalExpression logicalExpression, Void r6) {
            switch (AnonymousClass1.$SwitchMap$io$trino$sql$tree$LogicalExpression$Operator[logicalExpression.getOperator().ordinal()]) {
                case 1:
                    return estimateLogicalAnd(logicalExpression.getTerms());
                case 2:
                    return estimateLogicalOr(logicalExpression.getTerms());
                default:
                    throw new IllegalArgumentException("Unexpected binary operator: " + logicalExpression.getOperator());
            }
        }

        private PlanNodeStatsEstimate estimateLogicalAnd(List<Expression> list) {
            PlanNodeStatsEstimate planNodeStatsEstimate = (PlanNodeStatsEstimate) process((Node) list.get(0));
            if (!planNodeStatsEstimate.isOutputRowCountUnknown()) {
                for (int i = 1; i < list.size(); i++) {
                    planNodeStatsEstimate = (PlanNodeStatsEstimate) new FilterExpressionStatsCalculatingVisitor(planNodeStatsEstimate, this.session, this.types).process((Node) list.get(i));
                    if (planNodeStatsEstimate.isOutputRowCountUnknown()) {
                        break;
                    }
                }
                if (!planNodeStatsEstimate.isOutputRowCountUnknown()) {
                    return planNodeStatsEstimate;
                }
            }
            Optional findFirst = list.stream().map((v1) -> {
                return process(v1);
            }).filter(planNodeStatsEstimate2 -> {
                return !planNodeStatsEstimate2.isOutputRowCountUnknown();
            }).sorted(Comparator.comparingDouble((v0) -> {
                return v0.getOutputRowCount();
            })).findFirst();
            return findFirst.isEmpty() ? PlanNodeStatsEstimate.unknown() : ((PlanNodeStatsEstimate) findFirst.get()).mapOutputRowCount(d -> {
                return Double.valueOf(d.doubleValue() * FilterStatsCalculator.UNKNOWN_FILTER_COEFFICIENT);
            });
        }

        private PlanNodeStatsEstimate estimateLogicalOr(List<Expression> list) {
            PlanNodeStatsEstimate planNodeStatsEstimate = (PlanNodeStatsEstimate) process((Node) list.get(0));
            if (planNodeStatsEstimate.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            for (int i = 1; i < list.size(); i++) {
                PlanNodeStatsEstimate planNodeStatsEstimate2 = (PlanNodeStatsEstimate) process((Node) list.get(i));
                if (planNodeStatsEstimate2.isOutputRowCountUnknown()) {
                    return PlanNodeStatsEstimate.unknown();
                }
                PlanNodeStatsEstimate planNodeStatsEstimate3 = (PlanNodeStatsEstimate) new FilterExpressionStatsCalculatingVisitor(planNodeStatsEstimate, this.session, this.types).process((Node) list.get(i));
                if (planNodeStatsEstimate3.isOutputRowCountUnknown()) {
                    return PlanNodeStatsEstimate.unknown();
                }
                planNodeStatsEstimate = PlanNodeStatsEstimateMath.capStats(PlanNodeStatsEstimateMath.subtractSubsetStats(PlanNodeStatsEstimateMath.addStatsAndSumDistinctValues(planNodeStatsEstimate, planNodeStatsEstimate2), planNodeStatsEstimate3), this.input);
            }
            return planNodeStatsEstimate;
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitBooleanLiteral(BooleanLiteral booleanLiteral, Void r6) {
            if (booleanLiteral.getValue()) {
                return this.input;
            }
            PlanNodeStatsEstimate.Builder builder = PlanNodeStatsEstimate.builder();
            builder.setOutputRowCount(0.0d);
            this.input.getSymbolsWithKnownStatistics().forEach(symbol -> {
                builder.addSymbolStatistics(symbol, SymbolStatsEstimate.zero());
            });
            return builder.build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitIsNotNullPredicate(IsNotNullPredicate isNotNullPredicate, Void r10) {
            if (!(isNotNullPredicate.getValue() instanceof SymbolReference)) {
                return PlanNodeStatsEstimate.unknown();
            }
            Symbol from = Symbol.from(isNotNullPredicate.getValue());
            SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(from);
            PlanNodeStatsEstimate.Builder buildFrom = PlanNodeStatsEstimate.buildFrom(this.input);
            buildFrom.setOutputRowCount(this.input.getOutputRowCount() * (1.0d - symbolStatistics.getNullsFraction()));
            buildFrom.addSymbolStatistics(from, symbolStatistics.mapNullsFraction(d -> {
                return Double.valueOf(0.0d);
            }));
            return buildFrom.build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitIsNullPredicate(IsNullPredicate isNullPredicate, Void r8) {
            if (!(isNullPredicate.getValue() instanceof SymbolReference)) {
                return PlanNodeStatsEstimate.unknown();
            }
            Symbol from = Symbol.from(isNullPredicate.getValue());
            SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(from);
            PlanNodeStatsEstimate.Builder buildFrom = PlanNodeStatsEstimate.buildFrom(this.input);
            buildFrom.setOutputRowCount(this.input.getOutputRowCount() * symbolStatistics.getNullsFraction());
            buildFrom.addSymbolStatistics(from, SymbolStatsEstimate.builder().setNullsFraction(1.0d).setLowValue(Double.NaN).setHighValue(Double.NaN).setDistinctValuesCount(0.0d).build());
            return buildFrom.build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitBetweenPredicate(BetweenPredicate betweenPredicate, Void r8) {
            if ((betweenPredicate.getValue() instanceof SymbolReference) && getExpressionStats(betweenPredicate.getMin()).isSingleValue() && getExpressionStats(betweenPredicate.getMax()).isSingleValue()) {
                SymbolStatsEstimate symbolStatistics = this.input.getSymbolStatistics(Symbol.from(betweenPredicate.getValue()));
                Expression comparisonExpression = new ComparisonExpression(ComparisonExpression.Operator.GREATER_THAN_OR_EQUAL, betweenPredicate.getValue(), betweenPredicate.getMin());
                Expression comparisonExpression2 = new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN_OR_EQUAL, betweenPredicate.getValue(), betweenPredicate.getMax());
                return (PlanNodeStatsEstimate) process(Double.isInfinite(symbolStatistics.getLowValue()) ? ExpressionUtils.and(comparisonExpression, comparisonExpression2) : ExpressionUtils.and(comparisonExpression2, comparisonExpression));
            }
            return PlanNodeStatsEstimate.unknown();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitInPredicate(InPredicate inPredicate, Void r9) {
            if (!(inPredicate.getValueList() instanceof InListExpression)) {
                return PlanNodeStatsEstimate.unknown();
            }
            ImmutableList immutableList = (ImmutableList) inPredicate.getValueList().getValues().stream().map(expression -> {
                return (PlanNodeStatsEstimate) process(new ComparisonExpression(ComparisonExpression.Operator.EQUAL, inPredicate.getValue(), expression));
            }).collect(ImmutableList.toImmutableList());
            if (immutableList.stream().anyMatch((v0) -> {
                return v0.isOutputRowCountUnknown();
            })) {
                return PlanNodeStatsEstimate.unknown();
            }
            PlanNodeStatsEstimate planNodeStatsEstimate = (PlanNodeStatsEstimate) immutableList.stream().reduce(PlanNodeStatsEstimateMath::addStatsAndSumDistinctValues).orElse(PlanNodeStatsEstimate.unknown());
            if (planNodeStatsEstimate.isOutputRowCountUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            SymbolStatsEstimate expressionStats = getExpressionStats(inPredicate.getValue());
            if (expressionStats.isUnknown()) {
                return PlanNodeStatsEstimate.unknown();
            }
            double outputRowCount = this.input.getOutputRowCount() * (1.0d - expressionStats.getNullsFraction());
            PlanNodeStatsEstimate.Builder buildFrom = PlanNodeStatsEstimate.buildFrom(this.input);
            buildFrom.setOutputRowCount(Double.min(planNodeStatsEstimate.getOutputRowCount(), outputRowCount));
            if (inPredicate.getValue() instanceof SymbolReference) {
                Symbol from = Symbol.from(inPredicate.getValue());
                buildFrom.addSymbolStatistics(from, planNodeStatsEstimate.getSymbolStatistics(from).mapDistinctValuesCount(d -> {
                    return Double.valueOf(Double.min(d.doubleValue(), expressionStats.getDistinctValuesCount()));
                }));
            }
            return buildFrom.build();
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitComparisonExpression(ComparisonExpression comparisonExpression, Void r9) {
            ComparisonExpression.Operator operator = comparisonExpression.getOperator();
            Expression left = comparisonExpression.getLeft();
            Expression right = comparisonExpression.getRight();
            Preconditions.checkArgument(((left instanceof Literal) && (right instanceof Literal)) ? false : true, "Literal-to-literal not supported here, should be eliminated earlier");
            if (!(left instanceof SymbolReference) && (right instanceof SymbolReference)) {
                return (PlanNodeStatsEstimate) process(new ComparisonExpression(operator.flip(), right, left));
            }
            if ((left instanceof Literal) && !(right instanceof Literal)) {
                return (PlanNodeStatsEstimate) process(new ComparisonExpression(operator.flip(), right, left));
            }
            if ((left instanceof SymbolReference) && left.equals(right)) {
                return (PlanNodeStatsEstimate) process(new IsNotNullPredicate(left));
            }
            SymbolStatsEstimate expressionStats = getExpressionStats(left);
            Optional of = left instanceof SymbolReference ? Optional.of(Symbol.from(left)) : Optional.empty();
            if (right instanceof Literal) {
                return ComparisonStatsCalculator.estimateExpressionToLiteralComparison(this.input, expressionStats, of, doubleValueFromLiteral(getType(left), (Literal) right), operator);
            }
            SymbolStatsEstimate expressionStats2 = getExpressionStats(right);
            if (expressionStats2.isSingleValue()) {
                return ComparisonStatsCalculator.estimateExpressionToLiteralComparison(this.input, expressionStats, of, Double.isNaN(expressionStats2.getLowValue()) ? OptionalDouble.empty() : OptionalDouble.of(expressionStats2.getLowValue()), operator);
            }
            return ComparisonStatsCalculator.estimateExpressionToExpressionComparison(this.input, expressionStats, of, expressionStats2, right instanceof SymbolReference ? Optional.of(Symbol.from(right)) : Optional.empty(), operator);
        }

        /* JADX INFO: Access modifiers changed from: protected */
        public PlanNodeStatsEstimate visitFunctionCall(FunctionCall functionCall, Void r6) {
            return DynamicFilters.isDynamicFilter(functionCall) ? process((Node) BooleanLiteral.TRUE_LITERAL, r6) : PlanNodeStatsEstimate.unknown();
        }

        private Type getType(Expression expression) {
            if (!(expression instanceof SymbolReference)) {
                return ExpressionAnalyzer.createWithoutSubqueries(FilterStatsCalculator.this.metadata, (AccessControl) new AllowAllAccessControl(), this.session, this.types, (Map<NodeRef<Parameter>, Expression>) ImmutableMap.of(), (Function<? super Node, ? extends RuntimeException>) node -> {
                    return new VerifyException("Unexpected subquery");
                }, WarningCollector.NOOP, false).analyze(expression, Scope.create());
            }
            Symbol from = Symbol.from(expression);
            return (Type) Objects.requireNonNull(this.types.get(from), (Supplier<String>) () -> {
                return String.format("No type for symbol %s", from);
            });
        }

        private SymbolStatsEstimate getExpressionStats(Expression expression) {
            if (!(expression instanceof SymbolReference)) {
                return FilterStatsCalculator.this.scalarStatsCalculator.calculate(expression, this.input, this.session, this.types);
            }
            Symbol from = Symbol.from(expression);
            return (SymbolStatsEstimate) Objects.requireNonNull(this.input.getSymbolStatistics(from), (Supplier<String>) () -> {
                return String.format("No statistics for symbol %s", from);
            });
        }

        private OptionalDouble doubleValueFromLiteral(Type type, Literal literal) {
            return StatsUtil.toStatsRepresentation(type, LiteralInterpreter.evaluate(FilterStatsCalculator.this.metadata, this.session, FilterStatsCalculator.this.getExpressionTypes(this.session, literal, this.types), literal));
        }
    }

    @Inject
    public FilterStatsCalculator(Metadata metadata, ScalarStatsCalculator scalarStatsCalculator, StatsNormalizer statsNormalizer) {
        this.metadata = (Metadata) Objects.requireNonNull(metadata, "metadata is null");
        this.scalarStatsCalculator = (ScalarStatsCalculator) Objects.requireNonNull(scalarStatsCalculator, "scalarStatsCalculator is null");
        this.normalizer = (StatsNormalizer) Objects.requireNonNull(statsNormalizer, "normalizer is null");
    }

    public PlanNodeStatsEstimate filterStats(PlanNodeStatsEstimate planNodeStatsEstimate, Expression expression, Session session, TypeProvider typeProvider) {
        return (PlanNodeStatsEstimate) new FilterExpressionStatsCalculatingVisitor(planNodeStatsEstimate, session, typeProvider).process(simplifyExpression(session, expression, typeProvider));
    }

    private Expression simplifyExpression(Session session, Expression expression, TypeProvider typeProvider) {
        Object optimize = new ExpressionInterpreter(expression, this.metadata, session, getExpressionTypes(session, expression, typeProvider)).optimize(NoOpSymbolResolver.INSTANCE);
        if (optimize == null) {
            optimize = false;
        }
        return new LiteralEncoder(session, this.metadata).toExpression(optimize, BooleanType.BOOLEAN);
    }

    private Map<NodeRef<Expression>, Type> getExpressionTypes(Session session, Expression expression, TypeProvider typeProvider) {
        ExpressionAnalyzer createWithoutSubqueries = ExpressionAnalyzer.createWithoutSubqueries(this.metadata, (AccessControl) new AllowAllAccessControl(), session, typeProvider, (Map<NodeRef<Parameter>, Expression>) Collections.emptyMap(), (Function<? super Node, ? extends RuntimeException>) node -> {
            return new IllegalStateException("Unexpected node: " + node);
        }, WarningCollector.NOOP, false);
        createWithoutSubqueries.analyze(expression, Scope.create());
        return createWithoutSubqueries.getExpressionTypes();
    }
}
