package io.trino.sql.planner.iterative.rule;

import com.google.common.base.Verify;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import io.trino.Session;
import io.trino.SystemSessionProperties;
import io.trino.cost.PlanNodeStatsEstimate;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.metadata.TableHandle;
import io.trino.spi.connector.BasicRelationStatistics;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.JoinApplicationResult;
import io.trino.spi.connector.JoinStatistics;
import io.trino.spi.connector.JoinType;
import io.trino.spi.predicate.Domain;
import io.trino.spi.predicate.TupleDomain;
import io.trino.sql.ExpressionUtils;
import io.trino.sql.PlannerContext;
import io.trino.sql.planner.ConnectorExpressionTranslator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.TypeAnalyzer;
import io.trino.sql.planner.TypeProvider;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.planner.plan.TableScanNode;
import io.trino.sql.tree.BooleanLiteral;
import io.trino.sql.tree.Expression;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/PushJoinIntoTableScan.class */
public class PushJoinIntoTableScan implements Rule<JoinNode> {
    private static final Capture<TableScanNode> LEFT_TABLE_SCAN = Capture.newCapture();
    private static final Capture<TableScanNode> RIGHT_TABLE_SCAN = Capture.newCapture();
    private static final Pattern<JoinNode> PATTERN = Patterns.join().with(Patterns.Join.left().matching(Patterns.tableScan().capturedAs(LEFT_TABLE_SCAN))).with(Patterns.Join.right().matching(Patterns.tableScan().capturedAs(RIGHT_TABLE_SCAN)));
    private final PlannerContext plannerContext;
    private final TypeAnalyzer typeAnalyzer;

