package io.trino.verifier;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.util.concurrent.SimpleTimeLimiter;
import com.google.common.util.concurrent.UncheckedTimeoutException;
import io.airlift.units.Duration;
import io.trino.sql.SqlFormatter;
import io.trino.sql.parser.SqlParser;
import io.trino.sql.tree.CreateTable;
import io.trino.sql.tree.CreateTableAsSelect;
import io.trino.sql.tree.DropTable;
import io.trino.sql.tree.FunctionCall;
import io.trino.sql.tree.Identifier;
import io.trino.sql.tree.Insert;
import io.trino.sql.tree.LikeClause;
import io.trino.sql.tree.Limit;
import io.trino.sql.tree.LongLiteral;
import io.trino.sql.tree.NodeLocation;
import io.trino.sql.tree.QualifiedName;
import io.trino.sql.tree.QuerySpecification;
import io.trino.sql.tree.SaveMode;
import io.trino.sql.tree.Select;
import io.trino.sql.tree.SingleColumn;
import io.trino.sql.tree.Statement;
import io.trino.sql.tree.Table;
import java.sql.Connection;
import java.sql.DriverManager;
import java.sql.ResultSet;
import java.sql.ResultSetMetaData;
import java.sql.SQLClientInfoException;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;

/* loaded from: input_file:io/trino/verifier/QueryRewriter.class */
public class QueryRewriter {
    private static final Set<Integer> APPROXIMATE_TYPES = ImmutableSet.of(7, 6, 8);
    private final SqlParser parser;
    private final String gatewayUrl;
    private final QualifiedName rewritePrefix;
    private final Optional<String> catalogOverride;
    private final Optional<String> schemaOverride;
    private final Optional<String> usernameOverride;
    private final Optional<String> passwordOverride;
    private final int doublePrecision;
    private final Duration timeout;

    /* JADX INFO: Access modifiers changed from: private */
    /* loaded from: input_file:io/trino/verifier/QueryRewriter$Column.class */
    public static class Column {
        private final String name;
        private final boolean approximateType;

        private Column(String str, boolean z) {
            this.name = str;
            this.approximateType = z;
        }

        public String getName() {
            return this.name;
        }

        public boolean isApproximateType() {
            return this.approximateType;
        }
    }

    /* loaded from: input_file:io/trino/verifier/QueryRewriter$QueryRewriteException.class */
    public static class QueryRewriteException extends Exception {
        public QueryRewriteException(String str) {
            super(str);
        }
    }

    public QueryRewriter(SqlParser sqlParser, String str, QualifiedName qualifiedName, Optional<String> optional, Optional<String> optional2, Optional<String> optional3, Optional<String> optional4, int i, Duration duration) {
        this.parser = (SqlParser) Objects.requireNonNull(sqlParser, "parser is null");
        this.gatewayUrl = (String) Objects.requireNonNull(str, "gatewayUrl is null");
        this.rewritePrefix = (QualifiedName) Objects.requireNonNull(qualifiedName, "rewritePrefix is null");
        this.catalogOverride = (Optional) Objects.requireNonNull(optional, "catalogOverride is null");
        this.schemaOverride = (Optional) Objects.requireNonNull(optional2, "schemaOverride is null");
        this.usernameOverride = (Optional) Objects.requireNonNull(optional3, "usernameOverride is null");
        this.passwordOverride = (Optional) Objects.requireNonNull(optional4, "passwordOverride is null");
        this.doublePrecision = i;
        this.timeout = (Duration) Objects.requireNonNull(duration, "timeout is null");
    }

