package io.trino.sql.routine;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.MoreCollectors;
import io.trino.Session;
import io.trino.metadata.ResolvedFunction;
import io.trino.spi.function.OperatorType;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.Type;
import io.trino.sql.planner.TestingPlannerContext;
import io.trino.sql.relational.Expressions;
import io.trino.sql.relational.InputReferenceExpression;
import io.trino.sql.relational.RowExpression;
import io.trino.sql.routine.ir.IrBlock;
import io.trino.sql.routine.ir.IrBreak;
import io.trino.sql.routine.ir.IrContinue;
import io.trino.sql.routine.ir.IrIf;
import io.trino.sql.routine.ir.IrLabel;
import io.trino.sql.routine.ir.IrLoop;
import io.trino.sql.routine.ir.IrRepeat;
import io.trino.sql.routine.ir.IrReturn;
import io.trino.sql.routine.ir.IrRoutine;
import io.trino.sql.routine.ir.IrSet;
import io.trino.sql.routine.ir.IrStatement;
import io.trino.sql.routine.ir.IrVariable;
import io.trino.sql.routine.ir.IrWhile;
import io.trino.testing.TestingSession;
import io.trino.util.Reflection;
import java.lang.invoke.MethodHandle;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Supplier;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/sql/routine/TestSqlRoutineCompiler.class */
public class TestSqlRoutineCompiler {
    private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build();
    private final SqlRoutineCompiler compiler = new SqlRoutineCompiler(TestingPlannerContext.PLANNER_CONTEXT.getFunctionManager());

    @Test
    public void testSimpleExpression() throws Throwable {
        IrVariable irVariable = new IrVariable(0, BigintType.BIGINT, Expressions.constantNull(BigintType.BIGINT));
        IrVariable irVariable2 = new IrVariable(1, BigintType.BIGINT, Expressions.constant(99L, BigintType.BIGINT));
        MethodHandle compile = compile(new IrRoutine(BigintType.BIGINT, parameters(irVariable), new IrBlock(variables(irVariable2), statements(new IrSet(irVariable2, Expressions.call(operator(OperatorType.MULTIPLY, BigintType.BIGINT, BigintType.BIGINT), new RowExpression[]{reference(irVariable2), reference(irVariable)})), new IrReturn(reference(irVariable2))))));
        Assertions.assertThat((Object) compile.invoke(0L)).isEqualTo(0L);
        Assertions.assertThat((Object) compile.invoke(1L)).isEqualTo(99L);
        Assertions.assertThat((Object) compile.invoke(42L)).isEqualTo(4158L);
        Assertions.assertThat((Object) compile.invoke(123L)).isEqualTo(12177L);
    }

    @Test
    public void testFibonacciWhileLoop() throws Throwable {
        IrVariable irVariable = new IrVariable(0, BigintType.BIGINT, Expressions.constantNull(BigintType.BIGINT));
        IrVariable irVariable2 = new IrVariable(1, BigintType.BIGINT, Expressions.constant(1L, BigintType.BIGINT));
        IrVariable irVariable3 = new IrVariable(2, BigintType.BIGINT, Expressions.constant(1L, BigintType.BIGINT));
        IrVariable irVariable4 = new IrVariable(3, BigintType.BIGINT, Expressions.constantNull(BigintType.BIGINT));
        ResolvedFunction operator = operator(OperatorType.ADD, BigintType.BIGINT, BigintType.BIGINT);
        ResolvedFunction operator2 = operator(OperatorType.SUBTRACT, BigintType.BIGINT, BigintType.BIGINT);
        ResolvedFunction operator3 = operator(OperatorType.LESS_THAN, BigintType.BIGINT, BigintType.BIGINT);
        MethodHandle compile = compile(new IrRoutine(BigintType.BIGINT, parameters(irVariable), new IrBlock(variables(irVariable2, irVariable3, irVariable4), statements(new IrIf(Expressions.call(operator(OperatorType.LESS_THAN_OR_EQUAL, BigintType.BIGINT, BigintType.BIGINT), new RowExpression[]{reference(irVariable), Expressions.constant(2L, BigintType.BIGINT)}), new IrReturn(Expressions.constant(1L, BigintType.BIGINT)), Optional.empty()), new IrWhile(Optional.empty(), Expressions.call(operator3, new RowExpression[]{Expressions.constant(2L, BigintType.BIGINT), reference(irVariable)}), new IrBlock(variables(new IrVariable[0]), statements(new IrSet(irVariable, Expressions.call(operator2, new RowExpression[]{reference(irVariable), Expressions.constant(1L, BigintType.BIGINT)})), new IrSet(irVariable4, Expressions.call(operator, new RowExpression[]{reference(irVariable2), reference(irVariable3)})), new IrSet(irVariable2, reference(irVariable3)), new IrSet(irVariable3, reference(irVariable4))))), new IrReturn(reference(irVariable4))))));
        Assertions.assertThat((Object) compile.invoke(1L)).isEqualTo(1L);
        Assertions.assertThat((Object) compile.invoke(2L)).isEqualTo(1L);
        Assertions.assertThat((Object) compile.invoke(3L)).isEqualTo(2L);
        Assertions.assertThat((Object) compile.invoke(4L)).isEqualTo(3L);
        Assertions.assertThat((Object) compile.invoke(5L)).isEqualTo(5L);
        Assertions.assertThat((Object) compile.invoke(6L)).isEqualTo(8L);
        Assertions.assertThat((Object) compile.invoke(7L)).isEqualTo(13L);
        Assertions.assertThat((Object) compile.invoke(8L)).isEqualTo(21L);
    }