    public PushJoinIntoTableScan(PlannerContext plannerContext, TypeAnalyzer typeAnalyzer) {
        this.plannerContext = (PlannerContext) Objects.requireNonNull(plannerContext, "plannerContext is null");
        this.typeAnalyzer = (TypeAnalyzer) Objects.requireNonNull(typeAnalyzer, "typeAnalyzer is null");
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public boolean isEnabled(Session session) {
        return SystemSessionProperties.isAllowPushdownIntoConnectors(session);
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        if (joinNode.isCrossJoin()) {
            return Rule.Result.empty();
        }
        TableScanNode tableScanNode = (TableScanNode) captures.get(LEFT_TABLE_SCAN);
        TableScanNode tableScanNode2 = (TableScanNode) captures.get(RIGHT_TABLE_SCAN);
        Verify.verify((tableScanNode.isUpdateTarget() || tableScanNode2.isUpdateTarget()) ? false : true, "Unexpected Join over for-update table scan", new Object[0]);
        ConnectorExpressionTranslator.ConnectorExpressionTranslation translateConjuncts = ConnectorExpressionTranslator.translateConjuncts(context.getSession(), getEffectiveFilter(joinNode), context.getSymbolAllocator().getTypes(), this.plannerContext, this.typeAnalyzer);
        if (!translateConjuncts.remainingExpression().equals(BooleanLiteral.TRUE_LITERAL)) {
            return Rule.Result.empty();
        }
        if (tableScanNode.getEnforcedConstraint().isNone() || tableScanNode2.getEnforcedConstraint().isNone()) {
            return Rule.Result.empty();
        }
        Optional<JoinApplicationResult<TableHandle>> applyJoin = this.plannerContext.getMetadata().applyJoin(context.getSession(), getJoinType(joinNode), tableScanNode.getTable(), tableScanNode2.getTable(), translateConjuncts.connectorExpression(), (Map) tableScanNode.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
            return ((Symbol) entry.getKey()).getName();
        }, (v0) -> {
            return v0.getValue();
        })), (Map) tableScanNode2.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap(entry2 -> {
            return ((Symbol) entry2.getKey()).getName();
        }, (v0) -> {
            return v0.getValue();
        })), getJoinStatistics(joinNode, tableScanNode, tableScanNode2, context));
        if (applyJoin.isEmpty()) {
            return Rule.Result.empty();
        }
        TableHandle tableHandle = (TableHandle) applyJoin.get().getTableHandle();
        Map<ColumnHandle, ColumnHandle> leftColumnHandles = applyJoin.get().getLeftColumnHandles();
        Map<ColumnHandle, ColumnHandle> rightColumnHandles = applyJoin.get().getRightColumnHandles();
        ImmutableMap.Builder builder = ImmutableMap.builder();
        builder.putAll((Map) tableScanNode.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry3 -> {
            return (ColumnHandle) leftColumnHandles.get(entry3.getValue());
        })));
        builder.putAll((Map) tableScanNode2.getAssignments().entrySet().stream().collect(ImmutableMap.toImmutableMap((v0) -> {
            return v0.getKey();
        }, entry4 -> {
            return (ColumnHandle) rightColumnHandles.get(entry4.getValue());
        })));
        ImmutableMap buildOrThrow = builder.buildOrThrow();
        JoinNode.Type type = joinNode.getType();
        return Rule.Result.ofPlanNode(new ProjectNode(context.getIdAllocator().getNextId(), new TableScanNode(joinNode.getId(), tableHandle, ImmutableList.copyOf(buildOrThrow.keySet()), buildOrThrow, TupleDomain.withColumnDomains(ImmutableMap.builder().putAll((Map) deriveConstraint(tableScanNode.getEnforcedConstraint(), leftColumnHandles, type == JoinNode.Type.RIGHT || type == JoinNode.Type.FULL).getDomains().orElseThrow()).putAll((Map) deriveConstraint(tableScanNode2.getEnforcedConstraint(), rightColumnHandles, type == JoinNode.Type.LEFT || type == JoinNode.Type.FULL).getDomains().orElseThrow()).buildOrThrow()), Rules.deriveTableStatisticsForPushdown(context.getStatsProvider(), context.getSession(), applyJoin.get().isPrecalculateStatistics(), joinNode), false, Optional.empty()), Assignments.identity(joinNode.getOutputSymbols())));
    }

    private JoinStatistics getJoinStatistics(final JoinNode joinNode, final TableScanNode tableScanNode, final TableScanNode tableScanNode2, final Rule.Context context) {
        return new JoinStatistics(this) { // from class: io.trino.sql.planner.iterative.rule.PushJoinIntoTableScan.1
            public Optional<BasicRelationStatistics> getLeftStatistics() {
                return getBasicRelationStats(tableScanNode, tableScanNode.getOutputSymbols(), context);
            }

            public Optional<BasicRelationStatistics> getRightStatistics() {
                return getBasicRelationStats(tableScanNode2, tableScanNode2.getOutputSymbols(), context);
            }

            public Optional<BasicRelationStatistics> getJoinStatistics() {
                return getBasicRelationStats(joinNode, joinNode.getOutputSymbols(), context);
            }

            private Optional<BasicRelationStatistics> getBasicRelationStats(PlanNode planNode, List<Symbol> list, Rule.Context context2) {
                PlanNodeStatsEstimate stats = context2.getStatsProvider().getStats(planNode);
                TypeProvider types = context2.getSymbolAllocator().getTypes();
                double outputRowCount = stats.getOutputRowCount();
                double outputSizeInBytes = stats.getOutputSizeInBytes(list, types);
                return (Double.isNaN(outputRowCount) || Double.isNaN(outputSizeInBytes)) ? Optional.empty() : Optional.of(new BasicRelationStatistics((long) outputRowCount, (long) outputSizeInBytes));
            }
        };
    }

    private TupleDomain<ColumnHandle> deriveConstraint(TupleDomain<ColumnHandle> tupleDomain, Map<ColumnHandle, ColumnHandle> map, boolean z) {
        TupleDomain<ColumnHandle> tupleDomain2 = tupleDomain;
        if (z) {
            tupleDomain2 = tupleDomain2.transformDomains((columnHandle, domain) -> {
                return domain.union(Domain.onlyNull(domain.getType()));
            });
        }
        Objects.requireNonNull(map);
        return tupleDomain2.transformKeys((v1) -> {
            return r1.get(v1);
        });
    }

    public Expression getEffectiveFilter(JoinNode joinNode) {
        Expression and = ExpressionUtils.and((Collection<Expression>) joinNode.getCriteria().stream().map((v0) -> {
            return v0.toExpression();
        }).collect(ImmutableList.toImmutableList()));
        if (joinNode.getFilter().isPresent()) {
            and = ExpressionUtils.and(and, joinNode.getFilter().get());
        }
        return and;
    }

    private JoinType getJoinType(JoinNode joinNode) {
        switch (joinNode.getType()) {
            case INNER:
                return JoinType.INNER;
            case LEFT:
                return JoinType.LEFT_OUTER;
            case RIGHT:
                return JoinType.RIGHT_OUTER;
            case FULL:
                return JoinType.FULL_OUTER;
            default:
                throw new MatchException((String) null, (Throwable) null);
        }
    }
}
