diff --git a/.github/workflows/integration_tests.yml b/.github/workflows/integration_tests.yml index 937545b8..3bbc2572 100644 --- a/.github/workflows/integration_tests.yml +++ b/.github/workflows/integration_tests.yml @@ -19,7 +19,7 @@ jobs: matrix: python-version: [ "3.8", "3.11" ] engine-version: [ "lts", "latest"] - environment: ["mysql", "pg"] + environment: ["mysql", "pg", "dsql"] steps: - name: 'Clone repository' diff --git a/README.md b/README.md index 2f7122f7..e7123796 100644 --- a/README.md +++ b/README.md @@ -26,6 +26,12 @@ Since a database failover is usually identified by reaching a network or a conne Enhanced Failure Monitoring (EFM) is a feature available from the [Host Monitoring Connection Plugin](./docs/using-the-python-driver/using-plugins/UsingTheHostMonitoringPlugin.md#enhanced-failure-monitoring) that periodically checks the connected database host's health and availability. If a database host is determined to be unhealthy, the connection is aborted (and potentially routed to another healthy host in the cluster). +### Using the AWS Advanced Python Driver with AWS Aurora DSQL +The AWS Advanced Python Driver is able to handle IAM authentication when working with AWS Aurora DSQL clusters. + +Please visit [this page](./docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md) for more information. + + ### Using the AWS Advanced Python Driver with plain RDS databases The AWS Advanced Python Driver also works with RDS provided databases that are not Aurora. diff --git a/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py new file mode 100644 index 00000000..62068870 --- /dev/null +++ b/aws_advanced_python_wrapper/dsql_iam_auth_plugin_factory.py @@ -0,0 +1,30 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin +from aws_advanced_python_wrapper.plugin import Plugin, PluginFactory +from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.properties import Properties + + +class DsqlIamAuthPluginFactory(PluginFactory): + def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: + return IamAuthPlugin(plugin_service, DSQLTokenUtils()) diff --git a/aws_advanced_python_wrapper/federated_plugin.py b/aws_advanced_python_wrapper/federated_plugin.py index 0eb8258f..186f6d16 100644 --- a/aws_advanced_python_wrapper/federated_plugin.py +++ b/aws_advanced_python_wrapper/federated_plugin.py @@ -31,6 +31,7 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.token_utils import TokenUtils from datetime import datetime, timedelta from typing import Callable, Dict, Optional, Set @@ -43,6 +44,7 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils logger = Logger(__name__) @@ -55,12 +57,17 @@ class FederatedAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None): + def __init__(self, + plugin_service: PluginService, + credentials_provider_factory: CredentialsProviderFactory, + token_utils: TokenUtils, + session: Optional[Session] = None): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory self._session = session self._region_utils = RegionUtils() + self._token_utils = token_utils telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("federated.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge("federated.token_cache.size", lambda: len(FederatedAuthPlugin._token_cache)) @@ -145,7 +152,7 @@ def _update_authentication_token(self, credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) self._fetch_token_counter.inc() - token: str = IamAuthUtils.generate_authentication_token( + token: str = self._token_utils.generate_authentication_token( self._plugin_service, user, host_info.host, @@ -159,7 +166,7 @@ def _update_authentication_token(self, class FederatedAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props)) + return FederatedAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils()) def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> AdfsCredentialsProviderFactory: idp_name = WrapperProperties.IDP_NAME.get(props) diff --git a/aws_advanced_python_wrapper/iam_plugin.py b/aws_advanced_python_wrapper/iam_plugin.py index 1a26c58a..06fb3b07 100644 --- a/aws_advanced_python_wrapper/iam_plugin.py +++ b/aws_advanced_python_wrapper/iam_plugin.py @@ -17,6 +17,8 @@ from typing import TYPE_CHECKING from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils, TokenInfo +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils +from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils from aws_advanced_python_wrapper.utils.region_utils import RegionUtils if TYPE_CHECKING: @@ -25,6 +27,7 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.token_utils import TokenUtils from datetime import datetime, timedelta from typing import Callable, Dict, Optional, Set @@ -35,7 +38,6 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) -from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils logger = Logger(__name__) @@ -48,11 +50,12 @@ class IamAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, session: Optional[Session] = None): + def __init__(self, plugin_service: PluginService, token_utils: TokenUtils, session: Optional[Session] = None): self._plugin_service = plugin_service self._session = session self._region_utils = RegionUtils() + self._token_utils = token_utils telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("iam.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge( @@ -102,7 +105,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl else: token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) self._fetch_token_counter.inc() - token: str = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) + token: str = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) @@ -120,7 +123,7 @@ def _connect(self, host_info: HostInfo, props: Properties, connect_func: Callabl # Try to generate a new token and try to connect again token_expiry = datetime.now() + timedelta(seconds=token_expiration_sec) self._fetch_token_counter.inc() - token = IamAuthUtils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) + token = self._token_utils.generate_authentication_token(self._plugin_service, user, host, port, region, client_session=self._session) self._plugin_service.driver_dialect.set_password(props, token) IamAuthPlugin._token_cache[cache_key] = TokenInfo(token, token_expiry) @@ -142,4 +145,4 @@ def force_connect( class IamAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return IamAuthPlugin(plugin_service) + return IamAuthPlugin(plugin_service, RDSTokenUtils()) diff --git a/aws_advanced_python_wrapper/okta_plugin.py b/aws_advanced_python_wrapper/okta_plugin.py index 55bd9980..d1e9e19e 100644 --- a/aws_advanced_python_wrapper/okta_plugin.py +++ b/aws_advanced_python_wrapper/okta_plugin.py @@ -31,6 +31,7 @@ from aws_advanced_python_wrapper.hostinfo import HostInfo from aws_advanced_python_wrapper.pep249 import Connection from aws_advanced_python_wrapper.plugin_service import PluginService + from aws_advanced_python_wrapper.utils.token_utils import TokenUtils import requests @@ -40,6 +41,7 @@ from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils logger = Logger(__name__) @@ -51,12 +53,17 @@ class OktaAuthPlugin(Plugin): _rds_utils: RdsUtils = RdsUtils() _token_cache: Dict[str, TokenInfo] = {} - def __init__(self, plugin_service: PluginService, credentials_provider_factory: CredentialsProviderFactory, session: Optional[Session] = None): + def __init__(self, + plugin_service: PluginService, + credentials_provider_factory: CredentialsProviderFactory, + token_utils: TokenUtils, + session: Optional[Session] = None): self._plugin_service = plugin_service self._credentials_provider_factory = credentials_provider_factory self._session = session self._region_utils = RegionUtils() + self._token_utils = token_utils telemetry_factory = self._plugin_service.get_telemetry_factory() self._fetch_token_counter = telemetry_factory.create_counter("okta.fetch_token.count") self._cache_size_gauge = telemetry_factory.create_gauge("okta.token_cache.size", lambda: len(OktaAuthPlugin._token_cache)) @@ -140,7 +147,7 @@ def _update_authentication_token(self, port: int = IamAuthUtils.get_port(props, host_info, self._plugin_service.database_dialect.default_port) credentials: Optional[Dict[str, str]] = self._credentials_provider_factory.get_aws_credentials(region, props) - token: str = IamAuthUtils.generate_authentication_token( + token: str = self._token_utils.generate_authentication_token( self._plugin_service, user, host_info.host, @@ -228,7 +235,7 @@ def get_saml_assertion(self, props: Properties): class OktaAuthPluginFactory(PluginFactory): def get_instance(self, plugin_service: PluginService, props: Properties) -> Plugin: - return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props)) + return OktaAuthPlugin(plugin_service, self.get_credentials_provider_factory(plugin_service, props), RDSTokenUtils()) def get_credentials_provider_factory(self, plugin_service: PluginService, props: Properties) -> OktaCredentialsProviderFactory: return OktaCredentialsProviderFactory(plugin_service, props) diff --git a/aws_advanced_python_wrapper/plugin_service.py b/aws_advanced_python_wrapper/plugin_service.py index 4be4fea3..8d2d4fea 100644 --- a/aws_advanced_python_wrapper/plugin_service.py +++ b/aws_advanced_python_wrapper/plugin_service.py @@ -57,6 +57,8 @@ from aws_advanced_python_wrapper.developer_plugin import DeveloperPluginFactory from aws_advanced_python_wrapper.driver_configuration_profiles import \ DriverConfigurationProfiles +from aws_advanced_python_wrapper.dsql_iam_auth_plugin_factory import \ + DsqlIamAuthPluginFactory from aws_advanced_python_wrapper.errors import (AwsWrapperError, QueryTimeoutError, UnsupportedOperationError) @@ -716,6 +718,7 @@ class PluginManager(CanReleaseResources): PLUGIN_FACTORIES: Dict[str, Type[PluginFactory]] = { "iam": IamAuthPluginFactory, + "iam_dsql": DsqlIamAuthPluginFactory, "aws_secrets_manager": AwsSecretsManagerPluginFactory, "aurora_connection_tracker": AuroraConnectionTrackerPluginFactory, "host_monitoring": HostMonitoringPluginFactory, @@ -748,6 +751,7 @@ class PluginManager(CanReleaseResources): HostMonitoringPluginFactory: 500, FastestResponseStrategyPluginFactory: 600, IamAuthPluginFactory: 700, + DsqlIamAuthPluginFactory: 710, AwsSecretsManagerPluginFactory: 800, FederatedAuthPluginFactory: 900, LimitlessPluginFactory: 950, diff --git a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties index a7c37c48..56d13449 100644 --- a/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties +++ b/aws_advanced_python_wrapper/resources/aws_advanced_python_wrapper_messages.properties @@ -151,7 +151,6 @@ IamAuthPlugin.UnhandledException=[IamAuthPlugin] Unhandled exception: {} IamAuthPlugin.UseCachedIamToken=[IamAuthPlugin] Used cached IAM token = {} IamAuthPlugin.InvalidHost=[IamAuthPlugin] Invalid IAM host {}. The IAM host must be a valid RDS or Aurora endpoint. IamAuthPlugin.IsNoneOrEmpty=[IamAuthPlugin] Property "{}" is None or empty. -IamAuthUtils.GeneratedNewAuthToken=Generated new authentication token = {} LimitlessPlugin.FailedToConnectToHost=[LimitlessPlugin] Failed to connect to host {}. LimitlessPlugin.UnsupportedDialectOrDatabase=[LimitlessPlugin] Unsupported dialect '{}' encountered. Please ensure the connection parameters are correct, and refer to the documentation to ensure that the connecting database is compatible with the Limitless Connection Plugin. @@ -316,6 +315,8 @@ RoundRobinHostSelector.ClusterInfoNone=[RoundRobinHostSelector] The round robin RoundRobinHostSelector.RoundRobinInvalidDefaultWeight=[RoundRobinHostSelector] The provided default weight value is not valid. Weight values must be an integer greater than or equal to 1. RoundRobinHostSelector.RoundRobinInvalidHostWeightPairs= [RoundRobinHostSelector] The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of :. Weight values must be an integer greater than or equal to the default weight value of 1. Weight pair: '{}' +TokenUtils.GeneratedNewAuthTokenLength=Generated new authentication token length = {} + WeightedRandomHostSelector.WeightedRandomInvalidHostWeightPairs= [WeightedRandomHostSelector] The provided host weight pairs have not been configured correctly. Please ensure the provided host weight pairs is a comma separated list of pairs, each pair in the format of :. Weight values must be an integer greater than or equal to the default weight value of 1. Weight pair: '{}' WeightedRandomHostSelector.WeightedRandomInvalidDefaultWeight=[WeightedRandomHostSelector] The provided default weight value is not valid. Weight values must be an integer greater than or equal to 1. diff --git a/aws_advanced_python_wrapper/utils/dsql_token_utils.py b/aws_advanced_python_wrapper/utils/dsql_token_utils.py new file mode 100644 index 00000000..aa61e02e --- /dev/null +++ b/aws_advanced_python_wrapper/utils/dsql_token_utils.py @@ -0,0 +1,75 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional + +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ + TelemetryTraceLevel +from aws_advanced_python_wrapper.utils.token_utils import TokenUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService + from boto3 import Session + +import boto3 + +logger = Logger(__name__) + + +class DSQLTokenUtils(TokenUtils): + def generate_authentication_token( + self, + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + telemetry_factory = plugin_service.get_telemetry_factory() + context = telemetry_factory.open_telemetry_context("fetch DSQL authentication token", TelemetryTraceLevel.NESTED) + + try: + session = client_session if client_session else boto3.Session() + if credentials is not None: + client = session.client( + 'dsql', + region_name=region, + aws_access_key_id=credentials.get('AccessKeyId'), + aws_secret_access_key=credentials.get('SecretAccessKey'), + aws_session_token=credentials.get('SessionToken') + ) + else: + client = session.client( + 'dsql', + region_name=region + ) + + if user == "admin": + token = client.generate_db_connect_admin_auth_token(host_name, region) + else: + token = client.generate_db_connect_auth_token(host_name, region) + + logger.debug("TokenUtils.GeneratedNewAuthTokenLength", len(token) if token else 0) + client.close() + return token + except Exception as ex: + context.set_success(False) + context.set_exception(ex) + raise ex + finally: + context.close_context() diff --git a/aws_advanced_python_wrapper/utils/iam_utils.py b/aws_advanced_python_wrapper/utils/iam_utils.py index ecb5868f..412499d0 100644 --- a/aws_advanced_python_wrapper/utils/iam_utils.py +++ b/aws_advanced_python_wrapper/utils/iam_utils.py @@ -15,22 +15,16 @@ from __future__ import annotations from datetime import datetime -from typing import TYPE_CHECKING, Dict, Optional - -import boto3 +from typing import TYPE_CHECKING, Optional from aws_advanced_python_wrapper.errors import AwsWrapperError from aws_advanced_python_wrapper.utils.log import Logger from aws_advanced_python_wrapper.utils.messages import Messages from aws_advanced_python_wrapper.utils.rds_url_type import RdsUrlType from aws_advanced_python_wrapper.utils.rdsutils import RdsUtils -from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ - TelemetryTraceLevel if TYPE_CHECKING: from aws_advanced_python_wrapper.hostinfo import HostInfo - from aws_advanced_python_wrapper.plugin_service import PluginService - from boto3 import Session from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) @@ -70,52 +64,6 @@ def get_port(props: Properties, host_info: HostInfo, dialect_default_port: int) def get_cache_key(user: Optional[str], hostname: Optional[str], port: int, region: Optional[str]) -> str: return f"{region}:{hostname}:{port}:{user}" - @staticmethod - def generate_authentication_token( - plugin_service: PluginService, - user: Optional[str], - host_name: Optional[str], - port: Optional[int], - region: Optional[str], - credentials: Optional[Dict[str, str]] = None, - client_session: Optional[Session] = None) -> str: - telemetry_factory = plugin_service.get_telemetry_factory() - context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) - - try: - session = client_session if client_session else boto3.Session() - - if credentials is not None: - client = session.client( - 'rds', - region_name=region, - aws_access_key_id=credentials.get('AccessKeyId'), - aws_secret_access_key=credentials.get('SecretAccessKey'), - aws_session_token=credentials.get('SessionToken') - ) - else: - client = session.client( - 'rds', - region_name=region - ) - - token = client.generate_db_auth_token( - DBHostname=host_name, - Port=port, - DBUsername=user - ) - - client.close() - - logger.debug("IamAuthUtils.GeneratedNewAuthToken", token) - return token - except Exception as ex: - context.set_success(False) - context.set_exception(ex) - raise ex - finally: - context.close_context() - class TokenInfo: @property diff --git a/aws_advanced_python_wrapper/utils/rds_token_utils.py b/aws_advanced_python_wrapper/utils/rds_token_utils.py new file mode 100644 index 00000000..3497cc65 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/rds_token_utils.py @@ -0,0 +1,79 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING, Dict, Optional + +from aws_advanced_python_wrapper.utils.log import Logger +from aws_advanced_python_wrapper.utils.telemetry.telemetry import \ + TelemetryTraceLevel +from aws_advanced_python_wrapper.utils.token_utils import TokenUtils + +if TYPE_CHECKING: + from aws_advanced_python_wrapper.plugin_service import PluginService + from boto3 import Session + +import boto3 + +logger = Logger(__name__) + + +class RDSTokenUtils(TokenUtils): + def generate_authentication_token( + self, + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + + telemetry_factory = plugin_service.get_telemetry_factory() + context = telemetry_factory.open_telemetry_context("fetch authentication token", TelemetryTraceLevel.NESTED) + + try: + session = client_session if client_session else boto3.Session() + + if credentials is not None: + client = session.client( + 'rds', + region_name=region, + aws_access_key_id=credentials.get('AccessKeyId'), + aws_secret_access_key=credentials.get('SecretAccessKey'), + aws_session_token=credentials.get('SessionToken') + ) + else: + client = session.client( + 'rds', + region_name=region + ) + + token = client.generate_db_auth_token( + DBHostname=host_name, + Port=port, + DBUsername=user + ) + + client.close() + + logger.debug("TokenUtils.GeneratedNewAuthTokenLength", len(token) if token else 0) + return token + except Exception as ex: + context.set_success(False) + context.set_exception(ex) + raise ex + finally: + context.close_context() diff --git a/aws_advanced_python_wrapper/utils/rds_url_type.py b/aws_advanced_python_wrapper/utils/rds_url_type.py index 7226c33c..911c25c6 100644 --- a/aws_advanced_python_wrapper/utils/rds_url_type.py +++ b/aws_advanced_python_wrapper/utils/rds_url_type.py @@ -34,4 +34,5 @@ def __init__(self, is_rds: bool, is_rds_cluster: bool): RDS_PROXY = True, False, RDS_INSTANCE = True, False, RDS_AURORA_LIMITLESS_DB_SHARD_GROUP = True, False, + DSQL_CLUSTER = False, False, OTHER = False, False diff --git a/aws_advanced_python_wrapper/utils/rdsutils.py b/aws_advanced_python_wrapper/utils/rdsutils.py index 7e289d2d..41726035 100644 --- a/aws_advanced_python_wrapper/utils/rdsutils.py +++ b/aws_advanced_python_wrapper/utils/rdsutils.py @@ -108,6 +108,10 @@ class RdsUtils: r"(?Pcluster-|cluster-ro-)+" \ r"(?P[a-zA-Z0-9]+\.rds\.(?P[a-zA-Z0-9\-]+)" \ r"\.(amazonaws\.com|c2s\.ic\.gov|sc2s\.sgov\.gov))$" + AURORA_DSQL_CLUSTER_PATTERN = r"^(?P[^.]+)\." \ + r"(?Pdsql(?:-[^.]+)?)\." \ + r"(?P(?P[a-zA-Z0-9\-]+)" \ + r"\.on\.aws\.?)$" ELB_PATTERN = r"^(?.+)\.elb\.((?[a-zA-Z0-9\-]+)\.amazonaws\.com)$" IP_V4 = r"^(([1-9]|[1-9][0-9]|1[0-9]{2}|2[0-4][0-9]|25[0-5])\.){1}" \ @@ -149,6 +153,14 @@ def is_rds_dns(self, host: str) -> bool: def is_rds_instance(self, host: str) -> bool: return self._get_dns_group(host) is None and self.is_rds_dns(host) + def is_dsql_cluster(self, host: str) -> bool: + if not host or not host.strip(): + return False + + pattern = self._find(host, [RdsUtils.AURORA_DSQL_CLUSTER_PATTERN]) + + return pattern is not None + def is_rds_proxy_dns(self, host: str) -> bool: dns_group = self._get_dns_group(host) return dns_group is not None and dns_group.casefold() == "proxy-" @@ -257,6 +269,8 @@ def identify_rds_type(self, host: Optional[str]) -> RdsUrlType: return RdsUrlType.RDS_PROXY elif self.is_rds_instance(host): return RdsUrlType.RDS_INSTANCE + elif self.is_dsql_cluster(host): + return RdsUrlType.DSQL_CLUSTER return RdsUrlType.OTHER diff --git a/aws_advanced_python_wrapper/utils/token_utils.py b/aws_advanced_python_wrapper/utils/token_utils.py new file mode 100644 index 00000000..f23c61e7 --- /dev/null +++ b/aws_advanced_python_wrapper/utils/token_utils.py @@ -0,0 +1,37 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Dict, Optional + +if TYPE_CHECKING: + from boto3 import Session + + from aws_advanced_python_wrapper.plugin_service import PluginService + + +class TokenUtils(ABC): + @abstractmethod + def generate_authentication_token( + self, + plugin_service: PluginService, + user: Optional[str], + host_name: Optional[str], + port: Optional[int], + region: Optional[str], + credentials: Optional[Dict[str, str]] = None, + client_session: Optional[Session] = None) -> str: + pass diff --git a/docs/README.md b/docs/README.md index ef76cae1..ab7a88c9 100644 --- a/docs/README.md +++ b/docs/README.md @@ -17,6 +17,7 @@ - [Aurora Initial Connection Strategy Plugin](./using-the-python-driver/using-plugins/UsingTheAuroraInitialConnectionStrategyPlugin.md) - [Host Availability Strategy](./using-the-python-driver/HostAvailabilityStrategy.md) - [IAM Authentication Plugin](./using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md) + - [DSQL IAM Authentication Plugin](./using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md) - [AWS Secrets Manager Plugin](./using-the-python-driver/using-plugins/UsingTheAwsSecretsManagerPlugin.md) - [Federated Authentication Plugin](./using-the-python-driver/using-plugins/UsingTheFederatedAuthenticationPlugin.md) - [Read Write Splitting Plugin](./using-the-python-driver/using-plugins/UsingTheReadWriteSplittingPlugin.md) diff --git a/docs/examples/DSQLIamAuthentication.py b/docs/examples/DSQLIamAuthentication.py new file mode 100644 index 00000000..4e9b2870 --- /dev/null +++ b/docs/examples/DSQLIamAuthentication.py @@ -0,0 +1,38 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import psycopg + +from aws_advanced_python_wrapper import AwsWrapperConnection + +if __name__ == "__main__": + with AwsWrapperConnection.connect( + psycopg.Connection.connect, + host="abcd.dsql.us-east-1.on.aws", + dbname="postgres", + user="admin", + plugins="iam_dsql", + iam_region="us-east-1", + wrapper_dialect="pg", + autocommit=True + ) as awsconn, awsconn.cursor() as awscursor: + awscursor.execute("CREATE TABLE IF NOT EXISTS bank_test (id int primary key, name varchar(40), account_balance int)") + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (0, "Jane Doe", 200)) + awscursor.execute("INSERT INTO bank_test VALUES (%s, %s, %s)", (1, "John Smith", 200)) + awscursor.execute("SELECT * FROM bank_test") + + res = awscursor.fetchall() + for record in res: + print(record) + awscursor.execute("DROP TABLE bank_test") diff --git a/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md new file mode 100644 index 00000000..87e36553 --- /dev/null +++ b/docs/using-the-python-driver/using-plugins/UsingTheDSQLIamAuthenticationPlugin.md @@ -0,0 +1,39 @@ +# AWS Aurora DSQL IAM Authentication Plugin + +This plugin enables connecting to AWS Aurora DSQL databases through AWS Identity and Access Management (IAM). + +## What is IAM? +AWS Identity and Access Management (IAM) grants users access control across all Amazon Web Services. IAM supports granular permissions, giving you the ability to grant different permissions to different users. For more information on IAM and its use cases, please refer to the [IAM documentation](https://docs.aws.amazon.com/IAM/latest/UserGuide/introduction.html). + +## Prerequisites +> [!WARNING]\ +> To preserve compatibility with customers using the community driver, IAM Authentication requires the AWS SDK for Python; [Boto3](https://pypi.org/project/boto3/). Boto3 is a runtime dependency and must be resolved. It can be installed via pip like so: `pip install boto3`. + +The DSQL IAM Authentication plugin requires authentication via AWS Credentials. These credentials can be defined in `~/.aws/credentials` or set as environment variables. All users must set `AWS_ACCESS_KEY_ID` and `AWS_SECRET_ACCESS_KEY`. Users who are using temporary security credentials will also need to additionally set `AWS_SESSION_TOKEN`. + +To enable the AWS Aurora DSQL IAM Authentication Plugin, add the plugin code `iam_dsql` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter. + +> [!WARNING]\ +> The `iam` plugin must NOT be specified when using the `iam_dsql` plugin. + +## AWS IAM Database Authentication +The AWS Python Driver supports Amazon AWS Identity and Access Management (IAM) authentication. When using AWS IAM database authentication, the host URL must be a valid AWS Aurora DSQL endpoint, and not a custom domain or an IP address. +
i.e. `cluster-identifier.dsql.us-east-1.on.aws` + +Connections established by the `iam_dsql` plugin are beholden to the [Cluster quotas and database limits in Amazon Aurora DSQL](https://docs.aws.amazon.com/aurora-dsql/latest/userguide/CHAP_quotas.html). In particular, applications need to consider the maximum transaction duration, and maximum connection duration limits. Ensure connections are returned to the pool regularly, and not retained for long periods. + + +## How do I use IAM with the AWS Python Driver? +1. Configure IAM roles for the cluster according to [Using database roles and IAM authentication](https://docs.aws.amazon.com/aurora-dsql/latest/userguide/using-database-and-iam-roles.html). +2. Add the plugin code `iam_dsql` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter value. + +| Parameter | Value | Required | Description | Example Value | +|--------------------|:-------:|:--------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------| +| `iam_host` | String | No | This property will override the default hostname that is used to generate the IAM token. The default hostname is derived from the connection string. This parameter is required when users are connecting with custom endpoints. | `cluster-identifier.dsql.us-east-1.on.aws` | +| `iam_region` | String | No | This property will override the default region that is used to generate the IAM token. The default region is parsed from the connection string where possible. Some connection string formats may not be supported, and the `iam_region` must be provided in these cases. | `us-east-2` | +| `iam_expiration` | Integer | No | This property determines how long an IAM token is kept in the driver cache before a new one is generated. The default expiration time is set to 14 minutes and 30 seconds. Note that IAM database authentication tokens have a lifetime of 15 minutes. | `600` | + +## Sample code + +[DSQLIamAuthentication.py](../../examples/DSQLIamAuthentication.py) + diff --git a/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md b/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md index e879e2a6..a5fadfcc 100644 --- a/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md +++ b/docs/using-the-python-driver/using-plugins/UsingTheIamAuthenticationPlugin.md @@ -11,6 +11,9 @@ The IAM Authentication plugin requires authentication via AWS Credentials. These To enable the IAM Authentication Connection Plugin, add the plugin code `iam` to the [`plugins`](../UsingThePythonDriver.md#connection-plugin-manager-parameters) parameter. +> [!WARNING]\ +> The `iam` plugin must NOT be specified when using the `iam_dsql` plugin. + ## AWS IAM Database Authentication The AWS Python Driver supports Amazon AWS Identity and Access Management (IAM) authentication. When using AWS IAM database authentication, the host URL must be a valid Amazon endpoint, and not a custom domain or an IP address.
i.e. `db-identifier.cluster-XYZ.us-east-2.rds.amazonaws.com` diff --git a/tests/integration/container/test_aurora_failover.py b/tests/integration/container/test_aurora_failover.py index 1a303e33..f82dcaa3 100644 --- a/tests/integration/container/test_aurora_failover.py +++ b/tests/integration/container/test_aurora_failover.py @@ -42,7 +42,9 @@ @enable_on_num_instances(min_instances=2) @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestAuroraFailover: IDLE_CONNECTIONS_NUM: int = 5 logger = Logger(__name__) diff --git a/tests/integration/container/test_autoscaling.py b/tests/integration/container/test_autoscaling.py index 61c9e7d2..ba328e82 100644 --- a/tests/integration/container/test_autoscaling.py +++ b/tests/integration/container/test_autoscaling.py @@ -42,7 +42,8 @@ @enable_on_num_instances(min_instances=5) -@enable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY]) +@enable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestAutoScaling: @pytest.fixture def rds_utils(self): diff --git a/tests/integration/container/test_basic_connectivity.py b/tests/integration/container/test_basic_connectivity.py index ef03614e..bf56f14b 100644 --- a/tests/integration/container/test_basic_connectivity.py +++ b/tests/integration/container/test_basic_connectivity.py @@ -36,7 +36,9 @@ from .utils.test_environment_features import TestEnvironmentFeatures -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestBasicConnectivity: @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_basic_functionality.py b/tests/integration/container/test_basic_functionality.py index 34f66c62..193f2d39 100644 --- a/tests/integration/container/test_basic_functionality.py +++ b/tests/integration/container/test_basic_functionality.py @@ -46,7 +46,9 @@ from .utils.test_environment_features import TestEnvironmentFeatures -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestBasicFunctionality: @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_custom_endpoint.py b/tests/integration/container/test_custom_endpoint.py index ee33bcfa..dbda019d 100644 --- a/tests/integration/container/test_custom_endpoint.py +++ b/tests/integration/container/test_custom_endpoint.py @@ -45,7 +45,9 @@ @enable_on_num_instances(min_instances=3) @enable_on_deployments([DatabaseEngineDeployment.AURORA]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestCustomEndpoint: logger: ClassVar[Logger] = Logger(__name__) endpoint_id: ClassVar[str] = f"test-endpoint-1-{uuid4()}" diff --git a/tests/integration/container/test_iam_authentication.py b/tests/integration/container/test_iam_authentication.py index 0e4e2e01..9ef98f21 100644 --- a/tests/integration/container/test_iam_authentication.py +++ b/tests/integration/container/test_iam_authentication.py @@ -39,7 +39,9 @@ @enable_on_features([TestEnvironmentFeatures.IAM]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestAwsIamAuthentication: @pytest.fixture(scope='class') diff --git a/tests/integration/container/test_iam_dsql_authentication.py b/tests/integration/container/test_iam_dsql_authentication.py new file mode 100644 index 00000000..3c4ce39c --- /dev/null +++ b/tests/integration/container/test_iam_dsql_authentication.py @@ -0,0 +1,118 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) +from tests.integration.container.utils.test_environment_features import \ + TestEnvironmentFeatures + +if TYPE_CHECKING: + from tests.integration.container.utils.test_driver import TestDriver + +from socket import gethostbyname +from typing import Callable + +import pytest + +from aws_advanced_python_wrapper import AwsWrapperConnection +from aws_advanced_python_wrapper.errors import AwsWrapperError +from tests.integration.container.utils.conditions import enable_on_features +from tests.integration.container.utils.driver_helper import DriverHelper +from tests.integration.container.utils.test_environment import TestEnvironment + + +@enable_on_features([TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) +class TestAwsIamDSQLAuthentication: + + @pytest.fixture(scope='class') + def props(self): + p: Properties = Properties() + + if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features() \ + or TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in TestEnvironment.get_current().get_features(): + WrapperProperties.ENABLE_TELEMETRY.set(p, "True") + WrapperProperties.TELEMETRY_SUBMIT_TOPLEVEL.set(p, "True") + + if TestEnvironmentFeatures.TELEMETRY_TRACES_ENABLED in TestEnvironment.get_current().get_features(): + WrapperProperties.TELEMETRY_TRACES_BACKEND.set(p, "XRAY") + + if TestEnvironmentFeatures.TELEMETRY_METRICS_ENABLED in TestEnvironment.get_current().get_features(): + WrapperProperties.TELEMETRY_METRICS_BACKEND.set(p, "OTLP") + + return p + + def test_iam_wrong_database_username(self, test_environment: TestEnvironment, + test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + user = f"WRONG_{conn_utils.iam_user}_USER" + params = conn_utils.get_connect_params(user=user) + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + + with pytest.raises(AwsWrapperError): + AwsWrapperConnection.connect( + target_driver_connect, + **params, + plugins="iam_dsql", + **props) + + def test_iam_no_database_username(self, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params() + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params.pop("user", None) + + with pytest.raises(AwsWrapperError): + AwsWrapperConnection.connect(target_driver_connect, **params, plugins="iam_dsql", **props) + + def test_iam_invalid_host(self, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params() + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params.update({"iam_host": "<>", "plugins": "iam_dsql"}) + + with pytest.raises(AwsWrapperError): + AwsWrapperConnection.connect(target_driver_connect, **params, **props) + + def test_iam_valid_connection_properties( + self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params(user=conn_utils.iam_user, password="") + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params["plugins"] = "iam_dsql" + + self.validate_connection(target_driver_connect, **params, **props) + + def test_iam_valid_connection_properties_no_password( + self, test_environment: TestEnvironment, test_driver: TestDriver, conn_utils, props): + target_driver_connect = DriverHelper.get_connect_func(test_driver) + params = conn_utils.get_connect_params(user=conn_utils.iam_user) + params.pop("use_pure", None) # AWS tokens are truncated when using the pure Python MySQL driver + params.pop("password", None) + params["plugins"] = "iam_dsql" + + self.validate_connection(target_driver_connect, **params, **props) + + def get_ip_address(self, hostname: str): + return gethostbyname(hostname) + + def validate_connection(self, target_driver_connect: Callable, **connect_params): + with AwsWrapperConnection.connect(target_driver_connect, **connect_params) as conn, \ + conn.cursor() as cursor: + cursor.execute("SELECT now()") + records = cursor.fetchall() + assert len(records) == 1 diff --git a/tests/integration/container/test_read_write_splitting.py b/tests/integration/container/test_read_write_splitting.py index 86ad049a..6e352419 100644 --- a/tests/integration/container/test_read_write_splitting.py +++ b/tests/integration/container/test_read_write_splitting.py @@ -43,7 +43,9 @@ @enable_on_num_instances(min_instances=2) @enable_on_deployments([DatabaseEngineDeployment.AURORA, DatabaseEngineDeployment.RDS_MULTI_AZ]) -@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, TestEnvironmentFeatures.PERFORMANCE]) +@disable_on_features([TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY, + TestEnvironmentFeatures.PERFORMANCE, + TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY]) class TestReadWriteSplitting: @pytest.fixture(scope='class') def rds_utils(self): diff --git a/tests/integration/container/utils/database_engine_deployment.py b/tests/integration/container/utils/database_engine_deployment.py index 3f542745..8bf401a3 100644 --- a/tests/integration/container/utils/database_engine_deployment.py +++ b/tests/integration/container/utils/database_engine_deployment.py @@ -20,3 +20,4 @@ class DatabaseEngineDeployment(str, Enum): RDS = "RDS" RDS_MULTI_AZ = "RDS_MULTI_AZ" AURORA = "AURORA" + DSQL = "DSQL" diff --git a/tests/integration/container/utils/test_environment_features.py b/tests/integration/container/utils/test_environment_features.py index dfbb7fd9..62c5b4a4 100644 --- a/tests/integration/container/utils/test_environment_features.py +++ b/tests/integration/container/utils/test_environment_features.py @@ -26,6 +26,7 @@ class TestEnvironmentFeatures(Enum): AWS_CREDENTIALS_ENABLED = "AWS_CREDENTIALS_ENABLED" PERFORMANCE = "PERFORMANCE" RUN_AUTOSCALING_TESTS_ONLY = "RUN_AUTOSCALING_TESTS_ONLY" + RUN_DSQL_TESTS_ONLY = "RUN_DSQL_TESTS_ONLY" SKIP_MYSQL_DRIVER_TESTS = "SKIP_MYSQL_DRIVER_TESTS" SKIP_PG_DRIVER_TESTS = "SKIP_PG_DRIVER_TESTS" TELEMETRY_TRACES_ENABLED = "TELEMETRY_TRACES_ENABLED" diff --git a/tests/integration/host/build.gradle.kts b/tests/integration/host/build.gradle.kts index 7351d814..112cde47 100644 --- a/tests/integration/host/build.gradle.kts +++ b/tests/integration/host/build.gradle.kts @@ -30,6 +30,7 @@ dependencies { testImplementation("software.amazon.awssdk:rds:2.20.49") testImplementation("software.amazon.awssdk:ec2:2.20.61") testImplementation("software.amazon.awssdk:secretsmanager:2.20.49") + testImplementation("software.amazon.awssdk:dsql:2.29.34") // Note: all org.testcontainers dependencies should have the same version testImplementation("org.testcontainers:testcontainers:1.21.2") testImplementation("org.testcontainers:mysql:1.21.2") @@ -72,6 +73,7 @@ tasks.register("test-python-3.11-mysql") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -84,6 +86,7 @@ tasks.register("test-python-3.8-mysql") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -98,6 +101,7 @@ tasks.register("test-python-3.11-pg") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -112,6 +116,43 @@ tasks.register("test-python-3.8-pg") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") + } +} + +tasks.register("test-python-3.11-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.runTests") + doFirst { + systemProperty("exclude-aurora", "true") + systemProperty("exclude-python-38", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") + } +} + +tasks.register("test-python-3.8-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.runTests") + doFirst { + systemProperty("exclude-aurora", "true") + systemProperty("exclude-python-311", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") } } @@ -123,6 +164,7 @@ tasks.register("test-docker") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -134,6 +176,7 @@ tasks.register("test-aurora") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -148,6 +191,7 @@ tasks.register("test-pg-aurora") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -160,6 +204,7 @@ tasks.register("test-mysql-aurora") { systemProperty("exclude-performance", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -171,6 +216,7 @@ tasks.register("test-multi-az") { systemProperty("exclude-performance", "true") systemProperty("exclude-aurora", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -185,6 +231,7 @@ tasks.register("test-pg-multi-az") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -197,6 +244,7 @@ tasks.register("test-mysql-multi-az") { systemProperty("exclude-aurora", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -209,6 +257,7 @@ tasks.register("test-autoscaling") { systemProperty("exclude-performance", "true") systemProperty("exclude-mysql-driver", "true") systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -224,6 +273,7 @@ tasks.register("test-pg-aurora-performance") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -237,6 +287,24 @@ tasks.register("test-mysql-aurora-performance") { systemProperty("exclude-secrets-manager", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") + } +} + +tasks.register("test-all-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.runTests") + doFirst { + systemProperty("exclude-aurora", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") } } @@ -248,6 +316,7 @@ tasks.register("debug-all-environments") { doFirst { systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -259,6 +328,7 @@ tasks.register("debug-docker") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -270,6 +340,7 @@ tasks.register("debug-aurora") { systemProperty("exclude-multi-az", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -282,6 +353,7 @@ tasks.register("debug-pg-aurora") { systemProperty("exclude-performance", "true") systemProperty("exclude-mysql-driver", "true") systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -294,6 +366,7 @@ tasks.register("debug-mysql-aurora") { systemProperty("exclude-performance", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -307,6 +380,7 @@ tasks.register("debug-autoscaling") { systemProperty("exclude-performance", "true") systemProperty("exclude-mysql-driver", "true") systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -322,6 +396,7 @@ tasks.register("debug-pg-aurora-performance") { systemProperty("exclude-mysql-engine", "true") systemProperty("exclude-mariadb-driver", "true") systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -335,6 +410,7 @@ tasks.register("debug-mysql-aurora-performance") { systemProperty("exclude-secrets-manager", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -346,6 +422,7 @@ tasks.register("debug-multi-az") { systemProperty("exclude-aurora", "true") systemProperty("exclude-performance", "true") systemProperty("exclude-python-38", "true") + systemProperty("exclude-dsql", "true") } } @@ -358,6 +435,7 @@ tasks.register("debug-pg-multi-az") { systemProperty("exclude-performance", "true") systemProperty("exclude-mysql-driver", "true") systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-dsql", "true") } } @@ -370,5 +448,23 @@ tasks.register("debug-mysql-multi-az") { systemProperty("exclude-performance", "true") systemProperty("exclude-pg-driver", "true") systemProperty("exclude-pg-engine", "true") + systemProperty("exclude-dsql", "true") + } +} + +tasks.register("debug-all-dsql") { + group = "verification" + filter.includeTestsMatching("integration.host.TestRunner.debugTests") + doFirst { + systemProperty("exclude-aurora", "true") + systemProperty("exclude-autoscaling", "true") + systemProperty("exclude-docker", "true") + systemProperty("exclude-multi-az", "true") + systemProperty("exclude-mysql-driver", "true") + systemProperty("exclude-mysql-engine", "true") + systemProperty("exclude-mariadb-driver", "true") + systemProperty("exclude-mariadb-engine", "true") + systemProperty("exclude-performance", "true") + systemProperty("exclude-secrets-manager", "true") } } diff --git a/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java b/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java index 0126e0f2..3a16b5f0 100644 --- a/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java +++ b/tests/integration/host/src/test/java/integration/DatabaseEngineDeployment.java @@ -20,5 +20,6 @@ public enum DatabaseEngineDeployment { DOCKER, RDS, RDS_MULTI_AZ, - AURORA + AURORA, + DSQL } diff --git a/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java b/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java index 6cf8514a..800f332a 100644 --- a/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java +++ b/tests/integration/host/src/test/java/integration/TestEnvironmentFeatures.java @@ -25,6 +25,7 @@ public enum TestEnvironmentFeatures { AWS_CREDENTIALS_ENABLED, PERFORMANCE, RUN_AUTOSCALING_TESTS_ONLY, + RUN_DSQL_TESTS_ONLY, SKIP_MYSQL_DRIVER_TESTS, SKIP_PG_DRIVER_TESTS, TELEMETRY_TRACES_ENABLED, diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironment.java b/tests/integration/host/src/test/java/integration/host/TestEnvironment.java index 7cabd4b2..f0a96de5 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironment.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironment.java @@ -41,6 +41,8 @@ import java.sql.SQLException; import java.sql.Statement; import java.util.ArrayList; +import java.util.LinkedList; +import java.util.List; import java.util.concurrent.ExecutionException; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; @@ -127,6 +129,7 @@ public static TestEnvironment build(TestEnvironmentRequest request) throws IOExc break; case AURORA: case RDS_MULTI_AZ: + case DSQL: env = createAuroraOrMultiAzEnvironment(request); authorizeIP(env); @@ -200,7 +203,12 @@ private static TestEnvironment createAuroraOrMultiAzEnvironment(TestEnvironmentR } else { TestEnvironment env = new TestEnvironment(request); initDatabaseParams(env); - createDbCluster(env); + if (request.getDatabaseEngineDeployment() == DatabaseEngineDeployment.DSQL) { + createDsqlCluster(env); + } + else { + createDbCluster(env); + } if (request.getFeatures().contains(TestEnvironmentFeatures.IAM)) { if (request.getDatabaseEngineDeployment() == DatabaseEngineDeployment.RDS_MULTI_AZ) { @@ -334,8 +342,11 @@ private static void createDbCluster(TestEnvironment env, int numOfInstances) thr ArrayList instances = new ArrayList<>(); if (env.reuseAuroraDbCluster) { + if (StringUtils.isNullOrEmpty(env.auroraClusterName)) { + throw new RuntimeException("Environment variable RDS_CLUSTER_NAME is required."); + } if (StringUtils.isNullOrEmpty(env.auroraClusterDomain)) { - throw new RuntimeException("Environment variable AURORA_CLUSTER_DOMAIN is required."); + throw new RuntimeException("Environment variable RDS_CLUSTER_DOMAIN is required."); } if (!env.auroraUtil.doesClusterExist(env.auroraClusterName)) { @@ -439,6 +450,88 @@ private static void createDbCluster(TestEnvironment env, int numOfInstances) thr } } + + + private static void createDsqlCluster(TestEnvironment env) throws URISyntaxException { + + initAwsCredentials(env); + + env.info.setRegion( + !StringUtils.isNullOrEmpty(config.rdsDbRegion) + ? config.rdsDbRegion + : "us-east-2"); + + env.reuseAuroraDbCluster = config.reuseRdsCluster; + env.auroraClusterName = config.rdsClusterName; // "cluster-mysql" + env.auroraClusterDomain = config.rdsClusterDomain; // "XYZ.us-west-2.rds.amazonaws.com" + env.rdsEndpoint = config.rdsEndpoint; // "https://rds-int.amazon.com" + env.info.setRdsEndpoint(env.rdsEndpoint); + + env.auroraUtil = + new AuroraTestUtility( + env.info.getRegion(), + env.rdsEndpoint, + env.awsAccessKeyId, + env.awsSecretAccessKey, + env.awsSessionToken); + + + final String endpoint; + if (env.reuseAuroraDbCluster) { + if (StringUtils.isNullOrEmpty(env.auroraClusterName)) { + throw new RuntimeException("Environment variable RDS_CLUSTER_NAME is required."); + } + if (StringUtils.isNullOrEmpty(env.auroraClusterDomain)) { + throw new RuntimeException("Environment variable RDS_CLUSTER_DOMAIN is required."); + } + + endpoint = env.auroraClusterName + "." + env.auroraClusterDomain; + + final String identifier = env.auroraUtil.getDsqlInstanceId(endpoint); + if (!env.auroraUtil.doesDsqlClusterExist(identifier)) { + throw new RuntimeException( + String.format("It's requested to reuse existing DSQL cluster '%s' but it doesn't exist in region %s ", + endpoint, + env.info.getRegion())); + } + + LOGGER.finer( + "Reuse existing cluster " + endpoint); + + } else { + final String name = getRandomName(env.info.getRequest()); + try { + final String identifier = env.auroraUtil.createDsqlCluster(name); + env.auroraClusterName = identifier; + endpoint = String.format("%s.dsql.%s.on.aws", identifier, env.info.getRegion()); + } catch (Exception e) { + LOGGER.finer("Error creating a cluster " + name + ". " + e.getMessage()); + throw new RuntimeException(e); + } + } + + env.info.setClusterName(env.auroraClusterName); + + int port = getPort(env.info.getRequest()); + + env.info + .getDatabaseInfo() + .setClusterEndpoint(endpoint, port); + env.info + .getDatabaseInfo() + .setClusterReadOnlyEndpoint(endpoint, port); + + List instances = new LinkedList<>(); + instances.add(new TestInstanceInfo(env.auroraClusterName, endpoint, port)); + + env.info.getDatabaseInfo().getInstances().clear(); + env.info.getDatabaseInfo().getInstances().addAll(instances); + + authorizeIP(env); + + } + + private static void authorizeIP(TestEnvironment env) { try { env.runnerIP = env.auroraUtil.getPublicIPAddress(); @@ -578,14 +671,21 @@ private static int getPort(TestEnvironmentRequest request) { } private static void initDatabaseParams(TestEnvironment env) { - final String dbName = - !StringUtils.isNullOrEmpty(config.dbName) - ? config.dbName - : "test_database"; - final String dbUsername = - !StringUtils.isNullOrEmpty(config.dbUsername) - ? config.dbUsername - : "test_user"; + + final TestEnvironmentRequest request = env.info.getRequest(); + final boolean isDsql = (request.getDatabaseEngineDeployment() == DatabaseEngineDeployment.DSQL); + + final String dbName = isDsql + ? "postgres" + : !StringUtils.isNullOrEmpty(config.dbName) + ? config.dbName + : "test_database"; + final String dbUsername = isDsql + ? "admin" + : !StringUtils.isNullOrEmpty(config.dbUsername) + ? config.dbUsername + : "test_user"; + final String dbPassword = !StringUtils.isNullOrEmpty(config.dbPassword) ? config.dbPassword @@ -805,17 +905,24 @@ private static String getContainerBaseImageName(TestEnvironmentRequest request) private static void configureIamAccess(TestEnvironment env) { - if (env.info.getRequest().getDatabaseEngineDeployment() != DatabaseEngineDeployment.AURORA) { + if (env.info.getRequest().getDatabaseEngineDeployment() != DatabaseEngineDeployment.AURORA && + env.info.getRequest().getDatabaseEngineDeployment() != DatabaseEngineDeployment.DSQL) + { throw new UnsupportedOperationException( env.info.getRequest().getDatabaseEngineDeployment().toString()); } + final TestEnvironmentRequest request = env.info.getRequest(); + final boolean isDsql = (request.getDatabaseEngineDeployment() == DatabaseEngineDeployment.DSQL); + env.info.setIamUsername( - !StringUtils.isNullOrEmpty(config.iamUser) - ? config.iamUser - : "jane_doe"); + isDsql + ? "admin" + : !StringUtils.isNullOrEmpty(config.iamUser) + ? config.iamUser + : "jane_doe"); - if (!env.reuseAuroraDbCluster) { + if (!env.reuseAuroraDbCluster && !isDsql) { try { Class.forName(DriverHelper.getDriverClassname(env.info.getRequest().getDatabaseEngine())); } catch (ClassNotFoundException e) { @@ -918,6 +1025,7 @@ public void close() throws Exception { switch (this.info.getRequest().getDatabaseEngineDeployment()) { case AURORA: case RDS_MULTI_AZ: + case DSQL: deleteDbCluster(); break; case RDS: @@ -932,10 +1040,19 @@ private void deleteDbCluster() { auroraUtil.ec2DeauthorizesIP(runnerIP); } + final DatabaseEngineDeployment deployment = this.info.getRequest().getDatabaseEngineDeployment(); + + final String identifier; + if (deployment == DatabaseEngineDeployment.DSQL) { + identifier = this.auroraClusterName; + } else { + identifier = this.auroraClusterName + ".cluster-" + this.auroraClusterDomain; + } + if (!this.reuseAuroraDbCluster) { - LOGGER.finest("Deleting cluster " + this.auroraClusterName + ".cluster-" + this.auroraClusterDomain); + LOGGER.finest("Deleting cluster " + identifier); auroraUtil.deleteCluster(this.auroraClusterName); - LOGGER.finest("Deleted cluster " + this.auroraClusterName + ".cluster-" + this.auroraClusterDomain); + LOGGER.finest("Deleted cluster " + identifier); } } diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java b/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java index 6789df0b..f98c0d13 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironmentConfiguration.java @@ -44,6 +44,8 @@ public class TestEnvironmentConfiguration { Boolean.parseBoolean(System.getProperty("exclude-secrets-manager", "false")); public boolean testAutoscalingOnly = Boolean.parseBoolean(System.getProperty("test-autoscaling", "false")); + public boolean excludeDsql = + Boolean.parseBoolean(System.getProperty("exclude-dsql", "false")); public boolean excludeInstances1 = Boolean.parseBoolean(System.getProperty("exclude-instances-1", "false")); diff --git a/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java b/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java index d5cd972d..86929324 100644 --- a/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java +++ b/tests/integration/host/src/test/java/integration/host/TestEnvironmentProvider.java @@ -68,6 +68,9 @@ public Stream provideTestTemplateInvocationContex if (deployment == DatabaseEngineDeployment.RDS_MULTI_AZ && config.excludeMultiAz) { continue; } + if (deployment == DatabaseEngineDeployment.DSQL && config.excludeDsql) { + continue; + } for (DatabaseEngine engine : DatabaseEngine.values()) { if (engine == DatabaseEngine.PG && config.excludePgEngine) { @@ -76,9 +79,12 @@ public Stream provideTestTemplateInvocationContex if (engine == DatabaseEngine.MYSQL && config.excludeMysqlEngine) { continue; } + if (engine != DatabaseEngine.PG && DatabaseEngineDeployment.DSQL == deployment) { + continue; + } for (DatabaseInstances instances : DatabaseInstances.values()) { - if (deployment == DatabaseEngineDeployment.DOCKER + if ((deployment == DatabaseEngineDeployment.DOCKER || deployment == DatabaseEngineDeployment.DSQL) && instances != DatabaseInstances.SINGLE_INSTANCE) { continue; } @@ -141,6 +147,9 @@ public Stream provideTestTemplateInvocationContex || config.excludeIam ? null : TestEnvironmentFeatures.IAM, + deployment == DatabaseEngineDeployment.DSQL + ? TestEnvironmentFeatures.RUN_DSQL_TESTS_ONLY + : null, config.excludeSecretsManager ? null : TestEnvironmentFeatures.SECRETS_MANAGER, config.excludePerformance ? null : TestEnvironmentFeatures.PERFORMANCE, config.excludeMysqlDriver ? TestEnvironmentFeatures.SKIP_MYSQL_DRIVER_TESTS : null, diff --git a/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java b/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java index 919737b8..a58c3a6d 100644 --- a/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java +++ b/tests/integration/host/src/test/java/integration/util/AuroraTestUtility.java @@ -34,13 +34,19 @@ import java.time.Duration; import java.util.ArrayList; import java.util.Comparator; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Random; import java.util.concurrent.TimeUnit; import java.util.logging.Logger; +import java.util.regex.Matcher; +import java.util.regex.Pattern; import java.util.stream.Collectors; +import org.apache.logging.log4j.CloseableThreadContext.Instance; + import software.amazon.awssdk.auth.credentials.AwsBasicCredentials; import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider; import software.amazon.awssdk.auth.credentials.AwsSessionCredentials; @@ -48,6 +54,12 @@ import software.amazon.awssdk.auth.credentials.StaticCredentialsProvider; import software.amazon.awssdk.core.waiters.WaiterResponse; import software.amazon.awssdk.regions.Region; +import software.amazon.awssdk.retries.api.BackoffStrategy; +import software.amazon.awssdk.services.dsql.DsqlClient; +import software.amazon.awssdk.services.dsql.model.CreateClusterRequest; +import software.amazon.awssdk.services.dsql.model.CreateClusterResponse; +import software.amazon.awssdk.services.dsql.model.GetClusterResponse; +import software.amazon.awssdk.services.dsql.model.ResourceNotFoundException; import software.amazon.awssdk.services.ec2.Ec2Client; import software.amazon.awssdk.services.ec2.model.DescribeSecurityGroupsResponse; import software.amazon.awssdk.services.ec2.model.Ec2Exception; @@ -98,10 +110,19 @@ public class AuroraTestUtility { private final RdsClient rdsClient; private final Ec2Client ec2Client; + private final DsqlClient dsqlClient; private static final Random rand = new Random(); private static final String DUPLICATE_IP_ERROR_CODE = "InvalidPermission.Duplicate"; + private static final Pattern AURORA_DSQL_CLUSTER_PATTERN = + Pattern.compile( + "^(?[^.]+)\\." + + "(?dsql(?:-[^.]+)?)\\." + + "(?(?[a-zA-Z0-9\\-]+)" + + "\\.on\\.aws\\.?)$", + Pattern.CASE_INSENSITIVE); + public AuroraTestUtility( String region, String rdsEndpoint, String awsAccessKeyId, String awsSecretAccessKey, String awsSessionToken) throws URISyntaxException { @@ -139,7 +160,11 @@ public AuroraTestUtility(Region region, String rdsEndpoint, AwsCredentialsProvid .region(dbRegion) .credentialsProvider(credentialsProvider) .build(); - } + dsqlClient = DsqlClient.builder() + .region(dbRegion) + .credentialsProvider(credentialsProvider) + .build(); + } protected static Region getRegionInternal(String rdsRegion) { Optional regionOptional = @@ -339,6 +364,41 @@ public String createMultiAzCluster() throws InterruptedException { return clusterDomainPrefix; } + /** + * Create a DSQL cluster. + * + * @param name A human-readable name to tag the cluster with. + * @return The unique identifier of the created cluster. + */ + public String createDsqlCluster(final String name) throws InterruptedException { + final Map tagMap = new HashMap<>(); + tagMap.put("Name", name); + + final CreateClusterRequest request = CreateClusterRequest.builder() + .deletionProtectionEnabled(false) + .tags(tagMap) + .build(); + final CreateClusterResponse cluster = dsqlClient.createCluster(request); + + this.dbEngineDeployment = DatabaseEngineDeployment.DSQL; + this.dbIdentifier = cluster.identifier(); + + final WaiterResponse waiterResponse = dsqlClient.waiter().waitUntilClusterActive( + getCluster -> getCluster.identifier(cluster.identifier()), + config -> config.backoffStrategyV2( + BackoffStrategy.fixedDelayWithoutJitter(Duration.ofSeconds(10)) + ).waitTimeout(Duration.ofMinutes(30)) + ); + + if (waiterResponse.matched().exception().isPresent()) { + deleteCluster(); + throw new InterruptedException( + "Unable to create DSQL cluster after waiting for 30 minutes"); + } + + return cluster.identifier(); + } + /** * Gets public IP. * @@ -435,14 +495,21 @@ public void deleteCluster(String identifier) { * Destroys all instances and clusters. Removes IP from EC2 whitelist. */ public void deleteCluster() { + final DatabaseEngineDeployment deployment = this.dbEngineDeployment; + if (deployment == null) { + throw new UnsupportedOperationException("DB engine deployment must be non-null"); + } - switch (this.dbEngineDeployment) { + switch (deployment) { case AURORA: this.deleteAuroraCluster(); break; case RDS_MULTI_AZ: this.deleteMultiAzCluster(); break; + case DSQL: + this.deleteDsqlCluster(); + break; default: throw new UnsupportedOperationException(this.dbEngineDeployment.toString()); } @@ -509,6 +576,23 @@ public void deleteMultiAzCluster() { } } + public void deleteDsqlCluster() { + dsqlClient.deleteCluster(r -> r.identifier(dbIdentifier)); + + WaiterResponse waiterResponse = dsqlClient.waiter().waitUntilClusterNotExists( + getCluster -> getCluster.identifier(dbIdentifier), + config -> config.backoffStrategyV2( + BackoffStrategy.fixedDelayWithoutJitter(Duration.ofSeconds(10)) + ).waitTimeout(Duration.ofMinutes(30)) + ); + + if (waiterResponse.matched().exception().isPresent() + && !(waiterResponse.matched().exception().get() instanceof ResourceNotFoundException)) { + throw new RuntimeException( + "Unable to delete DSQL cluster after waiting for 30 minutes"); + } + } + public boolean doesClusterExist(final String clusterId) { final DescribeDbClustersRequest request = DescribeDbClustersRequest.builder().dbClusterIdentifier(clusterId).build(); @@ -520,6 +604,28 @@ public boolean doesClusterExist(final String clusterId) { return true; } + public boolean doesDsqlClusterExist(final String identifier) { + try { + final GetClusterResponse response = dsqlClient.getCluster(r -> r.identifier(identifier)); + return response.sdkHttpResponse().isSuccessful(); + } catch (ResourceNotFoundException ex) { + return false; + } + } + + public String getDsqlInstanceId(final String host) { + + if (StringUtils.isNullOrEmpty(host)) { + return null; + } + + final Matcher matcher = AURORA_DSQL_CLUSTER_PATTERN.matcher(host); + if (!matcher.matches()) { + return null; + } + return matcher.group("instance"); + } + public DBCluster getClusterInfo(final String clusterId) { final DescribeDbClustersRequest request = DescribeDbClustersRequest.builder().dbClusterIdentifier(clusterId).build(); diff --git a/tests/unit/test_federated_auth_plugin.py b/tests/unit/test_federated_auth_plugin.py index 1c3a77e3..3f982a70 100644 --- a/tests/unit/test_federated_auth_plugin.py +++ b/tests/unit/test_federated_auth_plugin.py @@ -25,6 +25,7 @@ from aws_advanced_python_wrapper.iam_plugin import TokenInfo from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils _GENERATED_TOKEN = "generated_token" _TEST_TOKEN = "test_token" @@ -101,6 +102,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi _token_cache[_PG_CACHE_KEY] = initial_token target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser" _token_cache[key] = initial_token @@ -129,7 +131,10 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) _token_cache[_PG_CACHE_KEY] = initial_token - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -154,7 +159,10 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m test_props: Properties = Properties({"plugins": "federated_auth", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) WrapperProperties.DB_USER.set(test_props, _DB_USER) - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -183,7 +191,9 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess exception_message = "generic exception" mock_func.side_effect = Exception(exception_message) - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), mock_session) with pytest.raises(Exception) as e_info: target_plugin.connect( @@ -229,7 +239,10 @@ def test_connect_with_specified_iam_host_port_region(mocker, mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}" - target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: FederatedAuthPlugin = FederatedAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, diff --git a/tests/unit/test_iam_dsql_plugin.py b/tests/unit/test_iam_dsql_plugin.py new file mode 100644 index 00000000..edd81d48 --- /dev/null +++ b/tests/unit/test_iam_dsql_plugin.py @@ -0,0 +1,444 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import annotations + +import urllib.request +from datetime import datetime, timedelta +from typing import Dict +from unittest.mock import patch + +import pytest + +from aws_advanced_python_wrapper.hostinfo import HostInfo +from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo +from aws_advanced_python_wrapper.utils.dsql_token_utils import DSQLTokenUtils +from aws_advanced_python_wrapper.utils.iam_utils import IamAuthUtils +from aws_advanced_python_wrapper.utils.properties import (Properties, + WrapperProperties) + +_GENERATED_TOKEN = "generated_token admin" +_GENERATED_TOKEN_NON_ADMIN = "generated_token non-admin" +_TEST_TOKEN = "test_token" +_DEFAULT_PG_PORT = 5432 + +_PG_HOST_INFO = HostInfo("dsqltestclusternamefoobar1.dsql.us-east-2.on.aws") +_PG_HOST_INFO_WITH_PORT = HostInfo(_PG_HOST_INFO.url, port=1234) +_PG_REGION = "us-east-2" + +_PG_CACHE_KEY = f"{_PG_REGION}:{_PG_HOST_INFO.url}:{_DEFAULT_PG_PORT}:admin" + + +_token_cache: Dict[str, TokenInfo] = {} + + +@pytest.fixture(autouse=True) +def clear_caches(): + _token_cache.clear() + + +@pytest.fixture +def mock_session(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_client(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_connection(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_func(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_plugin_service(mocker): + return mocker.MagicMock() + + +@pytest.fixture +def mock_dialect(mocker): + return mocker.MagicMock() + + +@pytest.fixture(autouse=True) +def mock_default_behavior(mock_session, mock_client, mock_func, mock_connection, mock_plugin_service, mock_dialect): + mock_session.client.return_value = mock_client + mock_client.generate_db_connect_admin_auth_token.return_value = _GENERATED_TOKEN + mock_client.generate_db_connect_auth_token.return_value = _GENERATED_TOKEN_NON_ADMIN + mock_session.get_available_regions.return_value = ['us-east-1', 'us-east-2', 'us-west-1', 'us-west-2'] + mock_func.return_value = mock_connection + mock_plugin_service.driver_dialect = mock_dialect + mock_plugin_service.database_dialect = mock_dialect + mock_dialect.default_port = _DEFAULT_PG_PORT + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def set_token_cache(user, host, port, region, expired=False): + if not expired: + initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) + else: + initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) + cache_key: str = IamAuthUtils.get_cache_key( + user, + host, + port, + region + ) + _token_cache[cache_key] = initial_token + + return cache_key, initial_token + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_pg_connect_valid_token_in_cache(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + cache_key, _ = set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION) + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + actual_token = _token_cache.get(cache_key) + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_not_called() + assert _GENERATED_TOKEN != actual_token.token + else: + mock_client.generate_db_connect_auth_token.assert_not_called() + assert _GENERATED_TOKEN_NON_ADMIN != actual_token.token + + assert _TEST_TOKEN == actual_token.token + assert not actual_token.is_expired() + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_pg_connect_with_invalid_port_fall_backs_to_host_port( + mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + invalid_port = "0" + test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = invalid_port + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO_WITH_PORT, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION + ) + + actual_token = _token_cache.get(f"{_PG_REGION}:{_PG_HOST_INFO.url}:1234:admin") + assert _GENERATED_TOKEN == actual_token.token + assert not actual_token.is_expired() + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin", "iam_default_port": "0"} + mock_dialect.set_password.assert_called_with(expected_props, _GENERATED_TOKEN) + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( + mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + expected_default_pg_port = 5432 + invalid_port = "0" + test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = invalid_port + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION + ) + + actual_token = _token_cache.get( + f"{_PG_REGION}:{_PG_HOST_INFO.url}:{expected_default_pg_port}:admin") + assert _GENERATED_TOKEN == actual_token.token + assert not actual_token.is_expired() + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin", "iam_default_port": "0"} + mock_dialect.set_password.assert_called_with(expected_props, _GENERATED_TOKEN) + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_expired_token_in_cache(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + cache_key, initial_token = set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION, True) + + mock_func.side_effect = Exception("generic exception") + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + with pytest.raises(Exception): + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + actual_token = _token_cache.get(cache_key) + assert initial_token != actual_token + assert not actual_token.is_expired() + + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION) + assert _GENERATED_TOKEN == actual_token.token + else: + mock_client.generate_db_connect_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION) + assert _GENERATED_TOKEN_NON_ADMIN == actual_token.token + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_empty_cache(user, mocker, mock_plugin_service, mock_connection, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + actual_connection = target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + cache_key: str = IamAuthUtils.get_cache_key( + user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION + ) + actual_token = _token_cache.get(cache_key) + + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION + ) + assert _GENERATED_TOKEN == actual_token.token + else: + mock_client.generate_db_connect_auth_token.assert_called_with( + _PG_HOST_INFO.url, _PG_REGION) + assert _GENERATED_TOKEN_NON_ADMIN == actual_token.token + + assert mock_connection == actual_connection + assert not actual_token.is_expired() + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + cache_key_with_new_port: str = f"{_PG_REGION}:{_PG_HOST_INFO.url}:1234:admin" + initial_token = TokenInfo(f"{_TEST_TOKEN}:1234", datetime.now() + timedelta(minutes=5)) + _token_cache[cache_key_with_new_port] = initial_token + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO_WITH_PORT, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_not_called() + + actual_token = _token_cache.get(cache_key_with_new_port) + assert _token_cache.get(_PG_CACHE_KEY) is None + assert _GENERATED_TOKEN != actual_token.token + assert f"{_TEST_TOKEN}:1234" == actual_token.token + assert not actual_token.is_expired() + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin"} + mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:1234") + + +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + iam_default_port: str = "9999" + test_props[WrapperProperties.IAM_DEFAULT_PORT.name] = iam_default_port + cache_key_with_new_port = f"{_PG_REGION}:{_PG_HOST_INFO.url}:{iam_default_port}:admin" + initial_token = TokenInfo(f"{_TEST_TOKEN}:{iam_default_port}", datetime.now() + timedelta(minutes=5)) + _token_cache[cache_key_with_new_port] = initial_token + + # Assert no password has been set + assert test_props.get("password") is None + + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=_PG_HOST_INFO_WITH_PORT, + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_not_called() + + actual_token = _token_cache.get(cache_key_with_new_port) + assert _token_cache.get(_PG_CACHE_KEY) is None + assert _GENERATED_TOKEN != actual_token.token + assert f"{_TEST_TOKEN}:{iam_default_port}" == actual_token.token + assert not actual_token.is_expired() + + # Assert password has been updated to the value in token cache + expected_props = {"user": "admin", "iam_default_port": "9999"} + mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:{iam_default_port}") + + +@pytest.mark.parametrize("user", [ + pytest.param("admin"), + pytest.param("non-admin"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_region(user, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": user}) + iam_region: str = "us-east-1" + + # Cache a token with a different region + set_token_cache(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, _PG_REGION) + test_props[WrapperProperties.IAM_REGION.name] = iam_region + + # Assert no password has been set + assert test_props.get("password") is None + + mock_client.generate_db_connect_admin_auth_token.return_value = f"{_TEST_TOKEN}:{iam_region}" + mock_client.generate_db_connect_auth_token.return_value = f"{_GENERATED_TOKEN_NON_ADMIN}:{iam_region}" + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=HostInfo(_PG_HOST_INFO.url), + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_session.client.assert_called_with( + "dsql", + region_name=iam_region + ) + + expected_props = {"iam_region": "us-east-1", "user": user} + actual_token = _token_cache.get(IamAuthUtils.get_cache_key(user, _PG_HOST_INFO.url, _DEFAULT_PG_PORT, iam_region)) + assert not actual_token.is_expired() + + if user == "admin": + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + _PG_HOST_INFO.url, iam_region + ) + assert f"{_TEST_TOKEN}:{iam_region}" == actual_token.token + mock_dialect.set_password.assert_called_with(expected_props, f"{_TEST_TOKEN}:{iam_region}") + else: + mock_client.generate_db_connect_auth_token.assert_called_with( + _PG_HOST_INFO.url, iam_region) + assert f"{_GENERATED_TOKEN_NON_ADMIN}:{iam_region}" == actual_token.token + mock_dialect.set_password.assert_called_with(expected_props, f"{_GENERATED_TOKEN_NON_ADMIN}:{iam_region}") + + +@pytest.mark.parametrize("iam_host", [ + pytest.param("dsqltestclusternamefoobar1.dsql.us-east-2.on.aws"), + pytest.param("dsqltestclusternamefoobar2.dsql.us-east-2.on.aws"), +]) +@patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) +def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): + test_props: Properties = Properties({"user": "admin"}) + + test_props[WrapperProperties.IAM_HOST.name] = iam_host + + # Assert no password has been set + assert test_props.get("password") is None + + mock_client.generate_db_connect_admin_auth_token.return_value = f"{_TEST_TOKEN}:{iam_host}" + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + DSQLTokenUtils(), + mock_session) + target_plugin.connect( + target_driver_func=mocker.MagicMock(), + driver_dialect=mock_dialect, + host_info=HostInfo("bar.foo.com"), + props=test_props, + is_initial_connection=False, + connect_func=mock_func) + + mock_client.generate_db_connect_admin_auth_token.assert_called_with( + iam_host, _PG_REGION + ) + + actual_token = _token_cache.get(f"{_PG_REGION}:{iam_host}:5432:admin") + assert actual_token is not None + assert _GENERATED_TOKEN != actual_token.token + assert f"{_TEST_TOKEN}:{iam_host}" == actual_token.token + assert not actual_token.is_expired() + + +def test_aws_supported_regions_url_exists(): + url = "https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html" + assert 200 == urllib.request.urlopen(url).getcode() diff --git a/tests/unit/test_iam_plugin.py b/tests/unit/test_iam_plugin.py index 04273698..10a3a3cd 100644 --- a/tests/unit/test_iam_plugin.py +++ b/tests/unit/test_iam_plugin.py @@ -26,6 +26,7 @@ from aws_advanced_python_wrapper.iam_plugin import IamAuthPlugin, TokenInfo from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils _GENERATED_TOKEN = "generated_token" _TEST_TOKEN = "test_token" @@ -99,6 +100,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi _token_cache[_PG_CACHE_KEY] = initial_token target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -127,6 +129,7 @@ def test_pg_connect_with_invalid_port_fall_backs_to_host_port( assert test_props.get("password") is None target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -163,6 +166,7 @@ def test_pg_connect_with_invalid_port_and_no_host_port_fall_backs_to_host_port( assert test_props.get("password") is None target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -195,7 +199,9 @@ def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_sessio _token_cache[_PG_CACHE_KEY] = initial_token mock_func.side_effect = Exception("generic exception") - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) with pytest.raises(Exception): target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -220,7 +226,9 @@ def test_connect_expired_token_in_cache(mocker, mock_plugin_service, mock_sessio @patch("aws_advanced_python_wrapper.iam_plugin.IamAuthPlugin._token_cache", _token_cache) def test_connect_empty_cache(mocker, mock_plugin_service, mock_connection, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) actual_connection = target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -251,7 +259,9 @@ def test_connect_with_specified_port(mocker, mock_plugin_service, mock_session, # Assert no password has been set assert test_props.get("password") is None - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -285,7 +295,9 @@ def test_connect_with_specified_iam_default_port(mocker, mock_plugin_service, mo # Assert no password has been set assert test_props.get("password") is None - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -323,7 +335,9 @@ def test_connect_with_specified_region(mocker, mock_plugin_service, mock_session assert test_props.get("password") is None mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{iam_region}" - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -369,7 +383,9 @@ def test_connect_with_specified_host(iam_host: str, mocker, mock_plugin_service, assert test_props.get("password") is None mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{iam_host}" - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, @@ -411,7 +427,7 @@ def test_aws_supported_regions_url_exists(): def test_invalid_iam_host(host, mocker, mock_plugin_service, mock_session, mock_func, mock_client, mock_dialect): test_props: Properties = Properties({"user": "postgresqlUser"}) with pytest.raises(AwsWrapperError): - target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, mock_session) + target_plugin: IamAuthPlugin = IamAuthPlugin(mock_plugin_service, RDSTokenUtils(), mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect, diff --git a/tests/unit/test_okta_plugin.py b/tests/unit/test_okta_plugin.py index 72f9727a..e2823568 100644 --- a/tests/unit/test_okta_plugin.py +++ b/tests/unit/test_okta_plugin.py @@ -25,6 +25,7 @@ from aws_advanced_python_wrapper.okta_plugin import OktaAuthPlugin from aws_advanced_python_wrapper.utils.properties import (Properties, WrapperProperties) +from aws_advanced_python_wrapper.utils.rds_token_utils import RDSTokenUtils _GENERATED_TOKEN = "generated_token" _TEST_TOKEN = "test_token" @@ -100,7 +101,7 @@ def test_pg_connect_valid_token_in_cache(mocker, mock_plugin_service, mock_sessi initial_token = TokenInfo(_TEST_TOKEN, datetime.now() + timedelta(minutes=5)) _token_cache[_PG_CACHE_KEY] = initial_token - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, RDSTokenUtils(), mock_session) key = "us-east-2:pg.testdb.us-east-2.rds.amazonaws.com:" + str(_DEFAULT_PG_PORT) + ":postgesqlUser" _token_cache[key] = initial_token @@ -127,7 +128,10 @@ def test_expired_cached_token(mocker, mock_plugin_service, mock_session, mock_fu initial_token = TokenInfo(_TEST_TOKEN, datetime.now() - timedelta(minutes=5)) _token_cache[_PG_CACHE_KEY] = initial_token - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -151,7 +155,10 @@ def test_no_cached_token(mocker, mock_plugin_service, mock_session, mock_func, m test_props: Properties = Properties({"plugins": "okta", "user": "postgresqlUser", "idp_username": "user", "idp_password": "password"}) WrapperProperties.DB_USER.set(test_props, _DB_USER) - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), @@ -179,7 +186,10 @@ def test_no_cached_token_raises_exception(mocker, mock_plugin_service, mock_sess exception_message = "generic exception" mock_func.side_effect = Exception(exception_message) - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) with pytest.raises(Exception) as e_info: target_plugin.connect( @@ -225,7 +235,10 @@ def test_connect_with_specified_iam_host_port_region(mocker, mock_client.generate_db_auth_token.return_value = f"{_TEST_TOKEN}:{expected_region}" - target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, mock_credentials_provider_factory, mock_session) + target_plugin: OktaAuthPlugin = OktaAuthPlugin(mock_plugin_service, + mock_credentials_provider_factory, + RDSTokenUtils(), + mock_session) target_plugin.connect( target_driver_func=mocker.MagicMock(), driver_dialect=mock_dialect,