Skip to content

Commit

Permalink
RDS: Improve handling of boolean attributes
Browse files Browse the repository at this point in the history
* Use the `_get_bool_param` helper
* Explicitly guard against None values in the model
* Standardize XML output

Note:

* `_get_bool_param` allows for a default value to be set, but handling it in
  the model ensures that we do it in one place, regardless of where the input
  comes from (e.g. our input may not always come via the RDS response class).
* Using the Jinja2 `lower` filter won't work if we have a None value.  We guard
  against that now, but it's still better to be explicit here and only allow the
  strings 'true' or 'false'.
  • Loading branch information
bpandola authored and bblommers committed Jul 9, 2024

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature. The key has expired.
1 parent 8165c18 commit 1c8c0fe
Showing 2 changed files with 25 additions and 10 deletions.
27 changes: 21 additions & 6 deletions moto/rds/models.py
Original file line number Diff line number Diff line change
@@ -88,9 +88,13 @@ def __init__(
self.storage_encrypted = (
storage_encrypted and storage_encrypted.lower() == "true"
)
if self.storage_encrypted is None:
self.storage_encrypted = False
self.deletion_protection = (
deletion_protection and deletion_protection.lower() == "true"
)
if self.deletion_protection is None:
self.deletion_protection = False
self.members: List[Cluster] = []

def to_xml(self) -> str:
@@ -138,6 +142,8 @@ def __init__(self, **kwargs: Any):
self.db_cluster_identifier = kwargs.get("db_cluster_identifier")
self.db_cluster_instance_class = kwargs.get("db_cluster_instance_class")
self.deletion_protection = kwargs.get("deletion_protection")
if self.deletion_protection is None:
self.deletion_protection = False
self.engine = kwargs.get("engine")
if self.engine not in ClusterEngine.list_cluster_engines():
raise InvalidParameterValue(
@@ -237,6 +243,8 @@ def __init__(self, **kwargs: Any):
self.read_replica_identifiers: List[str] = list()
self.is_writer: bool = False
self.storage_encrypted = kwargs.get("storage_encrypted", False)
if self.storage_encrypted is None:
self.storage_encrypted = False
if self.storage_encrypted:
self.kms_key_id = kwargs.get("kms_key_id", "default_kms_key_id")
else:
@@ -264,13 +272,14 @@ def __init__(self, **kwargs: Any):
else:
self.backtrack_window = 0

self.iam_auth: bool = False
if auth := kwargs.get("enable_iam_database_authentication", False):
self.iam_auth = kwargs.get("enable_iam_database_authentication", False)
if self.iam_auth is None:
self.iam_auth = False
if self.iam_auth:
if not self.engine.startswith("aurora-"):
raise InvalidParameterCombination(
"IAM Authentication is currently not supported by Multi-AZ DB clusters."
)
self.iam_auth = auth

@property
def is_multi_az(self) -> bool:
@@ -409,7 +418,7 @@ def to_xml(self, initial: bool = False) -> str:
<DbClusterResourceId>{{ cluster.resource_id }}</DbClusterResourceId>
<DBClusterArn>{{ cluster.db_cluster_arn }}</DBClusterArn>
<AssociatedRoles></AssociatedRoles>
<IAMDatabaseAuthenticationEnabled>{{ cluster.iam_auth | string | lower }}</IAMDatabaseAuthenticationEnabled>
<IAMDatabaseAuthenticationEnabled>{{ 'true' if cluster.iam_auth else 'false' }}</IAMDatabaseAuthenticationEnabled>
<EngineMode>{{ cluster.engine_mode }}</EngineMode>
<DeletionProtection>{{ 'true' if cluster.deletion_protection else 'false' }}</DeletionProtection>
<HttpEndpointEnabled>{{ 'true' if cluster.enable_http_endpoint else 'false' }}</HttpEndpointEnabled>
@@ -667,6 +676,8 @@ def __init__(self, **kwargs: Any):
if not self.availability_zone:
self.availability_zone = f"{self.region_name}a"
self.multi_az = kwargs.get("multi_az")
if self.multi_az is None:
self.multi_az = False
self.db_subnet_group_name = kwargs.get("db_subnet_group_name")
self.db_subnet_group = None
if self.db_subnet_group_name:
@@ -713,9 +724,13 @@ def __init__(self, **kwargs: Any):
self.enable_iam_database_authentication = kwargs.get(
"enable_iam_database_authentication", False
)
if self.enable_iam_database_authentication is None:
self.enable_iam_database_authentication = False
self.dbi_resource_id = "db-M5ENSHXFPU6XHZ4G4ZEI5QIO2U"
self.tags = kwargs.get("tags", [])
self.deletion_protection = kwargs.get("deletion_protection", False)
if self.deletion_protection is None:
self.deletion_protection = False
self.enabled_cloudwatch_logs_exports = (
kwargs.get("enable_cloudwatch_logs_exports") or []
)
@@ -865,7 +880,7 @@ def to_xml(self) -> str:
<CopyTagsToSnapshot>{{ database.copy_tags_to_snapshot }}</CopyTagsToSnapshot>
<AutoMinorVersionUpgrade>{{ database.auto_minor_version_upgrade }}</AutoMinorVersionUpgrade>
<AllocatedStorage>{{ database.allocated_storage }}</AllocatedStorage>
<StorageEncrypted>{{ database.storage_encrypted }}</StorageEncrypted>
<StorageEncrypted>{{ 'true' if database.storage_encrypted else 'false' }}</StorageEncrypted>
{% if database.kms_key_id %}
<KmsKeyId>{{ database.kms_key_id }}</KmsKeyId>
{% endif %}
@@ -1222,7 +1237,7 @@ def to_xml(self) -> str:
<DBSnapshotArn>{{ snapshot.snapshot_arn }}</DBSnapshotArn>
<Timezone></Timezone>
{% if database.enable_iam_database_authentication %}
<IAMDatabaseAuthenticationEnabled>{{ database.enable_iam_database_authentication|lower }}</IAMDatabaseAuthenticationEnabled>
<IAMDatabaseAuthenticationEnabled>{{ 'true' if database.enable_iam_database_authentication else 'false' }}</IAMDatabaseAuthenticationEnabled>
{% endif %}
</DBSnapshot>"""
)
8 changes: 4 additions & 4 deletions moto/rds/responses.py
Original file line number Diff line number Diff line change
@@ -80,7 +80,7 @@ def _get_db_kwargs(self) -> Dict[str, Any]:
"security_groups": self._get_multi_param(
"DBSecurityGroups.DBSecurityGroupName"
),
"storage_encrypted": self._get_param("StorageEncrypted"),
"storage_encrypted": self._get_bool_param("StorageEncrypted"),
"storage_type": self._get_param("StorageType", None),
"vpc_security_group_ids": self._get_multi_param(
"VpcSecurityGroupIds.VpcSecurityGroupId"
@@ -137,7 +137,7 @@ def _get_modify_db_cluster_kwargs(self) -> Dict[str, Any]:
"security_groups": self._get_multi_param(
"DBSecurityGroups.DBSecurityGroupName"
),
"storage_encrypted": self._get_param("StorageEncrypted"),
"storage_encrypted": self._get_bool_param("StorageEncrypted"),
"storage_type": self._get_param("StorageType", None),
"vpc_security_group_ids": self._get_multi_param(
"VpcSecurityGroupIds.VpcSecurityGroupId"
@@ -200,7 +200,7 @@ def _get_db_cluster_kwargs(self) -> Dict[str, Any]:
"allocated_storage": self._get_param("AllocatedStorage"),
"global_cluster_identifier": self._get_param("GlobalClusterIdentifier"),
"iops": self._get_param("Iops"),
"storage_encrypted": self._get_param("StorageEncrypted"),
"storage_encrypted": self._get_bool_param("StorageEncrypted"),
"enable_global_write_forwarding": self._get_param(
"EnableGlobalWriteForwarding"
),
@@ -214,7 +214,7 @@ def _get_db_cluster_kwargs(self) -> Dict[str, Any]:
"region": self.region,
"db_cluster_instance_class": self._get_param("DBClusterInstanceClass"),
"enable_http_endpoint": self._get_bool_param("EnableHttpEndpoint"),
"copy_tags_to_snapshot": self._get_param("CopyTagsToSnapshot"),
"copy_tags_to_snapshot": self._get_bool_param("CopyTagsToSnapshot"),
"tags": self.unpack_list_params("Tags", "Tag"),
"scaling_configuration": self._get_dict_param("ScalingConfiguration."),
"serverless_v2_scaling_configuration": params.get(

0 comments on commit 1c8c0fe

Please sign in to comment.