Skip to content

Commit

Permalink
unit testing
Browse files Browse the repository at this point in the history
  • Loading branch information
abikouo committed Oct 13, 2022
1 parent 795bad0 commit 931fbf9
Show file tree
Hide file tree
Showing 3 changed files with 498 additions and 195 deletions.
3 changes: 3 additions & 0 deletions changelogs/fragments/module_utils_s3-unit-testing.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
trivial:
- module_utils.s3 - Refactor get_s3_connection into a module_utils for S3 modules and expand module_utils.s3 unit tests.
169 changes: 82 additions & 87 deletions plugins/module_utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,9 @@

from ansible.module_utils.basic import to_text
from ansible.module_utils.six.moves.urllib.parse import urlparse
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import get_aws_connection_info


try:
import botocore
from botocore.client import Config
from botocore.exceptions import BotoCoreError, ClientError
except ImportError:
pass # Handled by the calling module
Expand All @@ -28,33 +26,49 @@
import string


def s3_head_objects(client, parts, bucket, obj, versionId):
args = {"Bucket": bucket, "Key": obj}
if versionId:
args["VersionId"] = versionId

for part in range(1, parts + 1):
args["PartNumber"] = part
yield client.head_object(**args)


def calculate_checksum_with_file(client, parts, bucket, obj, versionId, filename):
digests = []
with open(filename, 'rb') as f:
for head in s3_head_objects(client, parts, bucket, obj, versionId):
digests.append(md5(f.read(int(head['ContentLength']))).digest())

digest_squared = b''.join(digests)
return '"{0}-{1}"'.format(md5(digest_squared).hexdigest(), len(digests))


def calculate_checksum_with_content(client, parts, bucket, obj, versionId, content):
digests = []
offset = 0
for head in s3_head_objects(client, parts, bucket, obj, versionId):
length = int(head['ContentLength'])
digests.append(md5(content[offset:offset + length]).digest())
offset += length

digest_squared = b''.join(digests)
return '"{0}-{1}"'.format(md5(digest_squared).hexdigest(), len(digests))


def calculate_etag(module, filename, etag, s3, bucket, obj, version=None):
if not HAS_MD5:
return None

if '-' in etag:
# Multi-part ETag; a hash of the hashes of each part.
parts = int(etag[1:-1].split('-')[1])
digests = []

s3_kwargs = dict(
Bucket=bucket,
Key=obj,
)
if version:
s3_kwargs['VersionId'] = version

with open(filename, 'rb') as f:
for part_num in range(1, parts + 1):
s3_kwargs['PartNumber'] = part_num
try:
head = s3.head_object(**s3_kwargs)
except (BotoCoreError, ClientError) as e:
module.fail_json_aws(e, msg="Failed to get head object")
digests.append(md5(f.read(int(head['ContentLength']))))

digest_squared = md5(b''.join(m.digest() for m in digests))
return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests))
try:
return calculate_checksum_with_file(s3, parts, bucket, obj, version, filename)
except (BotoCoreError, ClientError) as e:
module.fail_json_aws(e, msg="Failed to get head object")
else: # Compute the MD5 sum normally
return '"{0}"'.format(module.md5(filename))

Expand All @@ -66,28 +80,10 @@ def calculate_etag_content(module, content, etag, s3, bucket, obj, version=None)
if '-' in etag:
# Multi-part ETag; a hash of the hashes of each part.
parts = int(etag[1:-1].split('-')[1])
digests = []
offset = 0

s3_kwargs = dict(
Bucket=bucket,
Key=obj,
)
if version:
s3_kwargs['VersionId'] = version

for part_num in range(1, parts + 1):
s3_kwargs['PartNumber'] = part_num
try:
head = s3.head_object(**s3_kwargs)
except (BotoCoreError, ClientError) as e:
module.fail_json_aws(e, msg="Failed to get head object")
length = int(head['ContentLength'])
digests.append(md5(content[offset:offset + length]))
offset += length

