package io.trino.sql.gen;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.airlift.bytecode.BytecodeBlock;
import io.airlift.bytecode.BytecodeNode;
import io.airlift.bytecode.Scope;
import io.airlift.bytecode.Variable;
import io.airlift.bytecode.control.IfStatement;
import io.airlift.bytecode.expression.BytecodeExpressions;
import io.airlift.bytecode.instruction.LabelNode;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.Type;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.relational.SpecialForm;
import java.util.Objects;
import java.util.Optional;

/* loaded from: input_file:io/trino/sql/gen/NullIfCodeGenerator.class */
public class NullIfCodeGenerator implements BytecodeGenerator {
    private final RowExpression first;
    private final RowExpression second;
    private final ResolvedFunction equalsFunction;
    private final Optional<ResolvedFunction> firstCast;
    private final Optional<ResolvedFunction> secondCast;

    public NullIfCodeGenerator(SpecialForm specialForm) {
        Objects.requireNonNull(specialForm, "specialForm is null");
        Preconditions.checkArgument(specialForm.getArguments().size() == 2);
        this.first = specialForm.getArguments().get(0);
        this.second = specialForm.getArguments().get(1);
        Preconditions.checkArgument(specialForm.getFunctionDependencies().size() <= 3);
        this.equalsFunction = specialForm.getOperatorDependency(OperatorType.EQUAL);
        this.firstCast = specialForm.getCastDependency(this.first.getType(), (Type) this.equalsFunction.getSignature().getArgumentTypes().get(0));
        this.secondCast = specialForm.getCastDependency(this.second.getType(), (Type) this.equalsFunction.getSignature().getArgumentTypes().get(0));
    }

    @Override // io.trino.sql.gen.BytecodeGenerator
    public BytecodeNode generateExpression(BytecodeGeneratorContext bytecodeGeneratorContext) {
        Scope scope = bytecodeGeneratorContext.getScope();
        LabelNode labelNode = new LabelNode("notMatch");
        Variable createTempVariable = scope.createTempVariable(this.first.getType().getJavaType());
        BytecodeBlock putVariable = new BytecodeBlock().comment("check if first arg is null").append(bytecodeGeneratorContext.generate(this.first)).append(BytecodeUtils.ifWasNullPopAndGoto(scope, labelNode, (Class<?>) Void.TYPE, (Class<?>[]) new Class[0])).dup(this.first.getType().getJavaType()).putVariable(createTempVariable);
        BytecodeNode generate = bytecodeGeneratorContext.generate(this.second);
        putVariable.append(new IfStatement().condition(new BytecodeBlock().append(bytecodeGeneratorContext.generateCall(this.equalsFunction, ImmutableList.of((BytecodeNode) this.firstCast.map(resolvedFunction -> {
            return bytecodeGeneratorContext.generateCall(resolvedFunction, ImmutableList.of(createTempVariable));
        }).orElse(createTempVariable), (BytecodeNode) this.secondCast.map(resolvedFunction2 -> {
            return bytecodeGeneratorContext.generateCall(resolvedFunction2, ImmutableList.of(generate));
        }).orElse(generate)))).append(BytecodeUtils.ifWasNullClearPopAndGoto(scope, labelNode, Void.TYPE, Boolean.TYPE))).ifTrue(new BytecodeBlock().append(bytecodeGeneratorContext.wasNull().set(BytecodeExpressions.constantTrue())).pop(this.first.getType().getJavaType()).pushJavaDefault(this.first.getType().getJavaType())).ifFalse(labelNode));
        return putVariable;
    }
}
