package io.trino.plugin.jdbc.expression;

import com.google.common.base.Verify;
import io.trino.matching.Capture;
import io.trino.matching.Captures;
import io.trino.matching.Pattern;
import io.trino.plugin.jdbc.JdbcColumnHandle;
import io.trino.plugin.jdbc.JdbcExpression;
import io.trino.plugin.jdbc.expression.AggregateFunctionRule;
import io.trino.spi.connector.AggregateFunction;
import io.trino.spi.expression.Variable;
import io.trino.spi.type.DoubleType;
import io.trino.spi.type.RealType;
import java.util.List;
import java.util.Optional;

/* loaded from: input_file:io/trino/plugin/jdbc/expression/ImplementCorr.class */
public class ImplementCorr implements AggregateFunctionRule {
    private static final Capture<List<Variable>> INPUTS = Capture.newCapture();

    @Override // io.trino.plugin.jdbc.expression.AggregateFunctionRule
    public Pattern<AggregateFunction> getPattern() {
        return AggregateFunctionPatterns.basicAggregation().with(AggregateFunctionPatterns.functionName().equalTo("corr")).with(AggregateFunctionPatterns.inputs().matching(AggregateFunctionPatterns.variables().matching(AggregateFunctionPatterns.expressionTypes(RealType.REAL, RealType.REAL).or(AggregateFunctionPatterns.expressionTypes(DoubleType.DOUBLE, DoubleType.DOUBLE))).capturedAs(INPUTS)));
    }

    @Override // io.trino.plugin.jdbc.expression.AggregateFunctionRule
    public Optional<JdbcExpression> rewrite(AggregateFunction aggregateFunction, Captures captures, AggregateFunctionRule.RewriteContext rewriteContext) {
        List list = (List) captures.get(INPUTS);
        Verify.verify(list.size() == 2);
        JdbcColumnHandle jdbcColumnHandle = (JdbcColumnHandle) rewriteContext.getAssignment(((Variable) list.get(0)).getName());
        JdbcColumnHandle jdbcColumnHandle2 = (JdbcColumnHandle) rewriteContext.getAssignment(((Variable) list.get(1)).getName());
        Verify.verify(aggregateFunction.getOutputType().equals(jdbcColumnHandle.getColumnType()));
        return Optional.of(new JdbcExpression(String.format("corr(%s, %s)", rewriteContext.getIdentifierQuote().apply(jdbcColumnHandle.getColumnName()), rewriteContext.getIdentifierQuote().apply(jdbcColumnHandle2.getColumnName())), jdbcColumnHandle.getJdbcTypeHandle()));
    }
}
