package io.trino.sql.planner;

import com.google.common.collect.ImmutableMap;
import io.airlift.slice.Slices;
import io.trino.Session;
import io.trino.operator.scalar.JoniRegexpCasts;
import io.trino.security.AllowAllAccessControl;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.Constant;
import io.trino.spi.expression.FieldDereference;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.StandardFunctions;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import io.trino.spi.type.RealType;
import io.trino.spi.type.RowType;
import io.trino.spi.type.SmallintType;
import io.trino.spi.type.TinyintType;
import io.trino.spi.type.Type;
import io.trino.spi.type.VarcharType;
import io.trino.sql.analyzer.TypeSignatureTranslator;
import io.trino.sql.tree.ArithmeticBinaryExpression;
import io.trino.sql.tree.ArithmeticUnaryExpression;
import io.trino.sql.tree.BetweenPredicate;
import io.trino.sql.tree.Cast;
import io.trino.sql.tree.ComparisonExpression;
import io.trino.sql.tree.DoubleLiteral;
import io.trino.sql.tree.Expression;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.IsNotNullPredicate;
import io.trino.sql.tree.IsNullPredicate;
import io.trino.sql.tree.LikePredicate;
import io.trino.sql.tree.LogicalExpression;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.NotExpression;
import io.trino.sql.tree.NullIfExpression;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.StringLiteral;
import io.trino.sql.tree.SubscriptExpression;
import io.trino.sql.tree.SymbolReference;
import io.trino.testing.DataProviders;
import io.trino.testing.TestingSession;
import io.trino.transaction.TestingTransactionManager;
import io.trino.transaction.TransactionBuilder;
import io.trino.type.JoniRegexpType;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.stream.Stream;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/sql/planner/TestConnectorExpressionTranslator.class */
public class TestConnectorExpressionTranslator {
    private static final Session TEST_SESSION = TestingSession.testSessionBuilder().build();
    private static final TypeAnalyzer TYPE_ANALYZER = TypeAnalyzer.createTestingTypeAnalyzer(TestingPlannerContext.PLANNER_CONTEXT);
    private static final Type ROW_TYPE = RowType.rowType(new RowType.Field[]{RowType.field("int_symbol_1", IntegerType.INTEGER), RowType.field("varchar_symbol_1", VarcharType.createVarcharType(5))});
    private static final VarcharType VARCHAR_TYPE = VarcharType.createVarcharType(25);
    private static final LiteralEncoder LITERAL_ENCODER = new LiteralEncoder(TestingPlannerContext.PLANNER_CONTEXT);
    private static final Map<Symbol, Type> symbols = ImmutableMap.builder().put(new Symbol("double_symbol_1"), DoubleType.DOUBLE).put(new Symbol("double_symbol_2"), DoubleType.DOUBLE).put(new Symbol("row_symbol_1"), ROW_TYPE).put(new Symbol("varchar_symbol_1"), VARCHAR_TYPE).put(new Symbol("boolean_symbol_1"), BooleanType.BOOLEAN).buildOrThrow();
    private static final TypeProvider TYPE_PROVIDER = TypeProvider.copyOf(symbols);
    private static final Map<String, Symbol> variableMappings = (Map) symbols.entrySet().stream().collect(ImmutableMap.toImmutableMap(entry -> {
        return ((Symbol) entry.getKey()).getName();
    }, (v0) -> {
        return v0.getKey();
    }));

    @Test
    public void testTranslateConstant() {
        testTranslateConstant(true, BooleanType.BOOLEAN);
        testTranslateConstant(42L, TinyintType.TINYINT);
        testTranslateConstant(42L, SmallintType.SMALLINT);
        testTranslateConstant(42L, IntegerType.INTEGER);
        testTranslateConstant(42L, BigintType.BIGINT);
        testTranslateConstant(42L, RealType.REAL);
        testTranslateConstant(Double.valueOf(42.0d), DoubleType.DOUBLE);
        testTranslateConstant(4200L, DecimalType.createDecimalType(4, 2));
        testTranslateConstant(4200L, DecimalType.createDecimalType(8, 2));
        testTranslateConstant(Slices.utf8Slice("abc"), VarcharType.createVarcharType(3));
        testTranslateConstant(Slices.utf8Slice("abc"), VarcharType.createVarcharType(33));
    }