    @Test
    public void testBreakContinue() throws Throwable {
        IrVariable irVariable = new IrVariable(0, BigintType.BIGINT, Expressions.constant(0L, BigintType.BIGINT));
        IrVariable irVariable2 = new IrVariable(1, BigintType.BIGINT, Expressions.constant(0L, BigintType.BIGINT));
        ResolvedFunction operator = operator(OperatorType.ADD, BigintType.BIGINT, BigintType.BIGINT);
        ResolvedFunction operator2 = operator(OperatorType.LESS_THAN, BigintType.BIGINT, BigintType.BIGINT);
        IrLabel irLabel = new IrLabel("test");
        Assertions.assertThat((Object) compile(new IrRoutine(BigintType.BIGINT, parameters(new IrVariable[0]), new IrBlock(variables(irVariable, irVariable2), statements(new IrWhile(Optional.of(irLabel), Expressions.call(operator2, new RowExpression[]{reference(irVariable), Expressions.constant(10L, BigintType.BIGINT)}), new IrBlock(variables(new IrVariable[0]), statements(new IrSet(irVariable, Expressions.call(operator, new RowExpression[]{reference(irVariable), Expressions.constant(1L, BigintType.BIGINT)})), new IrIf(Expressions.call(operator2, new RowExpression[]{reference(irVariable), Expressions.constant(3L, BigintType.BIGINT)}), new IrContinue(irLabel), Optional.empty()), new IrSet(irVariable2, Expressions.call(operator, new RowExpression[]{reference(irVariable2), Expressions.constant(1L, BigintType.BIGINT)})), new IrIf(Expressions.call(operator2, new RowExpression[]{Expressions.constant(6L, BigintType.BIGINT), reference(irVariable)}), new IrBreak(irLabel), Optional.empty())))), new IrReturn(reference(irVariable2)))))).invoke()).isEqualTo(5L);
    }

    @Test
    public void testInterruptionWhile() throws Throwable {
        assertRoutineInterruption(() -> {
            return new IrWhile(Optional.empty(), Expressions.constant(true, BooleanType.BOOLEAN), new IrBlock(variables(new IrVariable[0]), statements(new IrStatement[0])));
        });
    }

    @Test
    public void testInterruptionRepeat() throws Throwable {
        assertRoutineInterruption(() -> {
            return new IrRepeat(Optional.empty(), Expressions.constant(false, BooleanType.BOOLEAN), new IrBlock(variables(new IrVariable[0]), statements(new IrStatement[0])));
        });
    }

    @Test
    public void testInterruptionLoop() throws Throwable {
        assertRoutineInterruption(() -> {
            return new IrLoop(Optional.empty(), new IrBlock(variables(new IrVariable[0]), statements(new IrStatement[0])));
        });
    }

    private void assertRoutineInterruption(Supplier<IrStatement> supplier) throws Throwable {
        MethodHandle compile = compile(new IrRoutine(BigintType.BIGINT, parameters(new IrVariable[0]), new IrBlock(variables(new IrVariable[0]), statements(supplier.get(), new IrReturn(Expressions.constant((Object) null, BigintType.BIGINT))))));
        AtomicBoolean atomicBoolean = new AtomicBoolean();
        Thread thread = new Thread(() -> {
            Objects.requireNonNull(compile);
            Assertions.assertThatThrownBy(compile::invoke).hasMessageContaining("Thread interrupted");
            atomicBoolean.set(true);
        });
        thread.start();
        thread.interrupt();
        thread.join(TimeUnit.SECONDS.toMillis(10L));
        Assertions.assertThat(atomicBoolean).isTrue();
    }

    private MethodHandle compile(IrRoutine irRoutine) throws Throwable {
        Class compileClass = this.compiler.compileClass(irRoutine);
        return ((MethodHandle) Arrays.stream(compileClass.getMethods()).filter(method -> {
            return method.getName().equals("run");
        }).map(Reflection::methodHandle).collect(MoreCollectors.onlyElement())).bindTo((Object) Reflection.constructorMethodHandle(compileClass, new Class[0]).invoke()).bindTo(TEST_SESSION.toConnectorSession());
    }

    private static List<IrVariable> parameters(IrVariable... irVariableArr) {
        return ImmutableList.copyOf(irVariableArr);
    }

    private static List<IrVariable> variables(IrVariable... irVariableArr) {
        return ImmutableList.copyOf(irVariableArr);
    }

    private static List<IrStatement> statements(IrStatement... irStatementArr) {
        return ImmutableList.copyOf(irStatementArr);
    }

    private static RowExpression reference(IrVariable irVariable) {
        return new InputReferenceExpression(irVariable.field(), irVariable.type());
    }

    private static ResolvedFunction operator(OperatorType operatorType, Type... typeArr) {
        return TestingPlannerContext.PLANNER_CONTEXT.getMetadata().resolveOperator(operatorType, ImmutableList.copyOf(typeArr));
    }
}
