package com.amazon.redshift.plugin;

import com.amazon.redshift.CredentialsHolder;
import com.amazon.redshift.IPlugin;
import com.amazon.redshift.RedshiftProperty;
import com.amazon.redshift.httpclient.log.IamCustomLogFactory;
import com.amazon.redshift.logger.RedshiftLogger;
import com.amazon.redshift.plugin.utils.RequestUtils;
import com.amazonaws.SdkClientException;
import com.amazonaws.auth.AWSStaticCredentialsProvider;
import com.amazonaws.auth.AnonymousAWSCredentials;
import com.amazonaws.auth.BasicSessionCredentials;
import com.amazonaws.services.securitytoken.AWSSecurityTokenServiceClientBuilder;
import com.amazonaws.services.securitytoken.model.AssumeRoleWithWebIdentityRequest;
import com.amazonaws.services.securitytoken.model.Credentials;
import com.amazonaws.util.StringUtils;
import com.amazonaws.util.json.Jackson;
import com.fasterxml.jackson.databind.JsonNode;
import java.io.IOException;
import java.net.URL;
import java.util.Collections;
import java.util.Enumeration;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.codec.binary.Base64;
import org.apache.commons.logging.LogFactory;

/* loaded from: input_file:com/amazon/redshift/plugin/JwtCredentialsProvider.class */
public abstract class JwtCredentialsProvider implements IPlugin {
    private static final String KEY_ROLE_ARN = "roleArn";
    private static final String KEY_WEB_IDENTITY_TOKEN = "webIdentityToken";
    private static final String KEY_DURATION = "duration";
    private static final String KEY_ROLE_SESSION_NAME = "roleSessionName";
    private static final String DEFAULT_ROLE_SESSION_NAME = "jwt_redshift_session";
    protected String m_roleArn;
    protected String m_jwt;
    protected int m_duration;
    protected String m_dbUser;
    protected String m_stsEndpoint;
    protected String m_region;
    protected RedshiftLogger m_log;
    private CredentialsHolder m_lastRefreshCredentials;
    private static final String LOG_PROPERTIES_FILE_NAME = "log-factory.properties";
    private static final String LOG_PROPERTIES_FILE_PATH = "META-INF/services/org.apache.commons.logging.LogFactory";
    private static Map<String, CredentialsHolder> m_cache = new HashMap();
    private static final Class<?> CUSTOM_LOG_FACTORY_CLASS = IamCustomLogFactory.class;
    private static final ClassLoader CONTEXT_CLASS_LOADER = new ClassLoader(JwtCredentialsProvider.class.getClassLoader()) { // from class: com.amazon.redshift.plugin.JwtCredentialsProvider.1
        @Override // java.lang.ClassLoader
        public Class<?> loadClass(String str) throws ClassNotFoundException {
            Class<?> loadClass = getParent().loadClass(str);
            return LogFactory.class.isAssignableFrom(loadClass) ? JwtCredentialsProvider.CUSTOM_LOG_FACTORY_CLASS : loadClass;
        }

        @Override // java.lang.ClassLoader
        public Enumeration<URL> getResources(String str) throws IOException {
            return "commons-logging.properties".equals(str) ? Collections.enumeration(Collections.emptyList()) : super.getResources(str);
        }

        @Override // java.lang.ClassLoader
        public URL getResource(String str) {
            return JwtCredentialsProvider.LOG_PROPERTIES_FILE_PATH.equals(str) ? JwtCredentialsProvider.class.getResource(JwtCredentialsProvider.LOG_PROPERTIES_FILE_NAME) : super.getResource(str);
        }
    };
    protected String m_roleSessionName = DEFAULT_ROLE_SESSION_NAME;
    protected Boolean m_disableCache = false;
    protected Boolean m_groupFederation = false;

    protected abstract String processJwt(String str) throws IOException;