    private void testTranslateConstant(Object obj, Type type) {
        assertTranslationRoundTrips(LITERAL_ENCODER.toExpression(TEST_SESSION, obj, type), new Constant(obj, type));
    }

    @Test
    public void testTranslateSymbol() {
        assertTranslationRoundTrips(new SymbolReference("double_symbol_1"), new Variable("double_symbol_1", DoubleType.DOUBLE));
    }

    @Test
    public void testTranslateRowSubscript() {
        assertTranslationRoundTrips(new SubscriptExpression(new SymbolReference("row_symbol_1"), new LongLiteral("1")), new FieldDereference(IntegerType.INTEGER, new Variable("row_symbol_1", ROW_TYPE), 0));
    }

    @Test(dataProvider = "testTranslateLogicalExpressionDataProvider")
    public void testTranslateLogicalExpression(LogicalExpression.Operator operator) {
        assertTranslationRoundTrips(new LogicalExpression(operator, List.of(new ComparisonExpression(ComparisonExpression.Operator.LESS_THAN, new SymbolReference("double_symbol_1"), new SymbolReference("double_symbol_2")), new ComparisonExpression(ComparisonExpression.Operator.EQUAL, new SymbolReference("double_symbol_1"), new SymbolReference("double_symbol_2")))), new Call(BooleanType.BOOLEAN, operator == LogicalExpression.Operator.AND ? StandardFunctions.AND_FUNCTION_NAME : StandardFunctions.OR_FUNCTION_NAME, List.of(new Call(BooleanType.BOOLEAN, StandardFunctions.LESS_THAN_OPERATOR_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DoubleType.DOUBLE), new Variable("double_symbol_2", DoubleType.DOUBLE))), new Call(BooleanType.BOOLEAN, StandardFunctions.EQUAL_OPERATOR_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DoubleType.DOUBLE), new Variable("double_symbol_2", DoubleType.DOUBLE))))));
    }

    @DataProvider
    public Object[][] testTranslateLogicalExpressionDataProvider() {
        return (Object[][]) Stream.of((Object[]) LogicalExpression.Operator.values()).collect(DataProviders.toDataProvider());
    }

    @Test(dataProvider = "testTranslateComparisonExpressionDataProvider")
    public void testTranslateComparisonExpression(ComparisonExpression.Operator operator) {
        assertTranslationRoundTrips(new ComparisonExpression(operator, new SymbolReference("double_symbol_1"), new SymbolReference("double_symbol_2")), new Call(BooleanType.BOOLEAN, ConnectorExpressionTranslator.functionNameForComparisonOperator(operator), List.of(new Variable("double_symbol_1", DoubleType.DOUBLE), new Variable("double_symbol_2", DoubleType.DOUBLE))));
    }

    @DataProvider
    public static Object[][] testTranslateComparisonExpressionDataProvider() {
        return (Object[][]) Stream.of((Object[]) ComparisonExpression.Operator.values()).collect(DataProviders.toDataProvider());
    }

    @Test(dataProvider = "testTranslateArithmeticBinaryDataProvider")
    public void testTranslateArithmeticBinary(ArithmeticBinaryExpression.Operator operator) {
        assertTranslationRoundTrips(new ArithmeticBinaryExpression(operator, new SymbolReference("double_symbol_1"), new SymbolReference("double_symbol_2")), new Call(DoubleType.DOUBLE, ConnectorExpressionTranslator.functionNameForArithmeticBinaryOperator(operator), List.of(new Variable("double_symbol_1", DoubleType.DOUBLE), new Variable("double_symbol_2", DoubleType.DOUBLE))));
    }

    @DataProvider
    public static Object[][] testTranslateArithmeticBinaryDataProvider() {
        return (Object[][]) Stream.of((Object[]) ArithmeticBinaryExpression.Operator.values()).collect(DataProviders.toDataProvider());
    }

    @Test
    public void testTranslateArithmeticUnaryMinus() {
        assertTranslationRoundTrips(new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.MINUS, new SymbolReference("double_symbol_1")), new Call(DoubleType.DOUBLE, StandardFunctions.NEGATE_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DoubleType.DOUBLE))));
    }

    @Test
    public void testTranslateArithmeticUnaryPlus() {
        assertTranslationToConnectorExpression(TEST_SESSION, (Expression) new ArithmeticUnaryExpression(ArithmeticUnaryExpression.Sign.PLUS, new SymbolReference("double_symbol_1")), (ConnectorExpression) new Variable("double_symbol_1", DoubleType.DOUBLE));
    }

    @Test
    public void testTranslateBetween() {
        assertTranslationToConnectorExpression(TEST_SESSION, (Expression) new BetweenPredicate(new SymbolReference("double_symbol_1"), new DoubleLiteral("1.2"), new SymbolReference("double_symbol_2")), (ConnectorExpression) new Call(BooleanType.BOOLEAN, StandardFunctions.AND_FUNCTION_NAME, List.of(new Call(BooleanType.BOOLEAN, StandardFunctions.GREATER_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DoubleType.DOUBLE), new Constant(Double.valueOf(1.2d), DoubleType.DOUBLE))), new Call(BooleanType.BOOLEAN, StandardFunctions.LESS_THAN_OR_EQUAL_OPERATOR_FUNCTION_NAME, List.of(new Variable("double_symbol_1", DoubleType.DOUBLE), new Variable("double_symbol_2", DoubleType.DOUBLE))))));
    }

    @Test
    public void testTranslateIsNull() {
        assertTranslationRoundTrips(new IsNullPredicate(new SymbolReference("varchar_symbol_1")), new Call(BooleanType.BOOLEAN, StandardFunctions.IS_NULL_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))));
    }

    @Test
    public void testTranslateNotExpression() {
        assertTranslationRoundTrips(new NotExpression(new SymbolReference("boolean_symbol_1")), new Call(BooleanType.BOOLEAN, StandardFunctions.NOT_FUNCTION_NAME, List.of(new Variable("boolean_symbol_1", BooleanType.BOOLEAN))));
    }

    @Test
    public void testTranslateIsNotNull() {
        assertTranslationRoundTrips(new IsNotNullPredicate(new SymbolReference("varchar_symbol_1")), new Call(BooleanType.BOOLEAN, StandardFunctions.NOT_FUNCTION_NAME, List.of(new Call(BooleanType.BOOLEAN, StandardFunctions.IS_NULL_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))))));
    }

    @Test
    public void testTranslateCast() {
        assertTranslationRoundTrips(new Cast(new SymbolReference("varchar_symbol_1"), TypeSignatureTranslator.toSqlType(VARCHAR_TYPE)), new Call(VARCHAR_TYPE, StandardFunctions.CAST_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))));
        VarcharType createVarcharType = VarcharType.createVarcharType(VARCHAR_TYPE.getBoundedLength() + 1);
        assertTranslationToConnectorExpression(TEST_SESSION, (Expression) new Cast(new SymbolReference("varchar_symbol_1"), TypeSignatureTranslator.toSqlType(createVarcharType), false, true), (ConnectorExpression) new Call(createVarcharType, StandardFunctions.CAST_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))));
        assertTranslationToConnectorExpression(TEST_SESSION, (Expression) new Cast(new SymbolReference("varchar_symbol_1"), TypeSignatureTranslator.toSqlType(BigintType.BIGINT), true, true), Optional.empty());
    }

    @Test
    public void testTranslateLike() {
        assertTranslationRoundTrips(new LikePredicate(new SymbolReference("varchar_symbol_1"), new StringLiteral("%pattern%"), Optional.empty()), new Call(BooleanType.BOOLEAN, StandardFunctions.LIKE_PATTERN_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), new Constant(Slices.wrappedBuffer("%pattern%".getBytes(StandardCharsets.UTF_8)), VarcharType.createVarcharType("%pattern%".length())))));
        assertTranslationRoundTrips(new LikePredicate(new SymbolReference("varchar_symbol_1"), new StringLiteral("%pattern%"), Optional.of(new StringLiteral("\\"))), new Call(BooleanType.BOOLEAN, StandardFunctions.LIKE_PATTERN_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), new Constant(Slices.wrappedBuffer("%pattern%".getBytes(StandardCharsets.UTF_8)), VarcharType.createVarcharType("%pattern%".length())), new Constant(Slices.wrappedBuffer("\\".getBytes(StandardCharsets.UTF_8)), VarcharType.createVarcharType("\\".length())))));
    }

    @Test
    public void testTranslateNullIf() {
        assertTranslationRoundTrips(new NullIfExpression(new SymbolReference("varchar_symbol_1"), new SymbolReference("varchar_symbol_1")), new Call(VARCHAR_TYPE, StandardFunctions.NULLIF_FUNCTION_NAME, List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), new Variable("varchar_symbol_1", VARCHAR_TYPE))));
    }

    @Test
    public void testTranslateResolvedFunction() {
        TransactionBuilder.transaction(new TestingTransactionManager(), new AllowAllAccessControl()).readOnly().execute(TEST_SESSION, session -> {
            assertTranslationRoundTrips(session, FunctionCallBuilder.resolve(TEST_SESSION, TestingPlannerContext.PLANNER_CONTEXT.getMetadata()).setName(QualifiedName.of("lower")).addArgument(VARCHAR_TYPE, new SymbolReference("varchar_symbol_1")).build(), new Call(VARCHAR_TYPE, new FunctionName("lower"), List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE))));
        });
    }

    @Test
    public void testTranslateRegularExpression() {
        TransactionBuilder.transaction(new TestingTransactionManager(), new AllowAllAccessControl()).readOnly().execute(TEST_SESSION, session -> {
            FunctionCall build = FunctionCallBuilder.resolve(TEST_SESSION, TestingPlannerContext.PLANNER_CONTEXT.getMetadata()).setName(QualifiedName.of("regexp_like")).addArgument(VARCHAR_TYPE, new SymbolReference("varchar_symbol_1")).addArgument(JoniRegexpType.JONI_REGEXP, LITERAL_ENCODER.toExpression(TEST_SESSION, JoniRegexpCasts.joniRegexp(Slices.utf8Slice("a+")), JoniRegexpType.JONI_REGEXP)).build();
            Call call = new Call(BooleanType.BOOLEAN, new FunctionName("regexp_like"), List.of(new Variable("varchar_symbol_1", VARCHAR_TYPE), new Constant(Slices.utf8Slice("a+"), VarcharType.createVarcharType(2))));
            FunctionCall build2 = FunctionCallBuilder.resolve(TEST_SESSION, TestingPlannerContext.PLANNER_CONTEXT.getMetadata()).setName(QualifiedName.of("regexp_like")).addArgument(VARCHAR_TYPE, new SymbolReference("varchar_symbol_1")).addArgument(JoniRegexpType.JONI_REGEXP, new Cast(new StringLiteral("a+"), TypeSignatureTranslator.toSqlType(JoniRegexpType.JONI_REGEXP))).build();
            assertTranslationToConnectorExpression(session, (Expression) build, (ConnectorExpression) call);
            assertTranslationFromConnectorExpression(session, call, build2);
        });
    }

    private void assertTranslationRoundTrips(Expression expression, ConnectorExpression connectorExpression) {
        assertTranslationRoundTrips(TEST_SESSION, expression, connectorExpression);
    }

    private void assertTranslationRoundTrips(Session session, Expression expression, ConnectorExpression connectorExpression) {
        assertTranslationToConnectorExpression(session, expression, Optional.of(connectorExpression));
        assertTranslationFromConnectorExpression(session, connectorExpression, expression);
    }

    private void assertTranslationToConnectorExpression(Session session, Expression expression, ConnectorExpression connectorExpression) {
        assertTranslationToConnectorExpression(session, expression, Optional.of(connectorExpression));
    }

    private void assertTranslationToConnectorExpression(Session session, Expression expression, Optional<ConnectorExpression> optional) {
        Optional translate = ConnectorExpressionTranslator.translate(session, expression, TYPE_ANALYZER, TYPE_PROVIDER, TestingPlannerContext.PLANNER_CONTEXT);
        Assert.assertEquals(optional.isPresent(), translate.isPresent());
        translate.ifPresent(connectorExpression -> {
            Assert.assertEquals(connectorExpression, optional.get());
        });
    }

    private void assertTranslationFromConnectorExpression(Session session, ConnectorExpression connectorExpression, Expression expression) {
        Assert.assertEquals(ConnectorExpressionTranslator.translate(session, connectorExpression, TestingPlannerContext.PLANNER_CONTEXT, variableMappings, LITERAL_ENCODER), expression);
    }
}
