package io.trino.plugin.mysql;

import io.trino.plugin.base.mapping.DefaultIdentifierMapping;
import io.trino.plugin.jdbc.BaseJdbcConfig;
import io.trino.plugin.jdbc.ColumnMapping;
import io.trino.plugin.jdbc.DefaultQueryBuilder;
import io.trino.plugin.jdbc.JdbcClient;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.JdbcStatisticsConfig;
import io.trino.plugin.jdbc.JdbcTypeHandle;
import io.trino.plugin.jdbc.logging.RemoteQueryModifier;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.BooleanType;
import io.trino.spi.type.DoubleType;
import io.trino.testing.TestingConnectorSession;
import io.trino.type.InternalTypeManager;
import java.sql.Connection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import org.assertj.core.api.Assertions;
import org.testng.Assert;
import org.testng.annotations.Test;

/* loaded from: input_file:io/trino/plugin/mysql/TestMySqlClient.class */
public class TestMySqlClient {
    private static final JdbcColumnHandle BIGINT_COLUMN = JdbcColumnHandle.builder().setColumnName("c_bigint").setColumnType(BigintType.BIGINT).setJdbcTypeHandle(new JdbcTypeHandle(-5, Optional.of("int8"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())).build();
    private static final JdbcColumnHandle DOUBLE_COLUMN = JdbcColumnHandle.builder().setColumnName("c_double").setColumnType(DoubleType.DOUBLE).setJdbcTypeHandle(new JdbcTypeHandle(8, Optional.of("double"), Optional.empty(), Optional.empty(), Optional.empty(), Optional.empty())).build();
    private static final JdbcClient JDBC_CLIENT = new MySqlClient(new BaseJdbcConfig(), new JdbcStatisticsConfig(), connectorSession -> {
        throw new UnsupportedOperationException();
    }, new DefaultQueryBuilder(RemoteQueryModifier.NONE), InternalTypeManager.TESTING_TYPE_MANAGER, new DefaultIdentifierMapping(), RemoteQueryModifier.NONE);

    @Test
    public void testImplementCount() {
        Variable variable = new Variable("v_bigint", BigintType.BIGINT);
        Variable variable2 = new Variable("v_double", BigintType.BIGINT);
        Optional of = Optional.of(new Variable("a_filter", BooleanType.BOOLEAN));
        testImplementAggregation(new AggregateFunction("count", BigintType.BIGINT, List.of(), List.of(), false, Optional.empty()), Map.of(), Optional.of("count(*)"));
        testImplementAggregation(new AggregateFunction("count", BigintType.BIGINT, List.of(variable), List.of(), false, Optional.empty()), Map.of(variable.getName(), BIGINT_COLUMN), Optional.of("count(`c_bigint`)"));
        testImplementAggregation(new AggregateFunction("count", BigintType.BIGINT, List.of(variable2), List.of(), false, Optional.empty()), Map.of(variable2.getName(), DOUBLE_COLUMN), Optional.of("count(`c_double`)"));
        testImplementAggregation(new AggregateFunction("count", BigintType.BIGINT, List.of(variable), List.of(), true, Optional.empty()), Map.of(variable.getName(), BIGINT_COLUMN), Optional.empty());
        testImplementAggregation(new AggregateFunction("count", BigintType.BIGINT, List.of(), List.of(), false, of), Map.of(), Optional.empty());
        testImplementAggregation(new AggregateFunction("count", BigintType.BIGINT, List.of(variable), List.of(), false, of), Map.of(variable.getName(), BIGINT_COLUMN), Optional.empty());
    }

    @Test
    public void testImplementSum() {
        Variable variable = new Variable("v_bigint", BigintType.BIGINT);
        Variable variable2 = new Variable("v_double", DoubleType.DOUBLE);
        Optional of = Optional.of(new Variable("a_filter", BooleanType.BOOLEAN));
        testImplementAggregation(new AggregateFunction("sum", BigintType.BIGINT, List.of(variable), List.of(), false, Optional.empty()), Map.of(variable.getName(), BIGINT_COLUMN), Optional.of("sum(`c_bigint`)"));
        testImplementAggregation(new AggregateFunction("sum", DoubleType.DOUBLE, List.of(variable2), List.of(), false, Optional.empty()), Map.of(variable2.getName(), DOUBLE_COLUMN), Optional.of("sum(`c_double`)"));
        testImplementAggregation(new AggregateFunction("sum", BigintType.BIGINT, List.of(variable), List.of(), true, Optional.empty()), Map.of(variable.getName(), BIGINT_COLUMN), Optional.of("sum(DISTINCT `c_bigint`)"));
        testImplementAggregation(new AggregateFunction("sum", DoubleType.DOUBLE, List.of(variable), List.of(), true, Optional.empty()), Map.of(variable.getName(), DOUBLE_COLUMN), Optional.of("sum(DISTINCT `c_double`)"));
        testImplementAggregation(new AggregateFunction("sum", BigintType.BIGINT, List.of(variable), List.of(), false, of), Map.of(variable.getName(), BIGINT_COLUMN), Optional.empty());
    }

    private static void testImplementAggregation(AggregateFunction aggregateFunction, Map<String, ColumnHandle> map, Optional<String> optional) {
        Optional implementAggregation = JDBC_CLIENT.implementAggregation(TestingConnectorSession.SESSION, aggregateFunction, map);
        if (optional.isEmpty()) {
            Assertions.assertThat(implementAggregation).isEmpty();
            return;
        }
        Assertions.assertThat(implementAggregation).isPresent();
        Assert.assertEquals(((JdbcExpression) implementAggregation.get()).getExpression(), optional.get());
        Optional columnMapping = JDBC_CLIENT.toColumnMapping(TestingConnectorSession.SESSION, (Connection) null, ((JdbcExpression) implementAggregation.get()).getJdbcTypeHandle());
        Assert.assertTrue(columnMapping.isPresent(), "No mapping for: " + ((JdbcExpression) implementAggregation.get()).getJdbcTypeHandle());
        Assert.assertEquals(((ColumnMapping) columnMapping.get()).getType(), aggregateFunction.getOutputType());
    }
}