    @Override // com.amazon.redshift.IPlugin
    public void addParameter(String str, String str2) {
        if (RedshiftLogger.isEnable()) {
            this.m_log.logDebug("key: {0}", str);
        }
        if (KEY_ROLE_ARN.equalsIgnoreCase(str)) {
            this.m_roleArn = str2;
            return;
        }
        if (KEY_WEB_IDENTITY_TOKEN.equalsIgnoreCase(str)) {
            this.m_jwt = str2;
            return;
        }
        if (KEY_ROLE_SESSION_NAME.equalsIgnoreCase(str)) {
            this.m_roleSessionName = str2;
            return;
        }
        if (KEY_DURATION.equalsIgnoreCase(str)) {
            this.m_duration = Integer.parseInt(str2);
            return;
        }
        if (RedshiftProperty.DB_USER.getName().equalsIgnoreCase(str)) {
            return;
        }
        if (RedshiftProperty.AWS_REGION.getName().equalsIgnoreCase(str)) {
            this.m_region = str2;
        } else if (RedshiftProperty.STS_ENDPOINT_URL.getName().equalsIgnoreCase(str)) {
            this.m_stsEndpoint = str2;
        } else if (RedshiftProperty.IAM_DISABLE_CACHE.getName().equalsIgnoreCase(str)) {
            this.m_disableCache = Boolean.valueOf(str2);
        }
    }

    @Override // com.amazon.redshift.IPlugin
    public void setLogger(RedshiftLogger redshiftLogger) {
        this.m_log = redshiftLogger;
    }

    @Override // com.amazon.redshift.IPlugin
    public int getSubType() {
        return 2;
    }

    @Override // com.amazonaws.auth.AWSCredentialsProvider
    public CredentialsHolder getCredentials() {
        CredentialsHolder credentialsHolder = null;
        if (!this.m_disableCache.booleanValue()) {
            credentialsHolder = m_cache.get(getCacheKey());
        }
        if (credentialsHolder == null || credentialsHolder.isExpired()) {
            if (RedshiftLogger.isEnable()) {
                this.m_log.logInfo("JWT getCredentials NOT from cache", new Object[0]);
            }
            synchronized (this) {
                refresh();
                if (this.m_disableCache.booleanValue()) {
                    credentialsHolder = this.m_lastRefreshCredentials;
                    this.m_lastRefreshCredentials = null;
                }
            }
        } else {
            credentialsHolder.setRefresh(false);
            if (RedshiftLogger.isEnable()) {
                this.m_log.logInfo("SAML getCredentials from cache", new Object[0]);
            }
        }
        if (!this.m_disableCache.booleanValue()) {
            credentialsHolder = m_cache.get(getCacheKey());
        }
        if (credentialsHolder == null) {
            throw new SdkClientException("Unable to load AWS credentials from ADFS");
        }
        return credentialsHolder;
    }

