package io.trino.plugin.jdbc.expression;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.MoreCollectors;
import io.trino.matching.Match;
import io.trino.plugin.base.expression.ConnectorExpressionRule;
import io.trino.spi.connector.ColumnHandle;
import io.trino.spi.connector.ConnectorSession;
import io.trino.spi.expression.Call;
import io.trino.spi.expression.ConnectorExpression;
import io.trino.spi.expression.FunctionName;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.BigintType;
import io.trino.spi.type.DecimalType;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.IntegerType;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import org.assertj.core.api.Assertions;
import org.junit.jupiter.api.Test;

/* loaded from: input_file:io/trino/plugin/jdbc/expression/TestGenericRewrite.class */
public class TestGenericRewrite {
    @Test
    public void testRewriteCall() {
        ParameterizedExpression orElseThrow = apply(new GenericRewrite(Map.of(), connectorSession -> {
            return true;
        }, "add(foo: decimal(p, s), bar: bigint): decimal(rp, rs)", "foo + bar::decimal(rp,rs)"), new Call(DecimalType.createDecimalType(21, 2), new FunctionName("add"), List.of(new Variable("first", DecimalType.createDecimalType(10, 2)), new Variable("second", BigintType.BIGINT)))).orElseThrow();
        Assertions.assertThat(orElseThrow.expression()).isEqualTo("(\"first\") + (\"second\")::decimal(21,2)");
        Assertions.assertThat(orElseThrow.parameters()).isEqualTo(List.of());
    }

    @Test
    public void testRewriteCallWithTypeClass() {
        GenericRewrite genericRewrite = new GenericRewrite(Map.of("integer_class", Set.of("integer", "bigint")), connectorSession -> {
            return true;
        }, "add(foo: integer_class, bar: bigint): integer_class", "foo + bar");
        Assertions.assertThat(apply(genericRewrite, new Call(BigintType.BIGINT, new FunctionName("add"), List.of(new Variable("first", IntegerType.INTEGER), new Variable("second", BigintType.BIGINT)))).orElseThrow().expression()).isEqualTo("(\"first\") + (\"second\")");
        Assertions.assertThat(apply(genericRewrite, new Call(BigintType.BIGINT, new FunctionName("add"), List.of(new Variable("first", DoubleType.DOUBLE), new Variable("second", BigintType.BIGINT))))).isEmpty();
        Assertions.assertThat(apply(genericRewrite, new Call(DoubleType.DOUBLE, new FunctionName("add"), List.of(new Variable("first", IntegerType.INTEGER), new Variable("second", BigintType.BIGINT))))).isEmpty();
    }

    private static Optional<ParameterizedExpression> apply(GenericRewrite genericRewrite, ConnectorExpression connectorExpression) {
        Optional optional = (Optional) genericRewrite.getPattern().match(connectorExpression).collect(MoreCollectors.toOptional());
        return optional.isEmpty() ? Optional.empty() : genericRewrite.rewrite(connectorExpression, ((Match) optional.get()).captures(), new ConnectorExpressionRule.RewriteContext<ParameterizedExpression>() { // from class: io.trino.plugin.jdbc.expression.TestGenericRewrite.1
            public Map<String, ColumnHandle> getAssignments() {
                throw new UnsupportedOperationException();
            }

            public ConnectorSession getSession() {
                throw new UnsupportedOperationException();
            }

            public Optional<ParameterizedExpression> defaultRewrite(ConnectorExpression connectorExpression2) {
                return connectorExpression2 instanceof Variable ? Optional.of(new ParameterizedExpression("\"" + ((Variable) connectorExpression2).getName().replace("\"", "\"\"") + "\"", ImmutableList.of())) : Optional.empty();
            }
        });
    }
}
