From 931fbf9fe04fdd173ade2cd8c51b3e5833d50378 Mon Sep 17 00:00:00 2001 From: abikouo Date: Thu, 13 Oct 2022 15:32:18 +0200 Subject: [PATCH] unit testing --- .../module_utils_s3-unit-testing.yml | 3 + plugins/module_utils/s3.py | 169 +++--- tests/unit/module_utils/test_s3.py | 521 ++++++++++++++---- 3 files changed, 498 insertions(+), 195 deletions(-) create mode 100644 changelogs/fragments/module_utils_s3-unit-testing.yml diff --git a/changelogs/fragments/module_utils_s3-unit-testing.yml b/changelogs/fragments/module_utils_s3-unit-testing.yml new file mode 100644 index 00000000000..47d5e4e46d3 --- /dev/null +++ b/changelogs/fragments/module_utils_s3-unit-testing.yml @@ -0,0 +1,3 @@ +--- +trivial: +- module_utils.s3 - Refactor get_s3_connection into a module_utils for S3 modules and expand module_utils.s3 unit tests. diff --git a/plugins/module_utils/s3.py b/plugins/module_utils/s3.py index 9677871b5ee..0719f5f3c1f 100644 --- a/plugins/module_utils/s3.py +++ b/plugins/module_utils/s3.py @@ -6,11 +6,9 @@ from ansible.module_utils.basic import to_text from ansible.module_utils.six.moves.urllib.parse import urlparse -from ansible_collections.amazon.aws.plugins.module_utils.ec2 import get_aws_connection_info - try: - import botocore + from botocore.client import Config from botocore.exceptions import BotoCoreError, ClientError except ImportError: pass # Handled by the calling module @@ -28,6 +26,38 @@ import string +def s3_head_objects(client, parts, bucket, obj, versionId): + args = {"Bucket": bucket, "Key": obj} + if versionId: + args["VersionId"] = versionId + + for part in range(1, parts + 1): + args["PartNumber"] = part + yield client.head_object(**args) + + +def calculate_checksum_with_file(client, parts, bucket, obj, versionId, filename): + digests = [] + with open(filename, 'rb') as f: + for head in s3_head_objects(client, parts, bucket, obj, versionId): + digests.append(md5(f.read(int(head['ContentLength']))).digest()) + + digest_squared = b''.join(digests) + return '"{0}-{1}"'.format(md5(digest_squared).hexdigest(), len(digests)) + + +def calculate_checksum_with_content(client, parts, bucket, obj, versionId, content): + digests = [] + offset = 0 + for head in s3_head_objects(client, parts, bucket, obj, versionId): + length = int(head['ContentLength']) + digests.append(md5(content[offset:offset + length]).digest()) + offset += length + + digest_squared = b''.join(digests) + return '"{0}-{1}"'.format(md5(digest_squared).hexdigest(), len(digests)) + + def calculate_etag(module, filename, etag, s3, bucket, obj, version=None): if not HAS_MD5: return None @@ -35,26 +65,10 @@ def calculate_etag(module, filename, etag, s3, bucket, obj, version=None): if '-' in etag: # Multi-part ETag; a hash of the hashes of each part. parts = int(etag[1:-1].split('-')[1]) - digests = [] - - s3_kwargs = dict( - Bucket=bucket, - Key=obj, - ) - if version: - s3_kwargs['VersionId'] = version - - with open(filename, 'rb') as f: - for part_num in range(1, parts + 1): - s3_kwargs['PartNumber'] = part_num - try: - head = s3.head_object(**s3_kwargs) - except (BotoCoreError, ClientError) as e: - module.fail_json_aws(e, msg="Failed to get head object") - digests.append(md5(f.read(int(head['ContentLength'])))) - - digest_squared = md5(b''.join(m.digest() for m in digests)) - return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests)) + try: + return calculate_checksum_with_file(s3, parts, bucket, obj, version, filename) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get head object") else: # Compute the MD5 sum normally return '"{0}"'.format(module.md5(filename)) @@ -66,28 +80,10 @@ def calculate_etag_content(module, content, etag, s3, bucket, obj, version=None) if '-' in etag: # Multi-part ETag; a hash of the hashes of each part. parts = int(etag[1:-1].split('-')[1]) - digests = [] - offset = 0 - - s3_kwargs = dict( - Bucket=bucket, - Key=obj, - ) - if version: - s3_kwargs['VersionId'] = version - - for part_num in range(1, parts + 1): - s3_kwargs['PartNumber'] = part_num - try: - head = s3.head_object(**s3_kwargs) - except (BotoCoreError, ClientError) as e: - module.fail_json_aws(e, msg="Failed to get head object") - length = int(head['ContentLength']) - digests.append(md5(content[offset:offset + length])) - offset += length - - digest_squared = md5(b''.join(m.digest() for m in digests)) - return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests)) + try: + return calculate_checksum_with_content(s3, parts, bucket, obj, version, content) + except (BotoCoreError, ClientError) as e: + module.fail_json_aws(e, msg="Failed to get head object") else: # Compute the MD5 sum normally return '"{0}"'.format(md5(content).hexdigest()) @@ -108,12 +104,42 @@ def validate_bucket_name(module, name): return True -# To get S3 connection, in case of dealing with ceph, dualstack, etc. -def is_fakes3(endpoint_url): +# Spot special case of fakes3. +def is_fakes3(url): """ Return True if endpoint_url has scheme fakes3:// """ result = False - if endpoint_url is not None: - result = urlparse(endpoint_url).scheme in ('fakes3', 'fakes3s') + if url is not None: + result = urlparse(url).scheme in ('fakes3', 'fakes3s') + return result + + +def parse_fakes3_endpoint(url): + fakes3 = urlparse(url) + protocol = "http" + port = fakes3.port or 80 + if fakes3.scheme == 'fakes3s': + protocol = "https" + port = fakes3.port or 443 + endpoint_url = f"{protocol}://{fakes3.hostname}:{to_text(port)}" + use_ssl = bool(fakes3.scheme == 'fakes3s') + return {"endpoint": endpoint_url, "use_ssl": use_ssl} + + +def parse_ceph_endpoint(url): + ceph = urlparse(url) + use_ssl = bool(ceph.scheme == 'https') + return {"endpoint": url, "use_ssl": use_ssl} + + +def parse_default_endpoint(url, mode, encryption_mode, dualstack, sig_4): + result = {"endpoint": url} + config = {} + if (mode in ('get', 'getstr') and sig_4) or (mode == "put" and encryption_mode == "aws:kms"): + config["signature_version"] = "s3v4" + if dualstack: + config["s3"] = {"use_dualstack_endpoint": True} + if config != {}: + result["config"] = Config(**config) return result @@ -122,48 +148,17 @@ def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, conn_type='client', resource='s3', region=location, - endpoint=endpoint_url, **aws_connect_kwargs ) - if ceph: # TODO - test this - ceph = urlparse(endpoint_url) - use_ssl = bool(ceph.scheme == 'https') - params.update( - dict( - use_ssl=use_ssl - ) - ) + if ceph: + endpoint_p = parse_ceph_endpoint(endpoint_url) elif is_fakes3(endpoint_url): - fakes3 = urlparse(endpoint_url) - protocol = "http" - port = fakes3.port or 80 - if fakes3.scheme == 'fakes3s': - protocol = "https" - port = fakes3.port or 443 - endpoint_url = f"{protocol}://{fakes3.hostname}:{to_text(port)}" - use_ssl = bool(fakes3.scheme == 'fakes3s') - params.update( - dict( - endpoint=endpoint_url, - use_ssl=use_ssl, - ) - ) + endpoint_p = parse_fakes3_endpoint(endpoint_url) else: mode = module.params.get("mode") encryption_mode = module.params.get("encryption_mode") - config = None - if (mode in ('get', 'getstr') and sig_4) or (mode == "put" and encryption_mode == "aws:kms"): - config = botocore.client.Config(signature_version='s3v4') - if module.params.get("dualstack"): - use_dualstack = dict(use_dualstack_endpoint=True) - if config is not None: - config.merge(botocore.client.Config(s3=use_dualstack)) - else: - config = botocore.client.Config(s3=use_dualstack) - if config: - params.update( - dict( - config=config - ) - ) + dualstack = module.params.get("dualstack") + endpoint_p = parse_default_endpoint(endpoint_url, mode, encryption_mode, dualstack, sig_4) + + params.update(endpoint_p) return module.boto3_conn(**params) diff --git a/tests/unit/module_utils/test_s3.py b/tests/unit/module_utils/test_s3.py index a630ac065f7..f4b652b7d0a 100644 --- a/tests/unit/module_utils/test_s3.py +++ b/tests/unit/module_utils/test_s3.py @@ -5,138 +5,443 @@ # GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt) from __future__ import (absolute_import, division, print_function) -from operator import mod - __metaclass__ = type -from ansible_collections.amazon.aws.tests.unit.compat.mock import MagicMock +import pytest +import random +import string + from ansible_collections.amazon.aws.plugins.module_utils import s3 -from ansible.module_utils.basic import AnsibleModule -from unittest.mock import MagicMock, Mock, patch, ANY, call +from unittest.mock import MagicMock, patch, call + +try: + import botocore +except ImportError: + pass + + +def generate_random_string(size, include_digits=True): + buffer = string.ascii_lowercase + if include_digits: + buffer += string.digits + + return ''.join(random.choice(buffer) for i in range(size)) + + +@pytest.mark.parametrize("parts", range(10)) +@pytest.mark.parametrize("version", [True, False]) +def test_s3_head_objects(parts, version): + + client = MagicMock() + + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + versionId = None + if version: + versionId = random.randint(0, 1000) + + total = 0 + for head in s3.s3_head_objects(client, parts, s3bucket_name, s3bucket_object, versionId): + assert head == client.head_object.return_value + total += 1 + + assert total == parts + params = {"Bucket": s3bucket_name, "Key": s3bucket_object} + if versionId: + params["VersionId"] = versionId + api_calls = [call(PartNumber=i, **params) for i in range(1, parts + 1)] + client.head_object.assert_has_calls(api_calls, any_order=True) -class FakeAnsibleModule(AnsibleModule): - def __init__(self): - pass +def raise_botoclient_exception(): + params = { + 'Error': { + 'Code': 1, + 'Message': 'Something went wrong' + }, + 'ResponseMetadata': { + 'RequestId': '01234567-89ab-cdef-0123-456789abcdef' + } + } + return botocore.exceptions.ClientError(params, 'some_called_method') -def test_calculate_etag_single_part(tmp_path_factory): - module = FakeAnsibleModule() - my_image = tmp_path_factory.mktemp("data") / "my.txt" - my_image.write_text("Hello World!") - etag = s3.calculate_etag( - module, str(my_image), etag="", s3=None, bucket=None, obj=None - ) - assert etag == '"ed076287532e86365e841e92bfc50d8c"' +@pytest.mark.parametrize("use_file", [False, True]) +@pytest.mark.parametrize("parts", range(10)) +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.md5") +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.s3_head_objects") +def test_calculate_checksum(m_s3_head_objects, m_s3_md5, use_file, parts, tmp_path): + client = MagicMock() + mock_md5 = m_s3_md5.return_value -def test_calculate_etag_multi_part(tmp_path_factory): - module = FakeAnsibleModule() - my_image = tmp_path_factory.mktemp("data") / "my.txt" - my_image.write_text("Hello World!" * 1000) + mock_md5.digest.return_value = b"1" + mock_md5.hexdigest.return_value = ''.join(["f" for i in range(32)]) - mocked_s3 = MagicMock() - mocked_s3.head_object.side_effect = [{"ContentLength": "1000"} for _i in range(12)] + m_s3_head_objects.return_value = [ + {"ContentLength": "%d" % (i + 1)} for i in range(parts) + ] - etag = s3.calculate_etag( - module, - str(my_image), - etag='"f20e84ac3d0c33cea77b3f29e3323a09-12"', - s3=mocked_s3, - bucket="my-bucket", - obj="my-obj", - ) - assert etag == '"f20e84ac3d0c33cea77b3f29e3323a09-12"' - mocked_s3.head_object.assert_called_with( - Bucket="my-bucket", Key="my-obj", PartNumber=12 - ) + content = b'"f20e84ac3d0c33cea77b3f29e3323a09"' + test_function = s3.calculate_checksum_with_content + if use_file: + test_function = s3.calculate_checksum_with_file + test_dir = tmp_path / "test_s3" + test_dir.mkdir() + etag_file = test_dir / "etag.bin" + etag_file.write_bytes(content) + content = str(etag_file) + + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + version = random.randint(0, 1000) + + result = test_function(client, parts, s3bucket_name, s3bucket_object, version, content) + + expected = '"{0}-{1}"'.format(mock_md5.hexdigest.return_value, parts) + assert result == expected + + mock_md5.digest.assert_has_calls([call() for i in range(parts)]) + mock_md5.hexdigest.assert_called_once() + + m_s3_head_objects.assert_called_once_with(client, parts, s3bucket_name, s3bucket_object, version) + + +@pytest.mark.parametrize("etag_multipart", [True, False]) +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.calculate_checksum_with_file") +def test_calculate_etag(m_checksum_file, etag_multipart): -def test_validate_bucket_name(): module = MagicMock() + client = MagicMock() - assert s3.validate_bucket_name(module, "docexamplebucket1") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "log-delivery-march-2020") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "my-hosted-content") is True - assert not module.fail_json.called - - assert s3.validate_bucket_name(module, "docexamplewebsite.com") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "www.docexamplewebsite.com") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "my.example.s3.bucket") is True - assert not module.fail_json.called - assert s3.validate_bucket_name(module, "doc") is True - assert not module.fail_json.called - - module.fail_json.reset_mock() - s3.validate_bucket_name(module, "doc_example_bucket") - assert module.fail_json.called - - module.fail_json.reset_mock() - s3.validate_bucket_name(module, "DocExampleBucket") - assert module.fail_json.called - module.fail_json.reset_mock() - s3.validate_bucket_name(module, "doc-example-bucket-") - assert module.fail_json.called - s3.validate_bucket_name(module, "my") - assert module.fail_json.called - - -def test_is_fakes3_with_none_arg(): - result = s3.is_fakes3(None) - assert not result + module.fail_json_aws.side_effect = SystemExit(2) + module.md5.return_value = generate_random_string(32) + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + version = random.randint(0, 1000) + parts = 3 -def test_is_fakes3_with_valid_protocol(): - assert s3.is_fakes3("https://test-s3.amazon.com") + etag = '"f20e84ac3d0c33cea77b3f29e3323a09"' + digest = '"9aa254f7f76fd14435b21e9448525b99"' + file_name = generate_random_string(32) -def test_is_fakes3_with_fakes3_protocol(): - assert s3.is_fakes3("fakes3://test-s3.amazon.com") + if not etag_multipart: + result = s3.calculate_etag(module, file_name, etag, client, s3bucket_name, s3bucket_object, version) + assert result == '"{0}"'.format(module.md5.return_value) + module.md5.assert_called_once_with(file_name) + else: + etag = '"f20e84ac3d0c33cea77b3f29e3323a09-{0}"'.format(parts) + m_checksum_file.return_value = digest + assert digest == s3.calculate_etag(module, file_name, etag, client, s3bucket_name, s3bucket_object, version) + m_checksum_file.assert_called_with( + client, parts, s3bucket_name, s3bucket_object, version, file_name + ) -def test_is_fakes3_with_fakes3s_protocol(): - assert s3.is_fakes3("fakes3s://test-s3.amazon.com") +@pytest.mark.parametrize("etag_multipart", [True, False]) +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.calculate_checksum_with_content") +def test_calculate_etag_content(m_checksum_content, etag_multipart): + + module = MagicMock() + client = MagicMock() -def test_get_s3_connection_ceph_with_https(): - aws_connect = dict( - aws_access_key_id="ACCESS012345", - aws_secret_access_key="SECRET123", - ) + module.fail_json_aws.side_effect = SystemExit(2) + + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + version = random.randint(0, 1000) + parts = 3 + + etag = '"f20e84ac3d0c33cea77b3f29e3323a09"' + content = b'"f20e84ac3d0c33cea77b3f29e3323a09"' + digest = '"9aa254f7f76fd14435b21e9448525b99"' + + if not etag_multipart: + assert digest == s3.calculate_etag_content(module, content, etag, client, s3bucket_name, s3bucket_object, version) + else: + etag = '"f20e84ac3d0c33cea77b3f29e3323a09-{0}"'.format(parts) + m_checksum_content.return_value = digest + result = s3.calculate_etag_content(module, content, etag, client, s3bucket_name, s3bucket_object, version) + assert result == digest + + m_checksum_content.assert_called_with( + client, parts, s3bucket_name, s3bucket_object, version, content + ) + + +@pytest.mark.parametrize("using_file", [True, False]) +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.calculate_checksum_with_content") +@patch("ansible_collections.amazon.aws.plugins.module_utils.s3.calculate_checksum_with_file") +def test_calculate_etag_failure(m_checksum_file, m_checksum_content, using_file): + + module = MagicMock() + client = MagicMock() + + module.fail_json_aws.side_effect = SystemExit(2) + + s3bucket_name = "s3-bucket-%s" % (generate_random_string(8, False)) + s3bucket_object = "s3-bucket-object-%s" % (generate_random_string(8, False)) + version = random.randint(0, 1000) + parts = 3 + + etag = '"f20e84ac3d0c33cea77b3f29e3323a09-{0}"'.format(parts) + content = "some content or file name" + + if using_file: + test_method = s3.calculate_etag + m_checksum_file.side_effect = raise_botoclient_exception() + else: + test_method = s3.calculate_etag_content + m_checksum_content.side_effect = raise_botoclient_exception() + + with pytest.raises(SystemExit): + test_method(module, content, etag, client, s3bucket_name, s3bucket_object, version) + module.fail_json_aws.assert_called() + + +@pytest.mark.parametrize( + "bucket_name,error", + [ + ("docexamplebucket1", None), + ("log-delivery-march-2020", None), + ("my-hosted-content", None), + ("docexamplewebsite.com", None), + ("www.docexamplewebsite.com", None), + ("my.example.s3.bucket", None), + ("doc", None), + ("doc_example_bucket", "invalid character(s) found in the bucket name"), + ("DocExampleBucket", "invalid character(s) found in the bucket name"), + ("doc-example-bucket-", "bucket names must begin and end with a letter or number"), + ( + "this.string.has.more.than.63.characters.so.it.should.not.passed.the.validated", + "the length of an S3 bucket cannot exceed 63 characters" + ), + ("my", "the length of an S3 bucket must be at least 3 characters") + ] +) +def test_validate_bucket_name(bucket_name, error): + + module = MagicMock() + module.fail_json.side_effect = SystemExit(1) + + if error: + with pytest.raises(SystemExit): + s3.validate_bucket_name(module, bucket_name) + + module.fail_json.assert_called_with(msg=error) + else: + assert s3.validate_bucket_name(module, bucket_name) + module.fail_json.assert_not_called() + + +mod_urlparse = "ansible_collections.amazon.aws.plugins.module_utils.s3.urlparse" + + +class UrlInfo(object): + + def __init__(self, scheme=None, hostname=None, port=None): + self.hostname = hostname + self.scheme = scheme + self.port = port + + +@patch(mod_urlparse) +def test_is_fakes3_with_none_arg(m_urlparse): + m_urlparse.side_effect = SystemExit(1) + result = s3.is_fakes3(None) + assert not result + m_urlparse.assert_not_called() + + +@pytest.mark.parametrize( + "url,scheme,result", + [ + ("https://test-s3.amazon.com", "https", False), + ("fakes3://test-s3.amazon.com", "fakes3", True), + ("fakes3s://test-s3.amazon.com", "fakes3s", True), + ] +) +@patch(mod_urlparse) +def test_is_fakes3(m_urlparse, url, scheme, result): + m_urlparse.return_value = UrlInfo(scheme=scheme) + assert result == s3.is_fakes3(url) + m_urlparse.assert_called_with(url) + + +@pytest.mark.parametrize( + "url,urlinfo,endpoint", + [ + ( + "fakes3://test-s3.amazon.com", + { + "scheme": "fakes3", + "hostname": "test-s3.amazon.com" + }, + { + "endpoint": "http://test-s3.amazon.com:80", + "use_ssl": False + } + ), + ( + "fakes3://test-s3.amazon.com:8080", + { + "scheme": "fakes3", + "hostname": "test-s3.amazon.com", + "port": 8080 + }, + { + "endpoint": "http://test-s3.amazon.com:8080", + "use_ssl": False + } + ), + ( + "fakes3s://test-s3.amazon.com", + { + "scheme": "fakes3s", + "hostname": "test-s3.amazon.com" + }, + { + "endpoint": "https://test-s3.amazon.com:443", + "use_ssl": True + } + ), + ( + "fakes3s://test-s3.amazon.com:9096", + { + "scheme": "fakes3s", + "hostname": "test-s3.amazon.com", + "port": 9096 + }, + { + "endpoint": "https://test-s3.amazon.com:9096", + "use_ssl": True + } + ) + ] +) +@patch(mod_urlparse) +def test_parse_fakes3_endpoint(m_urlparse, url, urlinfo, endpoint): + m_urlparse.return_value = UrlInfo(**urlinfo) + result = s3.parse_fakes3_endpoint(url) + assert endpoint == result + m_urlparse.assert_called_with(url) + + +@pytest.mark.parametrize( + "url,scheme,use_ssl", + [ + ("https://test-s3-ceph.amazon.com", "https", True), + ("http://test-s3-ceph.amazon.com", "http", False), + ] +) +@patch(mod_urlparse) +def test_parse_ceph_endpoint(m_urlparse, url, scheme, use_ssl): + m_urlparse.return_value = UrlInfo(scheme=scheme) + result = s3.parse_ceph_endpoint(url) + assert result == {"endpoint": url, "use_ssl": use_ssl} + m_urlparse.assert_called_with(url) + + +@pytest.mark.parametrize("mode", ["put", "get", "getstr"]) +@pytest.mark.parametrize("encryption_mode", ["aws:kms", "aws:unknown"]) +@pytest.mark.parametrize("dualstack", [True, False]) +@pytest.mark.parametrize("sig_4", [True, False]) +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.Config') +def test_parse_default_endpoint(m_config, mode, encryption_mode, dualstack, sig_4): + + url = "https://my-bucket.s3.us-west-2.amazonaws.com" + + signature_version = False + if mode in ('get', 'getstr') and sig_4: + signature_version = True + if mode == "put" and encryption_mode == "aws:kms": + signature_version = True + + attributes = {} + if signature_version: + attributes["signature_version"] = "s3v4" + if dualstack: + attributes["s3"] = {"use_dualstack_endpoint": True} + + result = s3.parse_default_endpoint(url, mode, encryption_mode, dualstack, sig_4) + + expected = {"endpoint": url} + if attributes: + m_config.assert_called_with(**attributes) + expected["config"] = m_config.return_value + else: + m_config.assert_not_called() + + assert result == expected + + +@pytest.mark.parametrize( + "ceph,isfakes3", + [ + (True, False), + (False, True), + (False, False), + ] +) +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.parse_default_endpoint') +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.parse_fakes3_endpoint') +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.is_fakes3') +@patch('ansible_collections.amazon.aws.plugins.module_utils.s3.parse_ceph_endpoint') +def test_get_s3_connection(m_parse_ceph_endpoint, + m_is_fakes3, + m_parse_fakes3_endpoint, + m_parse_default_endpoint, + ceph, + isfakes3): + + url = "https://my-bucket.s3.us-west-2.amazonaws.com" region = "us-east-1" - s3_url = "https://test.ceph-s3.domain-name.com:8080" + aws_connect_kwargs = {"aws_secret_key": "secret123!", "aws_access_key": "ABCDEFG"} + params = {"mode": "put", "encryption_mode": "aws:test", "dualstack": False} + sig_4 = False + + endpoint = {"endpoint": url, "config": {"s3": True, "signature": "s123"}} module = MagicMock() - module.params = dict() - # module.boto3_conn.return_value = True - - result = s3.get_s3_connection( - module=module, - aws_connect_kwargs=aws_connect, - location=region, - ceph=True, - endpoint_url=s3_url - ) - - # print("Result MagicMock: {}".format(result)) - - print("called: {}".format(module.called)) - print("method call_count: {}".format(module.boto3_conn.call_count)) - # print("call_args: {}".format(result.call_args)) - # print("mock_calls: {}".format(result.mock_calls)) - # print("call_args_list: {}".format(result.call_args_list)) - print("method_calls: {}".format(module.method_calls)) - # assert module.boto3_conn.assert_called_once() - assert module.boto3_conn.assert_called_with( - conn_type='client', - resource='s3', - region=region, - endpoint=s3_url, - use_ssl=True, - **aws_connect - ) + module.params = params + + m_is_fakes3.return_value = isfakes3 + if ceph: + m_parse_ceph_endpoint.return_value = endpoint + elif isfakes3: + m_parse_fakes3_endpoint.return_value = endpoint + else: + m_parse_default_endpoint.return_value = endpoint + + expected = {"conn_type": "client", "resource": "s3", "region": region} + expected.update(aws_connect_kwargs) + expected.update(endpoint) + + result = s3.get_s3_connection(module, aws_connect_kwargs, region, ceph, url, sig_4) + + if ceph: + m_parse_ceph_endpoint.assert_called_with(url) + m_parse_fakes3_endpoint.assert_not_called() + m_parse_default_endpoint.assert_not_called() + elif isfakes3: + m_parse_fakes3_endpoint.assert_called_with(url) + m_parse_ceph_endpoint.assert_not_called() + m_parse_default_endpoint.assert_not_called() + else: + m_parse_default_endpoint.assert_called_with( + url, + params.get("mode"), + params.get("encryption_mode"), + params.get("dualstack"), + sig_4 + ) + m_parse_ceph_endpoint.assert_not_called() + m_parse_fakes3_endpoint.assert_not_called() + + assert result == module.boto3_conn.return_value + module.boto3_conn.assert_called_with(**expected)