    @Override // com.amazonaws.auth.AWSCredentialsProvider
    public void refresh() {
        Thread currentThread = Thread.currentThread();
        ClassLoader contextClassLoader = currentThread.getContextClassLoader();
        Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);
        try {
            try {
                String processJwt = processJwt(this.m_jwt);
                if (RedshiftLogger.isEnable()) {
                    this.m_log.logDebug(String.format("JWT : %s", processJwt), new Object[0]);
                }
                this.m_dbUser = deriveDatabaseUser(decodeJwt(this.m_jwt));
                AssumeRoleWithWebIdentityRequest assumeRoleWithWebIdentityRequest = new AssumeRoleWithWebIdentityRequest();
                assumeRoleWithWebIdentityRequest.setWebIdentityToken(processJwt);
                assumeRoleWithWebIdentityRequest.setRoleArn(this.m_roleArn);
                assumeRoleWithWebIdentityRequest.setRoleSessionName(this.m_roleSessionName);
                if (this.m_duration > 0) {
                    assumeRoleWithWebIdentityRequest.setDurationSeconds(Integer.valueOf(this.m_duration));
                }
                Credentials credentials = RequestUtils.buildSts(this.m_stsEndpoint, this.m_region, AWSSecurityTokenServiceClientBuilder.standard(), new AWSStaticCredentialsProvider(new AnonymousAWSCredentials()), this.m_log).assumeRoleWithWebIdentity(assumeRoleWithWebIdentityRequest).getCredentials();
                CredentialsHolder newInstance = CredentialsHolder.newInstance(new BasicSessionCredentials(credentials.getAccessKeyId(), credentials.getSecretAccessKey(), credentials.getSessionToken()), credentials.getExpiration());
                newInstance.setMetadata(readMetadata());
                newInstance.setRefresh(true);
                if (this.m_disableCache.booleanValue()) {
                    this.m_lastRefreshCredentials = newInstance;
                } else {
                    m_cache.put(getCacheKey(), newInstance);
                }
            } catch (Exception e) {
                if (RedshiftLogger.isEnable()) {
                    this.m_log.logError(e);
                }
                throw new SdkClientException("JWT error: " + e.getMessage(), e);
            }
        } finally {
            currentThread.setContextClassLoader(contextClassLoader);
        }
    }

    @Override // com.amazon.redshift.IPlugin
    public String getPluginSpecificCacheKey() {
        return "";
    }

    @Override // com.amazon.redshift.IPlugin
    public String getIdpToken() {
        Thread currentThread = Thread.currentThread();
        ClassLoader contextClassLoader = currentThread.getContextClassLoader();
        Thread.currentThread().setContextClassLoader(CONTEXT_CLASS_LOADER);
        try {
            try {
                String processJwt = processJwt(this.m_jwt);
                if (RedshiftLogger.isEnable()) {
                    this.m_log.logDebug(String.format("JWT : %s", processJwt), new Object[0]);
                }
                return processJwt;
            } catch (Exception e) {
                if (RedshiftLogger.isEnable()) {
                    this.m_log.logError(e);
                }
                throw new SdkClientException("JWT error: " + e.getMessage(), e);
            }
        } finally {
            currentThread.setContextClassLoader(contextClassLoader);
        }
    }

    @Override // com.amazon.redshift.IPlugin
    public void setGroupFederation(boolean z) {
        this.m_groupFederation = Boolean.valueOf(z);
    }

    @Override // com.amazon.redshift.IPlugin
    public String getCacheKey() {
        return this.m_roleArn + this.m_jwt + this.m_roleSessionName + this.m_duration + getPluginSpecificCacheKey();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void checkRequiredParameters() throws IOException {
        if (StringUtils.isNullOrEmpty(this.m_roleArn)) {
            throw new IOException("Missing required property: roleArn");
        }
        if (StringUtils.isNullOrEmpty(this.m_jwt)) {
            throw new IOException("Missing required property: webIdentityToken");
        }
    }

    protected String[] decodeJwt(String str) {
        if (str == null) {
            return null;
        }
        String[] split = str.split("\\.");
        if (split.length != 3) {
            return null;
        }
        String str2 = new String(Base64.decodeBase64(split[0]));
        String str3 = new String(Base64.decodeBase64(split[1]));
        String str4 = split[2];
        if (RedshiftLogger.isEnable()) {
            this.m_log.logDebug(String.format("Decoded JWT : Header: %s payload: %s signature:%s", str2, str3, str4), new Object[0]);
        }
        return new String[]{str2, str3, str4};
    }

    protected String deriveDatabaseUser(String[] strArr) {
        String str = null;
        if (strArr == null || strArr.length != 3) {
            throw new SdkClientException("JWT decoding error");
        }
        String[] strArr2 = {"DbUser", "upn", "preferred_username", "email"};
        JsonNode jsonNodeOf = Jackson.jsonNodeOf(strArr[1]);
        int length = strArr2.length;
        int i = 0;
        while (true) {
            if (i >= length) {
                break;
            }
            String str2 = strArr2[i];
            JsonNode findValue = jsonNodeOf.findValue(str2);
            if (findValue != null) {
                str = findValue.textValue();
                if (!StringUtils.isNullOrEmpty(str)) {
                    if (RedshiftLogger.isEnable()) {
                        this.m_log.logDebug(String.format("JWT claim: %s as database user: %s", str2, str), new Object[0]);
                    }
                }
            }
            i++;
        }
        if (StringUtils.isNullOrEmpty(str)) {
            throw new SdkClientException("No database user claim found in JWT");
        }
        return str;
    }

    private CredentialsHolder.IamMetadata readMetadata() {
        CredentialsHolder.IamMetadata iamMetadata = new CredentialsHolder.IamMetadata();
        iamMetadata.setDbUser(this.m_dbUser);
        iamMetadata.setAutoCreate(true);
        return iamMetadata;
    }
}
