package io.castled.warehouses.connectors.snowflake;

import com.amazonaws.services.s3.model.S3ObjectSummary;
import com.google.common.collect.Lists;
import com.google.inject.Inject;
import com.google.inject.Singleton;
import io.castled.ObjectRegistry;
import io.castled.commons.models.FileFormat;
import io.castled.commons.models.FileStorageNamespace;
import io.castled.commons.streams.RecordInputStream;
import io.castled.commons.streams.S3FilesRecordInputStream;
import io.castled.constants.ConnectorExecutionConstants;
import io.castled.exceptions.CastledRuntimeException;
import io.castled.filestorage.CastledS3Client;
import io.castled.schema.models.RecordSchema;
import io.castled.warehouses.S3BasedDataPoller;
import io.castled.warehouses.WarehouseConfig;
import io.castled.warehouses.connectors.redshift.models.S3PolledFile;
import io.castled.warehouses.models.WarehousePollContext;
import io.castled.warehouses.models.WarehousePollResult;
import java.io.IOException;
import java.sql.Connection;
import java.sql.PreparedStatement;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.List;
import java.util.stream.Collectors;
import org.apache.commons.collections4.CollectionUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Singleton
/* loaded from: input_file:io/castled/warehouses/connectors/snowflake/SnowflakeDataPoller.class */
public class SnowflakeDataPoller extends S3BasedDataPoller {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) SnowflakeDataPoller.class);
    private final SnowflakeClient snowflakeClient;
    private final SnowflakeResultSetSchemaMapper resultSetSchemaMapper;
    private final SnowflakeCsvSchemaMapper snowflakeCsvSchemaMapper;
    private final SnowflakeConnector snowflakeConnector;

    @Inject
    public SnowflakeDataPoller(SnowflakeClient snowflakeClient, SnowflakeConnector snowflakeConnector, SnowflakeResultSetSchemaMapper snowflakeResultSetSchemaMapper, SnowflakeCsvSchemaMapper snowflakeCsvSchemaMapper) {
        this.snowflakeClient = snowflakeClient;
        this.resultSetSchemaMapper = snowflakeResultSetSchemaMapper;
        this.snowflakeCsvSchemaMapper = snowflakeCsvSchemaMapper;
        this.snowflakeConnector = snowflakeConnector;
    }

    @Override // io.castled.warehouses.WarehouseDataPoller
    public WarehousePollResult pollRecords(WarehousePollContext warehousePollContext) {
        try {
            Connection connection = ((SnowflakeConnector) ObjectRegistry.getInstance(SnowflakeConnector.class)).getConnection((SnowflakeWarehouseConfig) warehousePollContext.getWarehouseConfig());
            try {
                List<String> listTables = this.snowflakeClient.listTables(connection, "castled".toUpperCase());
                createUncommittedSnapshot(connection, warehousePollContext, listTables);
                recoverSnapshotFromBackup(connection, listTables, warehousePollContext);
                RecordSchema schemaFromQuery = getSchemaFromQuery(connection, String.format("select * from %s", ConnectorExecutionConstants.getQualifiedUncommittedSnapshot(warehousePollContext.getPipelineUUID())));
                WarehousePollResult build = WarehousePollResult.builder().recordInputStream(createRecordStream(connection, warehousePollContext, listTables, schemaFromQuery)).warehouseSchema(schemaFromQuery).build();
                if (connection != null) {
                    connection.close();
                }
                return build;
            } finally {
            }
        } catch (Exception e) {
            log.error("Snowflake data poll failed for pipeline {}", warehousePollContext.getPipelineUUID());
            throw new CastledRuntimeException(e);
        }
    }

    @Override // io.castled.warehouses.WarehouseDataPoller
    public WarehousePollResult resumePoll(WarehousePollContext warehousePollContext) {
        try {
            SnowflakeWarehouseConfig snowflakeWarehouseConfig = (SnowflakeWarehouseConfig) warehousePollContext.getWarehouseConfig();
            CastledS3Client s3Client = SnowflakeUtils.getS3Client(warehousePollContext.getWarehouseConfig(), warehousePollContext.getDataEncryptionKey());
            List list = (List) s3Client.listObjects(getS3UnloadDirectory(warehousePollContext.getPipelineUUID(), warehousePollContext.getPipelineRunId())).stream().map(this::buildS3PolledFile).collect(Collectors.toList());
            if (CollectionUtils.isEmpty(list)) {
                return pollRecords(warehousePollContext);
            }
            RecordSchema querySchema = getQuerySchema(snowflakeWarehouseConfig, String.format("select * from %s", ConnectorExecutionConstants.getQualifiedUncommittedSnapshot(warehousePollContext.getPipelineUUID())));
            return WarehousePollResult.builder().warehouseSchema(querySchema).recordInputStream(new S3FilesRecordInputStream(querySchema, this.snowflakeCsvSchemaMapper, list, s3Client, FileFormat.CSV, getPipelineRunUnloadDirectory(warehousePollContext.getPipelineUUID(), warehousePollContext.getPipelineRunId()), 20, true)).resumed(true).build();
        } catch (Exception e) {
            log.error("Snowflake data poll resume failed for pipeline {}", warehousePollContext.getPipelineUUID());
            return pollRecords(warehousePollContext);
        }
    }

    private RecordSchema getQuerySchema(SnowflakeWarehouseConfig snowflakeWarehouseConfig, String str) throws SQLException {
        Connection connection = this.snowflakeConnector.getConnection(snowflakeWarehouseConfig);
        try {
            RecordSchema schemaFromQuery = getSchemaFromQuery(connection, str);
            if (connection != null) {
                connection.close();
            }
            return schemaFromQuery;
        } catch (Throwable th) {
            if (connection != null) {
                try {
                    connection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // io.castled.warehouses.WarehouseDataPoller
    public void cleanupPipelineResources(String str, WarehouseConfig warehouseConfig) {
        try {
            Connection connection = ((SnowflakeConnector) ObjectRegistry.getInstance(SnowflakeConnector.class)).getConnection((SnowflakeWarehouseConfig) warehouseConfig);
            try {
                Statement createStatement = connection.createStatement();
                try {
                    createStatement.execute(String.format("drop table if exists %s", ConnectorExecutionConstants.getQualifiedCommittedSnapshot(str)));
                    createStatement.execute(String.format("drop table if exists %s", ConnectorExecutionConstants.getQualifiedUncommittedSnapshot(str)));
                    if (createStatement != null) {
                        createStatement.close();
                    }
                    if (connection != null) {
                        connection.close();
                    }
                } catch (Throwable th) {
                    if (createStatement != null) {
                        try {
                            createStatement.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    }
                    throw th;
                }
            } finally {
            }
        } catch (SQLException e) {
            log.error("Cleanup pipeline resources failed for pipeline {}", str);
            throw new CastledRuntimeException(e);
        }
    }

    private void recoverSnapshotFromBackup(Connection connection, List<String> list, WarehousePollContext warehousePollContext) throws SQLException {
        String qualifiedCommittedSnapshotBkp = ConnectorExecutionConstants.getQualifiedCommittedSnapshotBkp(warehousePollContext.getPipelineUUID());
        if (list.contains(ConnectorExecutionConstants.getCommittedSnapshotBackup(warehousePollContext.getPipelineUUID()).toUpperCase())) {
            Statement createStatement = connection.createStatement();
            try {
                createStatement.execute(String.format("alter table %s rename to %s", qualifiedCommittedSnapshotBkp, ConnectorExecutionConstants.getCommittedSnapshot(warehousePollContext.getPipelineUUID())));
                if (createStatement != null) {
                    createStatement.close();
                }
            } catch (Throwable th) {
                if (createStatement != null) {
                    try {
                        createStatement.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
    }

    private void createUncommittedSnapshot(Connection connection, WarehousePollContext warehousePollContext, List<String> list) throws SQLException {
        if (list.contains(ConnectorExecutionConstants.getUncommittedSnapshot(warehousePollContext.getPipelineUUID()).toUpperCase())) {
            Statement createStatement = connection.createStatement();
            try {
                createStatement.execute(String.format("drop table if exists %s", ConnectorExecutionConstants.getQualifiedUncommittedSnapshot(warehousePollContext.getPipelineUUID())));
                if (createStatement != null) {
                    createStatement.close();
                }
            } catch (Throwable th) {
                if (createStatement != null) {
                    try {
                        createStatement.close();
                    } catch (Throwable th2) {
                        th.addSuppressed(th2);
                    }
                }
                throw th;
            }
        }
        this.snowflakeClient.createTableFromQuery(connection, ConnectorExecutionConstants.getQualifiedUncommittedSnapshot(warehousePollContext.getPipelineUUID()), warehousePollContext.getQuery(), false);
    }

    private void createInternalSchemaIfRequired(Connection connection) throws SQLException {
        Statement createStatement = connection.createStatement();
        try {
            createStatement.execute(String.format("create schema if not exists %s", "castled"));
            if (createStatement != null) {
                createStatement.close();
            }
        } catch (Throwable th) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private RecordInputStream createRecordStream(Connection connection, WarehousePollContext warehousePollContext, List<String> list, RecordSchema recordSchema) throws SQLException, IOException {
        SnowflakeWarehouseConfig snowflakeWarehouseConfig = (SnowflakeWarehouseConfig) warehousePollContext.getWarehouseConfig();
        CastledS3Client s3Client = SnowflakeUtils.getS3Client(warehousePollContext.getWarehouseConfig(), warehousePollContext.getDataEncryptionKey());
        String format = String.format("COPY INTO '%s/' FROM (%s) FILE_FORMAT = (TYPE = 'CSV' COMPRESSION = 'GZIP' FIELD_OPTIONALLY_ENCLOSED_BY = '\"' NULL_IF = ('NULL', 'null') EMPTY_FIELD_AS_NULL=FALSE DATE_FORMAT = 'YYYY-MM-DD' TIMESTAMP_FORMAT = 'YYYY-MM-DD HH24:MI:SS.FF3 TZHTZM' TIME_FORMAT = 'HH24:MI:SS')CREDENTIALS = (AWS_KEY_ID = '%s' AWS_SECRET_KEY = '%s') ENCRYPTION = (TYPE = 'AWS_CSE'  MASTER_KEY = '%s' ) OVERWRITE=TRUE HEADER=TRUE", CastledS3Client.constructS3Path(s3Client.getBucket(), Lists.newArrayList(FileStorageNamespace.PIPELINE_UNLOADS.getNamespace(), warehousePollContext.getPipelineUUID(), String.valueOf(warehousePollContext.getPipelineRunId()))), getDataFetchQuery(warehousePollContext, list), snowflakeWarehouseConfig.getAccessKeyId(), snowflakeWarehouseConfig.getAccessKeySecret(), s3Client.getEncryptionKey());
        Statement createStatement = connection.createStatement();
        try {
            createStatement.execute(format);
            if (createStatement != null) {
                createStatement.close();
            }
            return new S3FilesRecordInputStream(recordSchema, this.snowflakeCsvSchemaMapper, (List) s3Client.listObjects(getS3UnloadDirectory(warehousePollContext.getPipelineUUID(), warehousePollContext.getPipelineRunId())).stream().map(this::buildS3PolledFile).collect(Collectors.toList()), s3Client, FileFormat.CSV, getPipelineRunUnloadDirectory(warehousePollContext.getPipelineUUID(), warehousePollContext.getPipelineRunId()), 20, true);
        } catch (Throwable th) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private S3PolledFile buildS3PolledFile(S3ObjectSummary s3ObjectSummary) {
        return new S3PolledFile(s3ObjectSummary.getBucketName(), s3ObjectSummary.getKey(), s3ObjectSummary.getSize());
    }

    private String getDataFetchQuery(WarehousePollContext warehousePollContext, List<String> list) {
        String qualifiedCommittedSnapshot = ConnectorExecutionConstants.getQualifiedCommittedSnapshot(warehousePollContext.getPipelineUUID());
        String qualifiedUncommittedSnapshot = ConnectorExecutionConstants.getQualifiedUncommittedSnapshot(warehousePollContext.getPipelineUUID());
        return list.contains(ConnectorExecutionConstants.getCommittedSnapshot(warehousePollContext.getPipelineUUID()).toUpperCase()) ? String.format("select * from %s except select * from %s", qualifiedUncommittedSnapshot, qualifiedCommittedSnapshot) : String.format("select * from %s", qualifiedUncommittedSnapshot);
    }

    private RecordSchema getSchemaFromQuery(Connection connection, String str) throws SQLException {
        PreparedStatement prepareStatement = connection.prepareStatement(str);
        try {
            RecordSchema schema = this.resultSetSchemaMapper.getSchema(prepareStatement.getMetaData());
            if (prepareStatement != null) {
                prepareStatement.close();
            }
            return schema;
        } catch (Throwable th) {
            if (prepareStatement != null) {
                try {
                    prepareStatement.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    @Override // io.castled.warehouses.S3BasedDataPoller
    public CastledS3Client getS3Client(WarehouseConfig warehouseConfig, String str) {
        return SnowflakeUtils.getS3Client(warehouseConfig, str);
    }
}
