package io.castled.warehouses.connectors.redshift;

import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import io.castled.ObjectRegistry;
import io.castled.constants.ConnectorExecutionConstants;
import io.castled.exceptions.CastledRuntimeException;
import io.castled.filemanager.RawFileWriter;
import io.castled.filestorage.CastledS3Client;
import io.castled.filestorage.ObjectStoreException;
import io.castled.schema.models.FieldSchema;
import io.castled.schema.models.Tuple;
import io.castled.utils.FileUtils;
import io.castled.utils.JsonUtils;
import io.castled.utils.SizeUtils;
import io.castled.warehouses.S3BasedWarehouseSyncFailureListener;
import io.castled.warehouses.WarehouseConnectorConfig;
import io.castled.warehouses.connectors.redshift.models.RedshiftS3CopyManifest;
import io.castled.warehouses.models.WarehousePollContext;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.nio.file.attribute.FileAttribute;
import java.sql.Connection;
import java.sql.SQLException;
import java.sql.Statement;
import java.util.HashMap;
import java.util.List;
import java.util.UUID;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:io/castled/warehouses/connectors/redshift/RedshiftSyncFailureListener.class */
public class RedshiftSyncFailureListener extends S3BasedWarehouseSyncFailureListener {
    private static final Logger log = LoggerFactory.getLogger((Class<?>) RedshiftSyncFailureListener.class);
    private final RedshiftConnector redshiftConnector;
    private final RawFileWriter rawFileWriter;
    private final RedshiftWarehouseConfig warehouseConfig;
    private final WarehousePollContext warehousePollContext;
    private final String s3UploadDir;
    private final CastledS3Client encryptedS3Client;
    private final CastledS3Client simpleS3Client;
    private int totalBytes;
    private long failedRecords;

    public RedshiftSyncFailureListener(WarehousePollContext warehousePollContext) {
        super(warehousePollContext, RedshiftUtils.getS3Client(warehousePollContext.getWarehouseConfig(), warehousePollContext.getDataEncryptionKey()));
        this.totalBytes = 0;
        this.failedRecords = 0L;
        this.warehouseConfig = (RedshiftWarehouseConfig) warehousePollContext.getWarehouseConfig();
        this.redshiftConnector = (RedshiftConnector) ObjectRegistry.getInstance(RedshiftConnector.class);
        this.rawFileWriter = new RawFileWriter(SizeUtils.convertMBToBytes(50L), this.failureRecordsDirectory, () -> {
            return UUID.randomUUID().toString();
        });
        this.warehousePollContext = warehousePollContext;
        this.encryptedS3Client = RedshiftUtils.getS3Client(this.warehouseConfig, warehousePollContext.getDataEncryptionKey());
        this.simpleS3Client = RedshiftUtils.getS3Client(this.warehouseConfig, null);
        this.s3UploadDir = getS3FailedRecordsDirectory(warehousePollContext.getPipelineUUID(), warehousePollContext.getPipelineRunId());
    }

    @Override // io.castled.warehouses.WarehouseSyncFailureListener
    public synchronized void doWriteRecord(Tuple tuple) throws Exception {
        byte[] bytes = getCopyableRecord(tuple).getBytes();
        this.rawFileWriter.writeRecord(bytes);
        this.totalBytes += bytes.length;
        this.failedRecords++;
        if (this.totalBytes > SizeUtils.convertGBToBytes(((WarehouseConnectorConfig) ObjectRegistry.getInstance(WarehouseConnectorConfig.class)).getFailedRecordBufferSize())) {
            this.rawFileWriter.close();
            uploadFilesToS3();
            if (Files.exists(this.failureRecordsDirectory, new LinkOption[0])) {
                return;
            }
            Files.createDirectory(this.failureRecordsDirectory, new FileAttribute[0]);
        }
    }

    private String getCopyableRecord(Tuple tuple) throws Exception {
        RedshiftCopySchemaMapper redshiftCopySchemaMapper = (RedshiftCopySchemaMapper) ObjectRegistry.getInstance(RedshiftCopySchemaMapper.class);
        HashMap newHashMap = Maps.newHashMap();
        for (FieldSchema fieldSchema : this.warehousePollContext.getWarehouseSchema().getFieldSchemas()) {
            if (this.trackableFields.contains(fieldSchema.getName())) {
                newHashMap.put(fieldSchema.getName(), redshiftCopySchemaMapper.transformValue(tuple.getValue(fieldSchema.getName()), fieldSchema.getSchema()));
            }
        }
        return JsonUtils.objectToString(newHashMap);
    }