    public Query shadowQuery(Query query) throws QueryRewriteException, SQLException {
        if (VerifyCommand.statementToQueryType(this.parser, query.getQuery()) == QueryType.READ) {
            return query;
        }
        if (!query.getPreQueries().isEmpty()) {
            throw new QueryRewriteException("Cannot rewrite queries that use pre-queries");
        }
        if (!query.getPostQueries().isEmpty()) {
            throw new QueryRewriteException("Cannot rewrite queries that use post-queries");
        }
        Statement createStatement = this.parser.createStatement(query.getQuery());
        Connection connection = DriverManager.getConnection(this.gatewayUrl, this.usernameOverride.orElse(query.getUsername()), this.passwordOverride.orElse(query.getPassword()));
        try {
            trySetConnectionProperties(query, connection);
            if (createStatement instanceof CreateTableAsSelect) {
                Query rewriteCreateTableAsSelect = rewriteCreateTableAsSelect(connection, query, (CreateTableAsSelect) createStatement);
                if (connection != null) {
                    connection.close();
                }
                return rewriteCreateTableAsSelect;
            }
            if (!(createStatement instanceof Insert)) {
                if (connection != null) {
                    connection.close();
                }
                throw new QueryRewriteException("Unsupported query type: " + String.valueOf(createStatement.getClass()));
            }
            Query rewriteInsertQuery = rewriteInsertQuery(connection, query, (Insert) createStatement);
            if (connection != null) {
                connection.close();
            }
            return rewriteInsertQuery;
        } catch (Throwable th) {
            if (connection != null) {
                try {
                    connection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private Query rewriteCreateTableAsSelect(Connection connection, Query query, CreateTableAsSelect createTableAsSelect) throws SQLException, QueryRewriteException {
        QualifiedName generateTemporaryTableName = generateTemporaryTableName(createTableAsSelect.getName());
        String formatSql = SqlFormatter.formatSql(new CreateTableAsSelect(generateTemporaryTableName, createTableAsSelect.getQuery(), createTableAsSelect.getSaveMode(), createTableAsSelect.getProperties(), createTableAsSelect.isWithData(), createTableAsSelect.getColumnAliases(), Optional.empty()));
        return new Query(query.getCatalog(), query.getSchema(), ImmutableList.of(formatSql), checksumSql(getColumns(connection, createTableAsSelect), generateTemporaryTableName), ImmutableList.of(dropTableSql(generateTemporaryTableName)), query.getUsername(), query.getPassword(), query.getSessionProperties());
    }

    private Query rewriteInsertQuery(Connection connection, Query query, Insert insert) throws SQLException, QueryRewriteException {
        QualifiedName generateTemporaryTableName = generateTemporaryTableName(insert.getTarget());
        String formatSql = SqlFormatter.formatSql(new CreateTable(generateTemporaryTableName, ImmutableList.of(new LikeClause(insert.getTarget(), Optional.of(LikeClause.PropertiesOption.INCLUDING))), SaveMode.IGNORE, ImmutableList.of(), Optional.empty()));
        String formatSql2 = SqlFormatter.formatSql(new Insert(new Table(generateTemporaryTableName), insert.getColumns(), insert.getQuery()));
        return new Query(query.getCatalog(), query.getSchema(), ImmutableList.of(formatSql, formatSql2), checksumSql(getColumnsForTable(connection, query.getCatalog(), query.getSchema(), insert.getTarget().toString()), generateTemporaryTableName), ImmutableList.of(dropTableSql(generateTemporaryTableName)), query.getUsername(), query.getPassword(), query.getSessionProperties());
    }

    private QualifiedName generateTemporaryTableName(QualifiedName qualifiedName) {
        ArrayList arrayList = new ArrayList();
        int size = qualifiedName.getOriginalParts().size();
        int size2 = this.rewritePrefix.getOriginalParts().size();
        if (size > size2) {
            arrayList.addAll(qualifiedName.getOriginalParts().subList(0, size - size2));
        }
        arrayList.addAll(this.rewritePrefix.getOriginalParts());
        arrayList.set(arrayList.size() - 1, new Identifier(createTemporaryTableName()));
        return QualifiedName.of(arrayList);
    }

    private void trySetConnectionProperties(Query query, Connection connection) throws SQLException {
        try {
            connection.setClientInfo("ApplicationName", "verifier-rewrite");
            connection.setCatalog(this.catalogOverride.orElse(query.getCatalog()));
            connection.setSchema(this.schemaOverride.orElse(query.getSchema()));
        } catch (SQLClientInfoException e) {
        }
    }

    private String createTemporaryTableName() {
        return this.rewritePrefix.getSuffix() + UUID.randomUUID().toString().replace("-", "");
    }

    private List<Column> getColumnsForTable(Connection connection, String str, String str2, String str3) throws SQLException {
        ResultSet columns = connection.getMetaData().getColumns(str, escapeLikeExpression(connection, str2), escapeLikeExpression(connection, str3), null);
        ImmutableList.Builder builder = ImmutableList.builder();
        while (columns.next()) {
            builder.add(new Column(columns.getString("COLUMN_NAME"), APPROXIMATE_TYPES.contains(Integer.valueOf(columns.getInt("DATA_TYPE")))));
        }
        return builder.build();
    }

    /* JADX WARN: Finally extract failed */
    private List<Column> getColumns(Connection connection, CreateTableAsSelect createTableAsSelect) throws SQLException {
        io.trino.sql.tree.Query query;
        io.trino.sql.tree.Query query2 = createTableAsSelect.getQuery();
        QuerySpecification queryBody = query2.getQueryBody();
        if (queryBody instanceof QuerySpecification) {
            QuerySpecification querySpecification = queryBody;
            query = new io.trino.sql.tree.Query(ImmutableList.of(), query2.getWith(), new QuerySpecification(querySpecification.getSelect(), querySpecification.getFrom(), querySpecification.getWhere(), querySpecification.getGroupBy(), querySpecification.getHaving(), querySpecification.getWindows(), querySpecification.getOrderBy(), querySpecification.getOffset(), Optional.of(new Limit(new LongLiteral("0")))), Optional.empty(), Optional.empty(), Optional.empty());
        } else {
            query = new io.trino.sql.tree.Query(ImmutableList.of(), query2.getWith(), queryBody, Optional.empty(), Optional.empty(), Optional.of(new Limit(new LongLiteral("0"))));
        }
        ImmutableList.Builder builder = ImmutableList.builder();
        java.sql.Statement createStatement = connection.createStatement();
        try {
            ExecutorService newSingleThreadExecutor = Executors.newSingleThreadExecutor();
            try {
                try {
                    ResultSet executeQuery = ((java.sql.Statement) SimpleTimeLimiter.create(newSingleThreadExecutor).newProxy(createStatement, java.sql.Statement.class, this.timeout.toMillis(), TimeUnit.MILLISECONDS)).executeQuery(SqlFormatter.formatSql(query));
                    try {
                        ResultSetMetaData metaData = executeQuery.getMetaData();
                        for (int i = 1; i <= metaData.getColumnCount(); i++) {
                            builder.add(new Column(metaData.getColumnName(i), APPROXIMATE_TYPES.contains(Integer.valueOf(metaData.getColumnType(i)))));
                        }
                        if (executeQuery != null) {
                            executeQuery.close();
                        }
                        newSingleThreadExecutor.shutdownNow();
                        if (createStatement != null) {
                            createStatement.close();
                        }
                        return builder.build();
                    } catch (Throwable th) {
                        if (executeQuery != null) {
                            try {
                                executeQuery.close();
                            } catch (Throwable th2) {
                                th.addSuppressed(th2);
                            }
                        }
                        throw th;
                    }
                } catch (Throwable th3) {
                    newSingleThreadExecutor.shutdownNow();
                    throw th3;
                }
            } catch (UncheckedTimeoutException e) {
                throw new SQLException("SQL statement execution timed out", (Throwable) e);
            }
        } catch (Throwable th4) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th5) {
                    th4.addSuppressed(th5);
                }
            }
            throw th4;
        }
    }

    private String checksumSql(List<Column> list, QualifiedName qualifiedName) throws QueryRewriteException {
        if (list.isEmpty()) {
            throw new QueryRewriteException("Table " + String.valueOf(qualifiedName) + " has no columns");
        }
        ImmutableList.Builder builder = ImmutableList.builder();
        for (Column column : list) {
            FunctionCall identifier = new Identifier(column.getName());
            if (column.isApproximateType()) {
                identifier = new FunctionCall(QualifiedName.of("round"), ImmutableList.of(identifier, new LongLiteral(Integer.toString(this.doublePrecision))));
            }
            builder.add(new SingleColumn(new FunctionCall(QualifiedName.of("checksum"), ImmutableList.of(identifier))));
        }
        return SqlFormatter.formatSql(new QuerySpecification(new Select(false, builder.build()), Optional.of(new Table(qualifiedName)), Optional.empty(), Optional.empty(), Optional.empty(), ImmutableList.of(), Optional.empty(), Optional.empty(), Optional.empty()));
    }

    private static String dropTableSql(QualifiedName qualifiedName) {
        return SqlFormatter.formatSql(new DropTable(new NodeLocation(1, 1), qualifiedName, true));
    }

    private static String escapeLikeExpression(Connection connection, String str) throws SQLException {
        String searchStringEscape = connection.getMetaData().getSearchStringEscape();
        return str.replace(searchStringEscape, searchStringEscape + searchStringEscape).replace("_", searchStringEscape + "_").replace("%", searchStringEscape + "%");
    }
}
