package io.trino.sql.planner.iterative.rule;

import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.planner.PlanNodeIdAllocator;
import io.trino.sql.planner.Symbol;
import io.trino.sql.planner.SymbolAllocator;
import io.trino.sql.planner.iterative.Rule;
import io.trino.sql.planner.optimizations.QueryCardinalityUtil;
import io.trino.sql.planner.plan.Assignments;
import io.trino.sql.planner.plan.JoinNode;
import io.trino.sql.planner.plan.Patterns;
import io.trino.sql.planner.plan.PlanNode;
import io.trino.sql.planner.plan.ProjectNode;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.NullLiteral;
import java.util.List;

/* loaded from: input_file:io/trino/sql/planner/iterative/rule/ReplaceRedundantJoinWithProject.class */
public class ReplaceRedundantJoinWithProject implements Rule<JoinNode> {
    private static final Pattern<JoinNode> PATTERN = Patterns.join();

    @Override // io.trino.sql.planner.iterative.Rule
    public Pattern<JoinNode> getPattern() {
        return PATTERN;
    }

    @Override // io.trino.sql.planner.iterative.Rule
    public Rule.Result apply(JoinNode joinNode, Captures captures, Rule.Context context) {
        boolean isAtMost = QueryCardinalityUtil.isAtMost(joinNode.getLeft(), context.getLookup(), 0L);
        boolean isAtMost2 = QueryCardinalityUtil.isAtMost(joinNode.getRight(), context.getLookup(), 0L);
        switch (joinNode.getType()) {
            case INNER:
                return Rule.Result.empty();
            case LEFT:
                if (!isAtMost && isAtMost2) {
                    return Rule.Result.ofPlanNode(appendNulls(joinNode.getLeft(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), context.getIdAllocator(), context.getSymbolAllocator()));
                }
                break;
            case RIGHT:
                if (isAtMost && !isAtMost2) {
                    return Rule.Result.ofPlanNode(appendNulls(joinNode.getRight(), joinNode.getRightOutputSymbols(), joinNode.getLeftOutputSymbols(), context.getIdAllocator(), context.getSymbolAllocator()));
                }
                break;
            case FULL:
                if (isAtMost && !isAtMost2) {
                    return Rule.Result.ofPlanNode(appendNulls(joinNode.getRight(), joinNode.getRightOutputSymbols(), joinNode.getLeftOutputSymbols(), context.getIdAllocator(), context.getSymbolAllocator()));
                }
                if (!isAtMost && isAtMost2) {
                    return Rule.Result.ofPlanNode(appendNulls(joinNode.getLeft(), joinNode.getLeftOutputSymbols(), joinNode.getRightOutputSymbols(), context.getIdAllocator(), context.getSymbolAllocator()));
                }
                break;
        }
        return Rule.Result.empty();
    }

    private static ProjectNode appendNulls(PlanNode planNode, List<Symbol> list, List<Symbol> list2, PlanNodeIdAllocator planNodeIdAllocator, SymbolAllocator symbolAllocator) {
        Assignments.Builder putIdentities = Assignments.builder().putIdentities(list);
        list2.stream().forEach(symbol -> {
            putIdentities.put(symbol, new Cast(new NullLiteral(), TypeSignatureTranslator.toSqlType(symbolAllocator.getTypes().get(symbol))));
        });
        return new ProjectNode(planNodeIdAllocator.getNextId(), planNode, putIdentities.build());
    }
}