    @Override // io.castled.warehouses.WarehouseSyncFailureListener
    public void doFlush() throws Exception {
        if (this.totalBytes > 0) {
            this.rawFileWriter.close();
            uploadFilesToS3();
        }
        Connection connection = this.redshiftConnector.getConnection(this.warehouseConfig);
        try {
            if (this.failedRecords > 0) {
                copyFailedRecords(connection);
                removeFailedRecordsFromSnapshot(connection);
            }
            commitSnapshot(connection);
            if (connection != null) {
                connection.close();
            }
        } catch (Throwable th) {
            if (connection != null) {
                try {
                    connection.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void commitSnapshot(Connection connection) throws SQLException {
        connection.setAutoCommit(false);
        String qualifiedUncommittedSnapshot = ConnectorExecutionConstants.getQualifiedUncommittedSnapshot(this.warehousePollContext.getPipelineUUID());
        String qualifiedCommittedSnapshot = ConnectorExecutionConstants.getQualifiedCommittedSnapshot(this.warehousePollContext.getPipelineUUID());
        try {
            Statement createStatement = connection.createStatement();
            try {
                createStatement.execute(String.format("drop table if exists %s", qualifiedCommittedSnapshot));
                createStatement.execute(String.format("alter table %s rename to %s", qualifiedUncommittedSnapshot, ConnectorExecutionConstants.getCommittedSnapshot(this.warehousePollContext.getPipelineUUID())));
                if (createStatement != null) {
                    createStatement.close();
                }
                connection.commit();
                connection.setAutoCommit(true);
            } finally {
            }
        } catch (Exception e) {
            connection.rollback();
            log.error("Committing snapshot for pipeline {} failed", this.warehousePollContext.getPipelineUUID(), e);
            throw new CastledRuntimeException(e);
        }
    }

    private void removeFailedRecordsFromSnapshot(Connection connection) throws SQLException {
        String qualifiedUncommittedSnapshot = ConnectorExecutionConstants.getQualifiedUncommittedSnapshot(this.warehousePollContext.getPipelineUUID());
        String failedRecordsTable = ConnectorExecutionConstants.getFailedRecordsTable(this.warehousePollContext.getPipelineUUID());
        StringBuilder sb = new StringBuilder(String.format("delete from %s using %s where 1 = 1", qualifiedUncommittedSnapshot, failedRecordsTable));
        for (String str : this.trackableFields) {
            sb.append(String.format(" AND (%s.%s = %s.%s OR (%s.%s IS NULL and %s.%s IS NULL))", failedRecordsTable, str, qualifiedUncommittedSnapshot, str, failedRecordsTable, str, qualifiedUncommittedSnapshot, str));
        }
        Statement createStatement = connection.createStatement();
        try {
            createStatement.execute(sb.toString());
            if (createStatement != null) {
                createStatement.close();
            }
        } catch (Throwable th) {
            if (createStatement != null) {
                try {
                    createStatement.close();
                } catch (Throwable th2) {
                    th.addSuppressed(th2);
                }
            }
            throw th;
        }
    }

    private void uploadFilesToS3() throws IOException {
        FileUtils.listFiles(this.failureRecordsDirectory).forEach(this::compressFile);
        this.encryptedS3Client.uploadDirectory(this.failureRecordsDirectory, this.s3UploadDir);
        FileUtils.deleteDirectory(this.failureRecordsDirectory);
        this.totalBytes = 0;
    }

    private void compressFile(Path path) {
        try {
            FileUtils.compressFile(path, Paths.get(path.toString() + ".gzip", new String[0]));
            Files.deleteIfExists(path);
        } catch (Exception e) {
            log.error("File compressed failed for file {}", path.toString());
            throw new CastledRuntimeException(e);
        }
    }

    private void copyFailedRecords(Connection connection) throws SQLException, ObjectStoreException {
        createFailedRecordsTable(connection);
        this.simpleS3Client.uploadText(CastledS3Client.constructObjectKey(Lists.newArrayList(this.s3UploadDir, "manifest.json")), JsonUtils.objectToString(new RedshiftS3CopyManifest((List) this.encryptedS3Client.listObjectUrls(this.s3UploadDir).stream().map(str -> {
            return new RedshiftS3CopyManifest.ManifestEntry(str, true);
        }).collect(Collectors.toList()))));
        ((RedshiftClient) ObjectRegistry.getInstance(RedshiftClient.class)).copyFilesToTable(connection, ConnectorExecutionConstants.getFailedRecordsTable(this.warehousePollContext.getPipelineUUID()), CastledS3Client.constructS3Path(this.encryptedS3Client.getBucket(), Lists.newArrayList(this.s3UploadDir, "manifest.json")), this.encryptedS3Client.getEncryptionKey(), this.warehouseConfig);
        this.encryptedS3Client.deleteDirectory(getS3FailedRecordsDirectory(this.warehousePollContext.getPipelineUUID(), this.warehousePollContext.getPipelineRunId()));
    }

    private void createFailedRecordsTable(Connection connection) throws SQLException {
        ((RedshiftClient) ObjectRegistry.getInstance(RedshiftClient.class)).createTableFromQuery(connection, ConnectorExecutionConstants.getFailedRecordsTable(this.warehousePollContext.getPipelineUUID()), String.format("select %s from %s limit 0", String.join(",", this.trackableFields), ConnectorExecutionConstants.getQualifiedUncommittedSnapshot(this.warehousePollContext.getPipelineUUID())), (RedshiftTableProperties) this.redshiftConnector.getSnapshotTableProperties(this.warehousePollContext.getPrimaryKeys()), true);
    }
}
