Skip to content

Commit

Permalink
Sanatizing account names before using them in SFn Invocation
Browse files Browse the repository at this point in the history
  • Loading branch information
StewartW committed Jan 31, 2023
1 parent 349601e commit 889f030
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 32 deletions.
86 changes: 59 additions & 27 deletions src/lambda_codebase/account_processing/process_account_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ class AccountFileData(TypedDict):
Class used to return YAML account file data and its related
metadata like the execution_id of the CodePipeline that uploaded it.
"""

content: Any
execution_id: str

Expand All @@ -65,8 +66,8 @@ def get_file_from_s3(
try:
LOGGER.debug(
"Reading YAML from S3: %s from %s",
s3_object_location.get('object_key'),
s3_object_location.get('bucket_name'),
s3_object_location.get("object_key"),
s3_object_location.get("bucket_name"),
)
s3_object = s3_resource.Object(**s3_object_location)
object_adf_version = s3_object.metadata.get(
Expand All @@ -80,12 +81,9 @@ def get_file_from_s3(
s3_object_location,
object_adf_version,
)
return {
"content": {},
"execution_id": ""
}
return {"content": {}, "execution_id": ""}

with tempfile.TemporaryFile(mode='w+b') as file_pointer:
with tempfile.TemporaryFile(mode="w+b") as file_pointer:
s3_object.download_fileobj(file_pointer)

# Move pointer to the start of the file
Expand All @@ -98,16 +96,16 @@ def get_file_from_s3(
except ClientError as error:
LOGGER.error(
"Failed to download %s from %s, due to %s",
s3_object_location.get('object_key'),
s3_object_location.get('bucket_name'),
s3_object_location.get("object_key"),
s3_object_location.get("bucket_name"),
error,
)
raise
except YAMLError as yaml_error:
LOGGER.error(
"Failed to parse YAML file: %s from %s, due to %s",
s3_object_location.get('object_key'),
s3_object_location.get('bucket_name'),
s3_object_location.get("object_key"),
s3_object_location.get("bucket_name"),
yaml_error,
)
raise
Expand All @@ -129,19 +127,52 @@ def process_account(account_lookup, account):


def process_account_list(all_accounts, accounts_in_file):
account_lookup = {
account["Name"]: account["Id"] for account in all_accounts
}
processed_accounts = list(map(
lambda account: process_account(
account_lookup=account_lookup,
account=account,
),
accounts_in_file
))
account_lookup = {account["Name"]: account["Id"] for account in all_accounts}
processed_accounts = list(
map(
lambda account: process_account(
account_lookup=account_lookup,
account=account,
),
accounts_in_file,
)
)
return processed_accounts


def sanitize_account_name_for_snf(account_name):
unsupported_characters = [
"<",
">",
"{",
"}",
"[",
"]",
" ",
"?",
"*",
'"',
"#",
"%",
"\\",
"^",
"|",
"~",
"`",
"$",
"&",
",",
";",
":",
"/",
]
sanitized_name = account_name[:30]
for char in unsupported_characters:
sanitized_name = sanitized_name.replace(char, "")

return sanitized_name


def start_executions(
sfn_client,
processed_account_list,
Expand All @@ -158,14 +189,14 @@ def start_executions(
run_id,
)
for account in processed_account_list:
full_account_name = account.get('account_full_name', 'no-account-name')
full_account_name = account.get("account_full_name", "no-account-name")
# AWS Step Functions supports max 80 characters.
# Since the run_id equals 49 characters plus the dash, we have 30
# characters available. To ensure we don't run over, lets use a
# truncated version instead:
truncated_account_name = full_account_name[:30]
sfn_execution_name = f"{truncated_account_name}-{run_id}"

sfn_execution_name = (
f"{sanitize_account_name_for_snf(full_account_name)}-{run_id}"
)
LOGGER.debug(
"Payload for %s: %s",
sfn_execution_name,
Expand All @@ -182,8 +213,9 @@ def lambda_handler(event, context):
"""Main Lambda Entry point"""
LOGGER.debug(
"Processing event: %s",
json.dumps(event, indent=2) if LOGGER.isEnabledFor(logging.DEBUG)
else "--data-hidden--"
json.dumps(event, indent=2)
if LOGGER.isEnabledFor(logging.DEBUG)
else "--data-hidden--",
)
sfn_client = boto3.client("stepfunctions")
s3_resource = boto3.resource("s3")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@
Tests the account file processing lambda
"""
import unittest
from ..process_account_files import process_account, process_account_list, get_details_from_event
from ..process_account_files import (
process_account,
process_account_list,
get_details_from_event,
sanitize_account_name_for_snf,
)


class SuccessTestCase(unittest.TestCase):
Expand All @@ -20,7 +25,7 @@ def test_process_account_when_account_exists(self):
"account_full_name": "myTestAccountName",
"account_id": 123456789012,
"needs_created": False,
}
},
)

def test_process_account_when_account_does_not_exist(self):
Expand All @@ -35,7 +40,7 @@ def test_process_account_when_account_does_not_exist(self):
"alias": "MyCoolAlias",
"account_full_name": "myTestAccountName",
"needs_created": True,
}
},
)

def test_process_account_list(self):
Expand All @@ -59,6 +64,43 @@ def test_process_account_list(self):
],
)

def test_get_sanitize_account_name(self):
self.assertEqual(
sanitize_account_name_for_snf("myTestAccountName"), "myTestAccountName"
)
self.assertEqual(
sanitize_account_name_for_snf(
"thisIsALongerAccountNameForTestingTruncatedNames"
),
"thisIsALongerAccountNameForTes",
)
self.assertEqual(
sanitize_account_name_for_snf(
"thisIsALongerAccountName ForTestingTruncatedNames"
),
"thisIsALongerAccountNameForTe",
)
self.assertEqual(
sanitize_account_name_for_snf("this accountname <has illegal> chars"),
"thisaccountnamehasillegal",
)
self.assertEqual(
sanitize_account_name_for_snf("this accountname \\has illegal chars"),
"thisaccountnamehasillegal",
)
self.assertEqual(
sanitize_account_name_for_snf("^startswithanillegalchar"),
"startswithanillegalchar",
)
self.assertEqual(
len(
sanitize_account_name_for_snf(
"ReallyLongAccountNameThatShouldBeTruncatedBecauseItsTooLong"
)
),
30,
)


class FailureTestCase(unittest.TestCase):
# pylint: disable=W0106
Expand All @@ -67,6 +109,5 @@ def test_event_parsing(self):
with self.assertRaises(ValueError) as _error:
get_details_from_event(sample_event)
self.assertEqual(
str(_error.exception),
"No S3 Event details present in event trigger"
str(_error.exception), "No S3 Event details present in event trigger"
)

0 comments on commit 889f030

Please sign in to comment.