From 50e26da6d0ee6b7521cfed37d1fa457f361ca07c Mon Sep 17 00:00:00 2001 From: Bikouo Aubin <79859644+abikouo@users.noreply.github.com> Date: Tue, 15 Nov 2022 21:18:44 +0100 Subject: [PATCH] Refactor module_utils/cloudfront_facts and add unit tests (#1265) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Refactor module_utils/cloudfront_facts and add unit tests Depends-On: ansible/ansible-zuul-jobs#1695 SUMMARY refactor module_utils/cloudfront_facts.py and unit tests ISSUE TYPE Feature Pull Request Reviewed-by: Gonéri Le Bouder Reviewed-by: Bikouo Aubin --- ...dule_utils_cloudfront_facts_unit_tests.yml | 3 + plugins/module_utils/cloudfront_facts.py | 249 +++++---- .../module_utils/test_cloudfront_facts.py | 522 ++++++++++++++++++ 3 files changed, 657 insertions(+), 117 deletions(-) create mode 100644 changelogs/fragments/module_utils_cloudfront_facts_unit_tests.yml create mode 100644 tests/unit/module_utils/test_cloudfront_facts.py diff --git a/changelogs/fragments/module_utils_cloudfront_facts_unit_tests.yml b/changelogs/fragments/module_utils_cloudfront_facts_unit_tests.yml new file mode 100644 index 00000000000..a252fa106fa --- /dev/null +++ b/changelogs/fragments/module_utils_cloudfront_facts_unit_tests.yml @@ -0,0 +1,3 @@ +--- +minor_changes: +- Refactor module_utils/cloudfront_facts.py and add unit tests (https://github.com/ansible-collections/amazon.aws/pull/1265). diff --git a/plugins/module_utils/cloudfront_facts.py b/plugins/module_utils/cloudfront_facts.py index 36e3603f3a7..103f3ae9718 100644 --- a/plugins/module_utils/cloudfront_facts.py +++ b/plugins/module_utils/cloudfront_facts.py @@ -26,6 +26,7 @@ Common cloudfront facts shared between modules """ +from functools import partial try: import botocore except ImportError: @@ -33,102 +34,134 @@ from .ec2 import AWSRetry from .ec2 import boto3_tag_list_to_ansible_dict +from .ec2 import snake_dict_to_camel_dict -class CloudFrontFactsServiceManager: - """Handles CloudFront Facts Services""" +class CloudFrontFactsServiceManagerFailure(Exception): + pass - def __init__(self, module): - self.module = module - self.client = module.client('cloudfront', retry_decorator=AWSRetry.jittered_backoff()) - def get_distribution(self, distribution_id): - try: - return self.client.get_distribution(Id=distribution_id, aws_retry=True) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error describing distribution") +def cloudfront_facts_keyed_list_helper(list_to_key): + result = dict() + for item in list_to_key: + distribution_id = item['Id'] + if 'Items' in item['Aliases']: + result.update({alias: item for alias in item['Aliases']['Items']}) + result.update({distribution_id: item}) + return result - def get_distribution_config(self, distribution_id): - try: - return self.client.get_distribution_config(Id=distribution_id, aws_retry=True) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error describing distribution configuration") - def get_origin_access_identity(self, origin_access_identity_id): - try: - return self.client.get_cloud_front_origin_access_identity(Id=origin_access_identity_id, aws_retry=True) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error describing origin access identity") +@AWSRetry.jittered_backoff() +def _cloudfront_paginate_build_full_result(client, client_method, **kwargs): + paginator = client.get_paginator(client_method) + return paginator.paginate(**kwargs).build_full_result() - def get_origin_access_identity_config(self, origin_access_identity_id): - try: - return self.client.get_cloud_front_origin_access_identity_config(Id=origin_access_identity_id, aws_retry=True) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error describing origin access identity configuration") - def get_invalidation(self, distribution_id, invalidation_id): - try: - return self.client.get_invalidation(DistributionId=distribution_id, Id=invalidation_id, aws_retry=True) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error describing invalidation") +class CloudFrontFactsServiceManager: + """Handles CloudFront Facts Services""" - def get_streaming_distribution(self, distribution_id): - try: - return self.client.get_streaming_distribution(Id=distribution_id, aws_retry=True) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error describing streaming distribution") + CLOUDFRONT_CLIENT_API_MAPPING = { + "get_distribution": { + "error": "Error describing distribution", + }, + "get_distribution_config": { + "error": "Error describing distribution configuration", + }, + "get_origin_access_identity": { + "error": "Error describing origin access identity", + "client_api": "get_cloud_front_origin_access_identity" + }, + "get_origin_access_identity_config": { + "error": "Error describing origin access identity configuration", + "client_api": "get_cloud_front_origin_access_identity_config" + }, + "get_streaming_distribution": { + "error": "Error describing streaming distribution", + }, + "get_streaming_distribution_config": { + "error": "Error describing streaming distribution", + }, + "get_invalidation": { + "error": "Error describing invalidation" + }, + "list_distributions_by_web_acl_id": { + "error": "Error listing distributions by web acl id", + "post_process": lambda x: cloudfront_facts_keyed_list_helper(x.get('DistributionList', {}).get('Items', [])) + } + } + + CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING = { + "list_origin_access_identities": { + "error": "Error listing cloud front origin access identities", + "client_api": "list_cloud_front_origin_access_identities", + "key": "CloudFrontOriginAccessIdentityList" + }, + "list_distributions": { + "error": "Error listing distributions", + "key": "DistributionList", + "keyed": True, + }, + "list_invalidations": { + "error": "Error listing invalidations", + "key": "InvalidationList" + }, + "list_streaming_distributions": { + "error": "Error listing streaming distributions", + "key": "StreamingDistributionList", + "keyed": True, + } + } - def get_streaming_distribution_config(self, distribution_id): - try: - return self.client.get_streaming_distribution_config(Id=distribution_id, aws_retry=True) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error describing streaming distribution") + def __init__(self, module): + self.module = module + self.client = module.client('cloudfront', retry_decorator=AWSRetry.jittered_backoff()) - def list_origin_access_identities(self): + def describe_cloudfront_property(self, client_method, error, post_process, **kwargs): + fail_if_error = kwargs.pop('fail_if_error', True) try: - paginator = self.client.get_paginator('list_cloud_front_origin_access_identities') - result = paginator.paginate().build_full_result().get('CloudFrontOriginAccessIdentityList', {}) - return result.get('Items', []) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error listing cloud front origin access identities") + method = getattr(self.client, client_method) + api_kwargs = snake_dict_to_camel_dict(kwargs, capitalize_first=True) + result = method(aws_retry=True, **api_kwargs) + result.pop('ResponseMetadata', None) + if post_process: + result = post_process(result) + return result + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + if not fail_if_error: + raise + self.module.fail_json_aws(e, msg=error) + + def paginate_list_cloudfront_property(self, client_method, key, default_keyed, error, **kwargs): + fail_if_error = kwargs.pop('fail_if_error', True) + try: + keyed = kwargs.pop("keyed", default_keyed) + api_kwargs = snake_dict_to_camel_dict(kwargs, capitalize_first=True) + result = _cloudfront_paginate_build_full_result(self.client, client_method, **api_kwargs) + items = result.get(key, {}).get('Items', []) + if keyed: + items = cloudfront_facts_keyed_list_helper(items) + return items + except (botocore.exceptions.ClientError, botocore.exceptions.BotoCoreError) as e: + if not fail_if_error: + raise + self.module.fail_json_aws(e, msg=error) - def list_distributions(self, keyed=True): - try: - paginator = self.client.get_paginator('list_distributions') - result = paginator.paginate().build_full_result().get('DistributionList', {}) - distribution_list = result.get('Items', []) - if not keyed: - return distribution_list - return self.keyed_list_helper(distribution_list) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error listing distributions") + def __getattr__(self, name): - def list_distributions_by_web_acl_id(self, web_acl_id): - try: - result = self.client.list_distributions_by_web_acl_id(WebAclId=web_acl_id, aws_retry=True) - distribution_list = result.get('DistributionList', {}).get('Items', []) - return self.keyed_list_helper(distribution_list) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error listing distributions by web acl id") + if name in self.CLOUDFRONT_CLIENT_API_MAPPING: + client_method = self.CLOUDFRONT_CLIENT_API_MAPPING[name].get('client_api', name) + error = self.CLOUDFRONT_CLIENT_API_MAPPING[name].get('error', '') + post_process = self.CLOUDFRONT_CLIENT_API_MAPPING[name].get('post_process') + return partial(self.describe_cloudfront_property, client_method, error, post_process) - def list_invalidations(self, distribution_id): - try: - paginator = self.client.get_paginator('list_invalidations') - result = paginator.paginate(DistributionId=distribution_id).build_full_result() - return result.get('InvalidationList', {}).get('Items', []) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error listing invalidations") + elif name in self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING: + client_method = self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING[name].get('client_api', name) + error = self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING[name].get('error', '') + key = self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING[name].get('key') + keyed = self.CLOUDFRONT_CLIENT_PAGINATE_API_MAPPING[name].get('keyed', False) + return partial(self.paginate_list_cloudfront_property, client_method, key, keyed, error) - def list_streaming_distributions(self, keyed=True): - try: - paginator = self.client.get_paginator('list_streaming_distributions') - result = paginator.paginate().build_full_result() - streaming_distribution_list = result.get('StreamingDistributionList', {}).get('Items', []) - if not keyed: - return streaming_distribution_list - return self.keyed_list_helper(streaming_distribution_list) - except botocore.exceptions.ClientError as e: - self.module.fail_json_aws(e, msg="Error listing streaming distributions") + raise CloudFrontFactsServiceManagerFailure("Method {0} is not currently supported".format(name)) def summary(self): summary_dict = {} @@ -139,27 +172,27 @@ def summary(self): def summary_get_origin_access_identity_list(self): try: - origin_access_identity_list = {'origin_access_identities': []} - origin_access_identities = self.list_origin_access_identities() - for origin_access_identity in origin_access_identities: + origin_access_identities = [] + for origin_access_identity in self.list_origin_access_identities(): oai_id = origin_access_identity['Id'] oai_full_response = self.get_origin_access_identity(oai_id) oai_summary = {'Id': oai_id, 'ETag': oai_full_response['ETag']} - origin_access_identity_list['origin_access_identities'].append(oai_summary) - return origin_access_identity_list + origin_access_identities.append(oai_summary) + return {'origin_access_identities': origin_access_identities} except botocore.exceptions.ClientError as e: self.module.fail_json_aws(e, msg="Error generating summary of origin access identities") + def list_resource_tags(self, resource_arn): + return self.client.list_tags_for_resource(Resource=resource_arn, aws_retry=True) + def summary_get_distribution_list(self, streaming=False): try: list_name = 'streaming_distributions' if streaming else 'distributions' key_list = ['Id', 'ARN', 'Status', 'LastModifiedTime', 'DomainName', 'Comment', 'PriceClass', 'Enabled'] distribution_list = {list_name: []} - distributions = self.list_streaming_distributions(False) if streaming else self.list_distributions(False) + distributions = self.list_streaming_distributions(keyed=False) if streaming else self.list_distributions(keyed=False) for dist in distributions: - temp_distribution = {} - for key_name in key_list: - temp_distribution[key_name] = dist[key_name] + temp_distribution = {k: dist[k] for k in key_list} temp_distribution['Aliases'] = list(dist['Aliases'].get('Items', [])) temp_distribution['ETag'] = self.get_etag_from_distribution_id(dist['Id'], streaming) if not streaming: @@ -167,7 +200,7 @@ def summary_get_distribution_list(self, streaming=False): invalidation_ids = self.get_list_of_invalidation_ids_from_distribution_id(dist['Id']) if invalidation_ids: temp_distribution['Invalidations'] = invalidation_ids - resource_tags = self.client.list_tags_for_resource(Resource=dist['ARN'], aws_retry=True) + resource_tags = self.list_resource_tags(dist['ARN']) temp_distribution['Tags'] = boto3_tag_list_to_ansible_dict(resource_tags['Tags'].get('Items', [])) distribution_list[list_name].append(temp_distribution) return distribution_list @@ -177,50 +210,32 @@ def summary_get_distribution_list(self, streaming=False): def get_etag_from_distribution_id(self, distribution_id, streaming): distribution = {} if not streaming: - distribution = self.get_distribution(distribution_id) + distribution = self.get_distribution(id=distribution_id) else: - distribution = self.get_streaming_distribution(distribution_id) + distribution = self.get_streaming_distribution(id=distribution_id) return distribution['ETag'] def get_list_of_invalidation_ids_from_distribution_id(self, distribution_id): try: - invalidation_ids = [] - invalidations = self.list_invalidations(distribution_id) - for invalidation in invalidations: - invalidation_ids.append(invalidation['Id']) - return invalidation_ids + return list(map(lambda x: x['Id'], self.list_invalidations(distribution_id=distribution_id))) except botocore.exceptions.ClientError as e: self.module.fail_json_aws(e, msg="Error getting list of invalidation ids") def get_distribution_id_from_domain_name(self, domain_name): try: distribution_id = "" - distributions = self.list_distributions(False) - distributions += self.list_streaming_distributions(False) + distributions = self.list_distributions(keyed=False) + distributions += self.list_streaming_distributions(keyed=False) for dist in distributions: - if 'Items' in dist['Aliases']: - for alias in dist['Aliases']['Items']: - if str(alias).lower() == domain_name.lower(): - distribution_id = dist['Id'] - break + if any(str(alias).lower() == domain_name.lower() for alias in dist['Aliases'].get('Items', [])): + distribution_id = dist['Id'] return distribution_id except botocore.exceptions.ClientError as e: self.module.fail_json_aws(e, msg="Error getting distribution id from domain name") def get_aliases_from_distribution_id(self, distribution_id): try: - distribution = self.get_distribution(distribution_id) - return distribution['DistributionConfig']['Aliases'].get('Items', []) + distribution = self.get_distribution(id=distribution_id) + return distribution['Distribution']['DistributionConfig']['Aliases'].get('Items', []) except botocore.exceptions.ClientError as e: self.module.fail_json_aws(e, msg="Error getting list of aliases from distribution_id") - - def keyed_list_helper(self, list_to_key): - keyed_list = dict() - for item in list_to_key: - distribution_id = item['Id'] - if 'Items' in item['Aliases']: - aliases = item['Aliases']['Items'] - for alias in aliases: - keyed_list.update({alias: item}) - keyed_list.update({distribution_id: item}) - return keyed_list diff --git a/tests/unit/module_utils/test_cloudfront_facts.py b/tests/unit/module_utils/test_cloudfront_facts.py new file mode 100644 index 00000000000..8aca0573c9f --- /dev/null +++ b/tests/unit/module_utils/test_cloudfront_facts.py @@ -0,0 +1,522 @@ +# +# (c) 2022 Red Hat Inc. +# +# This file is part of Ansible +# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) + + +import pytest + +try: + import botocore +except ImportError: + # Handled by HAS_BOTO3 + pass + +from ansible_collections.amazon.aws.plugins.module_utils.cloudfront_facts import ( + CloudFrontFactsServiceManager, + CloudFrontFactsServiceManagerFailure, + cloudfront_facts_keyed_list_helper, +) +from unittest.mock import MagicMock, patch, call + + +MODULE_NAME = "ansible_collections.amazon.aws.plugins.module_utils.cloudfront_facts" +MOCK_CLOUDFRONT_FACTS_KEYED_LIST_HELPER = MODULE_NAME + ".cloudfront_facts_keyed_list_helper" + + +@pytest.fixture() +def cloudfront_facts_service(): + module = MagicMock() + cloudfront_facts = CloudFrontFactsServiceManager(module) + + cloudfront_facts.module = MagicMock() + cloudfront_facts.module.fail_json_aws.side_effect = SystemExit(1) + + cloudfront_facts.client = MagicMock() + + return cloudfront_facts + + +def raise_botocore_error(operation='getCloudFront'): + return botocore.exceptions.ClientError( + { + "Error": { + "Code": "AccessDenied", + "Message": "User: Unauthorized operation" + }, + "ResponseMetadata": { + "RequestId": "01234567-89ab-cdef-0123-456789abcdef" + } + }, operation) + + +def test_unsupported_api(cloudfront_facts_service): + with pytest.raises(CloudFrontFactsServiceManagerFailure) as err: + cloudfront_facts_service._unsupported_api() + assert "Method {0} is not currently supported".format("_unsupported_api") in err + + +def test_get_distribution(cloudfront_facts_service): + + cloudfront_facts = MagicMock() + cloudfront_id = MagicMock() + cloudfront_facts_service.client.get_distribution.return_value = cloudfront_facts + + assert cloudfront_facts == cloudfront_facts_service.get_distribution(id=cloudfront_id) + cloudfront_facts_service.client.get_distribution.assert_called_with( + Id=cloudfront_id, aws_retry=True + ) + + +def test_get_distribution_failure(cloudfront_facts_service): + cloudfront_id = MagicMock() + cloudfront_facts_service.client.get_distribution.side_effect = raise_botocore_error() + + with pytest.raises(SystemExit): + cloudfront_facts_service.get_distribution(id=cloudfront_id) + cloudfront_facts_service.client.get_distribution.assert_called_with( + Id=cloudfront_id, aws_retry=True + ) + + +def test_get_distribution_fail_if_error(cloudfront_facts_service): + cloudfront_id = MagicMock() + cloudfront_facts_service.client.get_distribution.side_effect = raise_botocore_error() + + with pytest.raises(botocore.exceptions.ClientError): + cloudfront_facts_service.get_distribution(id=cloudfront_id, fail_if_error=False) + cloudfront_facts_service.client.get_distribution.assert_called_with( + Id=cloudfront_id, aws_retry=True + ) + + +def test_get_invalidation(cloudfront_facts_service): + + cloudfront_facts = MagicMock() + cloudfront_id = MagicMock() + distribution_id = MagicMock() + cloudfront_facts_service.client.get_invalidation.return_value = cloudfront_facts + + assert cloudfront_facts == cloudfront_facts_service.get_invalidation(distribution_id=distribution_id, id=cloudfront_id) + cloudfront_facts_service.client.get_invalidation.assert_called_with( + DistributionId=distribution_id, Id=cloudfront_id, aws_retry=True + ) + + +def test_get_invalidation_failure(cloudfront_facts_service): + cloudfront_id = MagicMock() + distribution_id = MagicMock() + cloudfront_facts_service.client.get_invalidation.side_effect = raise_botocore_error() + + with pytest.raises(SystemExit): + cloudfront_facts_service.get_invalidation(distribution_id=distribution_id, id=cloudfront_id) + + +@patch(MOCK_CLOUDFRONT_FACTS_KEYED_LIST_HELPER) +def test_list_distributions_by_web_acl_id(m_cloudfront_facts_keyed_list_helper, cloudfront_facts_service): + + web_acl_id = MagicMock() + distribution_webacl = { + 'DistributionList': { + 'Items': ["webacl_%d" % d for d in range(10)] + } + } + cloudfront_facts_service.client.list_distributions_by_web_acl_id.return_value = distribution_webacl + m_cloudfront_facts_keyed_list_helper.return_value = distribution_webacl['DistributionList']['Items'] + + result = cloudfront_facts_service.list_distributions_by_web_acl_id(web_acl_id=web_acl_id) + assert distribution_webacl['DistributionList']['Items'] == result + cloudfront_facts_service.client.list_distributions_by_web_acl_id.assert_called_with( + WebAclId=web_acl_id, aws_retry=True + ) + m_cloudfront_facts_keyed_list_helper.assert_called_with(distribution_webacl['DistributionList']['Items']) + + +@patch(MOCK_CLOUDFRONT_FACTS_KEYED_LIST_HELPER) +@patch(MODULE_NAME + "._cloudfront_paginate_build_full_result") +def test_list_origin_access_identities(m_cloudfront_paginate_build_full_result, m_cloudfront_facts_keyed_list_helper, cloudfront_facts_service): + + items = ["item_%d" % d for d in range(10)] + result = { + 'CloudFrontOriginAccessIdentityList': { + 'Items': items + } + } + + m_cloudfront_paginate_build_full_result.return_value = result + assert items == cloudfront_facts_service.list_origin_access_identities() + m_cloudfront_facts_keyed_list_helper.assert_not_called() + + +@patch(MOCK_CLOUDFRONT_FACTS_KEYED_LIST_HELPER) +@patch(MODULE_NAME + "._cloudfront_paginate_build_full_result") +def test_list_distributions(m_cloudfront_paginate_build_full_result, m_cloudfront_facts_keyed_list_helper, cloudfront_facts_service): + + items = ["item_%d" % d for d in range(10)] + result = { + 'DistributionList': { + 'Items': items + } + } + + m_cloudfront_paginate_build_full_result.return_value = result + m_cloudfront_facts_keyed_list_helper.return_value = items + + assert items == cloudfront_facts_service.list_distributions() + m_cloudfront_facts_keyed_list_helper.assert_called_with(items) + + +@patch(MOCK_CLOUDFRONT_FACTS_KEYED_LIST_HELPER) +@patch(MODULE_NAME + "._cloudfront_paginate_build_full_result") +def test_list_invalidations(m_cloudfront_paginate_build_full_result, m_cloudfront_facts_keyed_list_helper, cloudfront_facts_service): + + items = ["item_%d" % d for d in range(10)] + result = { + 'InvalidationList': { + 'Items': items + } + } + distribution_id = MagicMock() + + m_cloudfront_paginate_build_full_result.return_value = result + m_cloudfront_facts_keyed_list_helper.return_value = items + + assert items == cloudfront_facts_service.list_invalidations(distribution_id=distribution_id) + m_cloudfront_facts_keyed_list_helper.assert_not_called() + m_cloudfront_paginate_build_full_result.assert_called_with( + cloudfront_facts_service.client, 'list_invalidations', DistributionId=distribution_id + ) + + +@pytest.mark.parametrize("fail_if_error", [True, False]) +@patch(MODULE_NAME + "._cloudfront_paginate_build_full_result") +def test_list_invalidations_failure(m_cloudfront_paginate_build_full_result, cloudfront_facts_service, fail_if_error): + + distribution_id = MagicMock() + m_cloudfront_paginate_build_full_result.side_effect = raise_botocore_error() + + if fail_if_error: + with pytest.raises(SystemExit): + cloudfront_facts_service.list_invalidations(distribution_id=distribution_id, fail_if_error=fail_if_error) + else: + with pytest.raises(botocore.exceptions.ClientError): + cloudfront_facts_service.list_invalidations(distribution_id=distribution_id, fail_if_error=fail_if_error) + m_cloudfront_paginate_build_full_result.assert_called_with( + cloudfront_facts_service.client, 'list_invalidations', DistributionId=distribution_id + ) + + +@pytest.mark.parametrize( + "list_to_key,expected", + [ + ([], {}), + ( + [ + {'Id': 'id_1', 'Aliases': {}}, + {'Id': 'id_2', 'Aliases': {'Items': ['alias_1', 'alias_2']}} + ], + { + 'id_1': {'Id': 'id_1', 'Aliases': {}}, + 'id_2': {'Id': 'id_2', 'Aliases': {'Items': ['alias_1', 'alias_2']}}, + 'alias_1': {'Id': 'id_2', 'Aliases': {'Items': ['alias_1', 'alias_2']}}, + 'alias_2': {'Id': 'id_2', 'Aliases': {'Items': ['alias_1', 'alias_2']}} + } + ), + ] +) +def test_cloudfront_facts_keyed_list_helper(list_to_key, expected): + assert expected == cloudfront_facts_keyed_list_helper(list_to_key) + + +@pytest.mark.parametrize( + "distribution,expected", + [ + ( + {'Distribution': {'DistributionConfig': {'Aliases': {'Items': ["item_1", "item_2"]}}}}, + ["item_1", "item_2"] + ), + ( + {'Distribution': {'DistributionConfig': {'Aliases': {}}}}, [] + ) + ] +) +def test_get_aliases_from_distribution_id(cloudfront_facts_service, distribution, expected): + + distribution_id = MagicMock() + + cloudfront_facts_service.get_distribution = MagicMock() + cloudfront_facts_service.get_distribution.return_value = distribution + assert expected == cloudfront_facts_service.get_aliases_from_distribution_id(distribution_id) + + +def test_get_aliases_from_distribution_id_failure(cloudfront_facts_service): + + distribution_id = MagicMock() + + cloudfront_facts_service.get_distribution = MagicMock() + cloudfront_facts_service.get_distribution.side_effect = raise_botocore_error() + + with pytest.raises(SystemExit): + cloudfront_facts_service.get_aliases_from_distribution_id(distribution_id) + cloudfront_facts_service.get_distribution.assert_called_once_with(id=distribution_id) + + +@pytest.mark.parametrize( + "distributions,streaming_distributions,domain_name,expected", + [ + ([], [], MagicMock(), ""), + ([{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "id-01"}], [], "domain01", ""), + ([{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "id-01"}], [], "domain_01", "id-01"), + ([{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "id-01"}], [], "DOMAIN_01", "id-01"), + ([{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "id-01"}], [], "domain_02", "id-01"), + ([], [{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "stream-01"}], "DOMAIN", ""), + ([], [{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "stream-01"}], "DOMAIN_01", "stream-01"), + ([], [{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "stream-01"}], "domain_01", "stream-01"), + ([], [{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "stream-01"}], "domain_02", "stream-01"), + ( + [{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "id-01"}], + [{'Aliases': {'Items': ["domain_01", "domain_02"]}, 'Id': "stream-01"}], + "domain_01", + "stream-01" + ), + ] +) +def test_get_distribution_id_from_domain_name(cloudfront_facts_service, distributions, streaming_distributions, domain_name, expected): + + cloudfront_facts_service.list_distributions = MagicMock() + cloudfront_facts_service.list_streaming_distributions = MagicMock() + + cloudfront_facts_service.list_distributions.return_value = distributions + cloudfront_facts_service.list_streaming_distributions.return_value = streaming_distributions + + assert expected == cloudfront_facts_service.get_distribution_id_from_domain_name(domain_name) + + cloudfront_facts_service.list_distributions.assert_called_once_with(keyed=False) + cloudfront_facts_service.list_streaming_distributions.assert_called_once_with(keyed=False) + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_get_etag_from_distribution_id(cloudfront_facts_service, streaming): + + distribution = {'ETag': MagicMock()} + streaming_distribution = {'ETag': MagicMock()} + + distribution_id = MagicMock() + + cloudfront_facts_service.get_distribution = MagicMock() + cloudfront_facts_service.get_distribution.return_value = distribution + + cloudfront_facts_service.get_streaming_distribution = MagicMock() + cloudfront_facts_service.get_streaming_distribution.return_value = streaming_distribution + + expected = distribution if not streaming else streaming_distribution + + assert expected['ETag'] == cloudfront_facts_service.get_etag_from_distribution_id(distribution_id, streaming) + if not streaming: + cloudfront_facts_service.get_distribution.assert_called_once_with(id=distribution_id) + else: + cloudfront_facts_service.get_streaming_distribution.assert_called_once_with(id=distribution_id) + + +@pytest.mark.parametrize( + "invalidations, expected", + [ + ([], []), + ([{'Id': "id-01"}], ["id-01"]), + ([{'Id': "id-01"}, {'Id': "id-02"}], ["id-01", "id-02"]), + ] +) +def test_get_list_of_invalidation_ids_from_distribution_id(cloudfront_facts_service, invalidations, expected): + + cloudfront_facts_service.list_invalidations = MagicMock() + cloudfront_facts_service.list_invalidations.return_value = invalidations + + distribution_id = MagicMock() + assert expected == cloudfront_facts_service.get_list_of_invalidation_ids_from_distribution_id(distribution_id) + cloudfront_facts_service.list_invalidations.assert_called_with(distribution_id=distribution_id) + + +def test_get_list_of_invalidation_ids_from_distribution_id_failure(cloudfront_facts_service): + + cloudfront_facts_service.list_invalidations = MagicMock() + cloudfront_facts_service.list_invalidations.side_effect = raise_botocore_error() + + distribution_id = MagicMock() + with pytest.raises(SystemExit): + cloudfront_facts_service.get_list_of_invalidation_ids_from_distribution_id(distribution_id) + + +@pytest.mark.parametrize("streaming", [True, False]) +@pytest.mark.parametrize( + "distributions, expected", + [ + ([], []), + ( + [ + { + 'Id': 'id_1', + 'Aliases': {'Items': ['item_1', 'item_2']}, + 'WebACLId': 'webacl_1', + 'ARN': 'arn:ditribution:us-east-1:1', + 'Status': 'available', + 'LastModifiedTime': '11102022120000', + 'DomainName': 'domain_01.com', + 'Comment': 'This is the first distribution', + 'PriceClass': 'low', + 'Enabled': 'False', + 'Tags': { + 'Items': [{'Name': 'tag1', 'Value': 'distribution1'}] + }, + 'ETag': 'abcdefgh', + '_ids': [] + }, + { + 'Id': 'id_2', + 'Aliases': {'Items': ['item_20']}, + 'WebACLId': 'webacl_2', + 'ARN': 'arn:ditribution:us-west:2', + 'Status': 'active', + 'LastModifiedTime': '11102022200000', + 'DomainName': 'another_domain_name.com', + 'Comment': 'This is the second distribution', + 'PriceClass': 'High', + 'Enabled': 'True', + 'Tags': { + 'Items': [{'Name': 'tag2', 'Value': 'distribution2'}, {'Name': 'another_tag', 'Value': 'item 2'}] + }, + 'ETag': 'ABCDEFGH', + '_ids': ["invalidation_1", "invalidation_2"] + } + ], + [ + { + 'Id': 'id_1', + 'ARN': 'arn:ditribution:us-east-1:1', + 'Status': 'available', + 'LastModifiedTime': '11102022120000', + 'DomainName': 'domain_01.com', + 'Comment': 'This is the first distribution', + 'PriceClass': 'low', + 'Enabled': 'False', + 'Aliases': ['item_1', 'item_2'], + 'ETag': 'abcdefgh', + 'WebACLId': 'webacl_1', + 'Tags': [{'Name': 'tag1', 'Value': 'distribution1'}] + }, + { + 'Id': 'id_2', + 'ARN': 'arn:ditribution:us-west:2', + 'Status': 'active', + 'LastModifiedTime': '11102022200000', + 'DomainName': 'another_domain_name.com', + 'Comment': 'This is the second distribution', + 'PriceClass': 'High', + 'Enabled': 'True', + 'Aliases': ['item_20'], + 'ETag': 'ABCDEFGH', + 'WebACLId': 'webacl_2', + 'Invalidations': ['invalidation_1', 'invalidation_2'], + 'Tags': [{'Name': 'tag2', 'Value': 'distribution2'}, {'Name': 'another_tag', 'Value': 'item 2'}] + } + ] + ) + ] +) +@patch(MODULE_NAME + ".boto3_tag_list_to_ansible_dict") +def test_summary_get_distribution_list(m_boto3_tag_list_to_ansible_dict, cloudfront_facts_service, streaming, distributions, expected): + + m_boto3_tag_list_to_ansible_dict.side_effect = lambda x: x + + cloudfront_facts_service.list_streaming_distributions = MagicMock() + cloudfront_facts_service.list_streaming_distributions.return_value = distributions + + cloudfront_facts_service.list_distributions = MagicMock() + cloudfront_facts_service.list_distributions.return_value = distributions + + cloudfront_facts_service.get_etag_from_distribution_id = MagicMock() + cloudfront_facts_service.get_etag_from_distribution_id.side_effect = lambda id, stream: [x['ETag'] for x in distributions if x['Id'] == id][0] + + cloudfront_facts_service.get_list_of_invalidation_ids_from_distribution_id = MagicMock() + cloudfront_facts_service.get_list_of_invalidation_ids_from_distribution_id.side_effect = lambda id: [x['_ids'] for x in distributions if x['Id'] == id][0] + + cloudfront_facts_service.list_resource_tags = MagicMock() + cloudfront_facts_service.list_resource_tags.side_effect = lambda arn: {'Tags': x['Tags'] for x in distributions if x['ARN'] == arn} + + key_name = 'streaming_distributions' + if not streaming: + key_name = 'distributions' + + if streaming: + expected = list(map(lambda x: {k: x[k] for k in x if k not in ('WebACLId', 'Invalidations')}, expected)) + assert {key_name: expected} == cloudfront_facts_service.summary_get_distribution_list(streaming) + + +@pytest.mark.parametrize("streaming", [True, False]) +def test_summary_get_distribution_list_failure(cloudfront_facts_service, streaming): + + cloudfront_facts_service.list_streaming_distributions = MagicMock() + cloudfront_facts_service.list_streaming_distributions.side_effect = raise_botocore_error() + + cloudfront_facts_service.list_distributions = MagicMock() + cloudfront_facts_service.list_distributions.side_effect = raise_botocore_error() + + with pytest.raises(SystemExit): + cloudfront_facts_service.summary_get_distribution_list(streaming) + + +def test_summary(cloudfront_facts_service): + + cloudfront_facts_service.summary_get_distribution_list = MagicMock() + cloudfront_facts_service.summary_get_distribution_list.side_effect = lambda x: {'called_with_true': True} if x else {'called_with_false': False} + + cloudfront_facts_service.summary_get_origin_access_identity_list = MagicMock() + cloudfront_facts_service.summary_get_origin_access_identity_list.return_value = {'origin_access_ids': ['access_1', 'access_2']} + + expected = { + 'called_with_true': True, + 'called_with_false': False, + 'origin_access_ids': ['access_1', 'access_2'] + } + + assert expected == cloudfront_facts_service.summary() + + cloudfront_facts_service.summary_get_origin_access_identity_list.assert_called_once() + cloudfront_facts_service.summary_get_distribution_list.assert_has_calls( + [call(True), call(False)], any_order=True + ) + + +@pytest.mark.parametrize( + "origin_access_identities,expected", + [ + ([], []), + ( + [ + {'Id': 'some_id', 'response': {'state': 'active', 'ETag': 'some_Etag'}}, + {'Id': 'another_id', 'response': {'ETag': 'another_Etag'}} + ], + [ + {'Id': 'some_id', 'ETag': 'some_Etag'}, + {'Id': 'another_id', 'ETag': 'another_Etag'} + ] + ) + ] +) +def test_summary_get_origin_access_identity_list(cloudfront_facts_service, origin_access_identities, expected): + + cloudfront_facts_service.list_origin_access_identities = MagicMock() + cloudfront_facts_service.list_origin_access_identities.return_value = origin_access_identities + cloudfront_facts_service.get_origin_access_identity = MagicMock() + cloudfront_facts_service.get_origin_access_identity.side_effect = lambda x: [o['response'] for o in origin_access_identities if o['Id'] == x][0] + + assert {'origin_access_identities': expected} == cloudfront_facts_service.summary_get_origin_access_identity_list() + + +def test_summary_get_origin_access_identity_list_failure(cloudfront_facts_service): + + cloudfront_facts_service.list_origin_access_identities = MagicMock() + cloudfront_facts_service.list_origin_access_identities.side_effect = raise_botocore_error() + + with pytest.raises(SystemExit): + cloudfront_facts_service.summary_get_origin_access_identity_list()