digest_squared = md5(b''.join(m.digest() for m in digests))
return '"{0}-{1}"'.format(digest_squared.hexdigest(), len(digests))
try:
return calculate_checksum_with_content(s3, parts, bucket, obj, version, content)
except (BotoCoreError, ClientError) as e:
module.fail_json_aws(e, msg="Failed to get head object")
else: # Compute the MD5 sum normally
return '"{0}"'.format(md5(content).hexdigest())

Expand All @@ -108,12 +104,42 @@ def validate_bucket_name(module, name):
return True


# To get S3 connection, in case of dealing with ceph, dualstack, etc.
def is_fakes3(endpoint_url):
# Spot special case of fakes3.
def is_fakes3(url):
""" Return True if endpoint_url has scheme fakes3:// """
result = False
if endpoint_url is not None:
result = urlparse(endpoint_url).scheme in ('fakes3', 'fakes3s')
if url is not None:
result = urlparse(url).scheme in ('fakes3', 'fakes3s')
return result


def parse_fakes3_endpoint(url):
fakes3 = urlparse(url)
protocol = "http"
port = fakes3.port or 80
if fakes3.scheme == 'fakes3s':
protocol = "https"
port = fakes3.port or 443
endpoint_url = f"{protocol}://{fakes3.hostname}:{to_text(port)}"
use_ssl = bool(fakes3.scheme == 'fakes3s')
return {"endpoint": endpoint_url, "use_ssl": use_ssl}


def parse_ceph_endpoint(url):
ceph = urlparse(url)
use_ssl = bool(ceph.scheme == 'https')
return {"endpoint": url, "use_ssl": use_ssl}


def parse_default_endpoint(url, mode, encryption_mode, dualstack, sig_4):
result = {"endpoint": url}
config = {}
if (mode in ('get', 'getstr') and sig_4) or (mode == "put" and encryption_mode == "aws:kms"):
config["signature_version"] = "s3v4"
if dualstack:
config["s3"] = {"use_dualstack_endpoint": True}
if config != {}:
result["config"] = Config(**config)
return result


Expand All @@ -122,48 +148,17 @@ def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url,
conn_type='client',
resource='s3',
region=location,
endpoint=endpoint_url,
**aws_connect_kwargs
)
if ceph: # TODO - test this
ceph = urlparse(endpoint_url)
use_ssl = bool(ceph.scheme == 'https')
params.update(
dict(
use_ssl=use_ssl
)
)
if ceph:
endpoint_p = parse_ceph_endpoint(endpoint_url)
elif is_fakes3(endpoint_url):
fakes3 = urlparse(endpoint_url)
protocol = "http"
port = fakes3.port or 80
if fakes3.scheme == 'fakes3s':
protocol = "https"
port = fakes3.port or 443
endpoint_url = f"{protocol}://{fakes3.hostname}:{to_text(port)}"
use_ssl = bool(fakes3.scheme == 'fakes3s')
params.update(
dict(
endpoint=endpoint_url,
use_ssl=use_ssl,
)
)
endpoint_p = parse_fakes3_endpoint(endpoint_url)
else:
mode = module.params.get("mode")
encryption_mode = module.params.get("encryption_mode")
config = None
if (mode in ('get', 'getstr') and sig_4) or (mode == "put" and encryption_mode == "aws:kms"):
config = botocore.client.Config(signature_version='s3v4')
if module.params.get("dualstack"):
use_dualstack = dict(use_dualstack_endpoint=True)
if config is not None:
config.merge(botocore.client.Config(s3=use_dualstack))
else:
config = botocore.client.Config(s3=use_dualstack)
if config:
params.update(
dict(
config=config
)
)
dualstack = module.params.get("dualstack")
endpoint_p = parse_default_endpoint(endpoint_url, mode, encryption_mode, dualstack, sig_4)

params.update(endpoint_p)
return module.boto3_conn(**params)
Loading

0 comments on commit 931fbf9

Please sign in to comment.