Skip to content

Commit

Permalink
move get_s3_connection, reduce complexity and increase coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
abikouo committed Oct 13, 2022
1 parent 7d2528a commit 795bad0
Show file tree
Hide file tree
Showing 6 changed files with 131 additions and 124 deletions.
3 changes: 3 additions & 0 deletions plugins/module_utils/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,9 @@ def resource(self, service):
return boto3_conn(self, conn_type='resource', resource=service,
region=region, endpoint=endpoint_url, **aws_connect_kwargs)

def boto3_conn(self, **kwargs):
return boto3_conn(self, **kwargs)

@property
def region(self):
return get_aws_region(self, True)
Expand Down
67 changes: 67 additions & 0 deletions plugins/module_utils/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

from ansible.module_utils.basic import to_text
from ansible.module_utils.six.moves.urllib.parse import urlparse
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import get_aws_connection_info


try:
import botocore
from botocore.exceptions import BotoCoreError, ClientError
except ImportError:
pass # Handled by the calling module
Expand Down Expand Up @@ -100,3 +106,64 @@ def validate_bucket_name(module, name):
if name[-1] not in string.ascii_lowercase + string.digits:
module.fail_json(msg='bucket names must begin and end with a letter or number')
return True


# To get S3 connection, in case of dealing with ceph, dualstack, etc.
def is_fakes3(endpoint_url):
""" Return True if endpoint_url has scheme fakes3:// """
result = False
if endpoint_url is not None:
result = urlparse(endpoint_url).scheme in ('fakes3', 'fakes3s')
return result


def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False):
params = dict(
conn_type='client',
resource='s3',
region=location,
endpoint=endpoint_url,
**aws_connect_kwargs
)
if ceph: # TODO - test this
ceph = urlparse(endpoint_url)
use_ssl = bool(ceph.scheme == 'https')
params.update(
dict(
use_ssl=use_ssl
)
)
elif is_fakes3(endpoint_url):
fakes3 = urlparse(endpoint_url)
protocol = "http"
port = fakes3.port or 80
if fakes3.scheme == 'fakes3s':
protocol = "https"
port = fakes3.port or 443
endpoint_url = f"{protocol}://{fakes3.hostname}:{to_text(port)}"
use_ssl = bool(fakes3.scheme == 'fakes3s')
params.update(
dict(
endpoint=endpoint_url,
use_ssl=use_ssl,
)
)
else:
mode = module.params.get("mode")
encryption_mode = module.params.get("encryption_mode")
config = None
if (mode in ('get', 'getstr') and sig_4) or (mode == "put" and encryption_mode == "aws:kms"):
config = botocore.client.Config(signature_version='s3v4')
if module.params.get("dualstack"):
use_dualstack = dict(use_dualstack_endpoint=True)
if config is not None:
config.merge(botocore.client.Config(s3=use_dualstack))
else:
config = botocore.client.Config(s3=use_dualstack)
if config:
params.update(
dict(
config=config
)
)
return module.boto3_conn(**params)
46 changes: 1 addition & 45 deletions plugins/modules/s3_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,15 +401,13 @@
except ImportError:
pass # Handled by AnsibleAWSModule

from ansible.module_utils.basic import to_text
from ansible.module_utils.basic import to_native
from ansible.module_utils.six.moves.urllib.parse import urlparse

from ansible_collections.amazon.aws.plugins.module_utils.core import AnsibleAWSModule
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_message
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_conn
from ansible_collections.amazon.aws.plugins.module_utils.s3 import get_s3_connection
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import get_aws_connection_info
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import ansible_dict_to_boto3_tag_list
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_tag_list_to_ansible_dict
Expand Down Expand Up @@ -835,48 +833,6 @@ def copy_object_to_bucket(module, s3, bucket, obj, encrypt, metadata, validate,
module.fail_json_aws(e, msg="Failed while copying object %s from bucket %s." % (obj, module.params['copy_src'].get('Bucket')))


def is_fakes3(endpoint_url):
""" Return True if endpoint_url has scheme fakes3:// """
if endpoint_url is not None:
return urlparse(endpoint_url).scheme in ('fakes3', 'fakes3s')
else:
return False


def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False):
if ceph: # TODO - test this
ceph = urlparse(endpoint_url)
params = dict(module=module, conn_type='client', resource='s3', use_ssl=ceph.scheme == 'https',
region=location, endpoint=endpoint_url, **aws_connect_kwargs)
elif is_fakes3(endpoint_url):
fakes3 = urlparse(endpoint_url)
port = fakes3.port
if fakes3.scheme == 'fakes3s':
protocol = "https"
if port is None:
port = 443
else:
protocol = "http"
if port is None:
port = 80
params = dict(module=module, conn_type='client', resource='s3', region=location,
endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)),
use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs)
else:
params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=endpoint_url, **aws_connect_kwargs)
if module.params['mode'] == 'put' and module.params['encryption_mode'] == 'aws:kms':
params['config'] = botocore.client.Config(signature_version='s3v4')
elif module.params['mode'] in ('get', 'getstr', 'geturl') and sig_4:
params['config'] = botocore.client.Config(signature_version='s3v4')
if module.params['dualstack']:
dualconf = botocore.client.Config(s3={'use_dualstack_endpoint': True})
if 'config' in params:
params['config'] = params['config'].merge(dualconf)
else:
params['config'] = dualconf
return boto3_conn(**params)


def get_current_object_tags_dict(s3, bucket, obj, version=None):
try:
if version:
Expand Down
50 changes: 2 additions & 48 deletions plugins/modules/s3_object_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,16 +436,13 @@
except ImportError:
pass # Handled by AnsibleAWSModule

from ansible.module_utils.basic import to_text
from ansible.module_utils.six.moves.urllib.parse import urlparse

from ansible_collections.amazon.aws.plugins.module_utils.core import AnsibleAWSModule
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import AWSRetry
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import camel_dict_to_snake_dict
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_tag_list_to_ansible_dict
from ansible_collections.amazon.aws.plugins.module_utils.core import is_boto3_error_code
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import get_aws_connection_info
from ansible_collections.amazon.aws.plugins.module_utils.ec2 import boto3_conn
from ansible_collections.amazon.aws.plugins.module_utils.s3 import get_s3_connection


def describe_s3_object_acl(connection, bucket_name, object_name):
Expand Down Expand Up @@ -666,49 +663,6 @@ def object_check(connection, module, bucket_name, object_name):
module.fail_json_aws(e, msg="The object %s does not exist or is missing access permissions." % object_name)


# To get S3 connection, in case of dealing with ceph, dualstack, etc.
def is_fakes3(endpoint_url):
""" Return True if endpoint_url has scheme fakes3:// """
if endpoint_url is not None:
return urlparse(endpoint_url).scheme in ('fakes3', 'fakes3s')
else:
return False


def get_s3_connection(module, aws_connect_kwargs, location, ceph, endpoint_url, sig_4=False):
if ceph: # TODO - test this
ceph = urlparse(endpoint_url)
params = dict(module=module, conn_type='client', resource='s3', use_ssl=ceph.scheme == 'https',
region=location, endpoint=endpoint_url, **aws_connect_kwargs)
elif is_fakes3(endpoint_url):
fakes3 = urlparse(endpoint_url)
port = fakes3.port
if fakes3.scheme == 'fakes3s':
protocol = "https"
if port is None:
port = 443
else:
protocol = "http"
if port is None:
port = 80
params = dict(module=module, conn_type='client', resource='s3', region=location,
endpoint="%s://%s:%s" % (protocol, fakes3.hostname, to_text(port)),
use_ssl=fakes3.scheme == 'fakes3s', **aws_connect_kwargs)
else:
params = dict(module=module, conn_type='client', resource='s3', region=location, endpoint=endpoint_url, **aws_connect_kwargs)
if module.params['mode'] == 'put' and module.params['encryption_mode'] == 'aws:kms':
params['config'] = botocore.client.Config(signature_version='s3v4')
elif module.params['mode'] in ('get', 'getstr') and sig_4:
params['config'] = botocore.client.Config(signature_version='s3v4')
if module.params['dualstack']:
dualconf = botocore.client.Config(s3={'use_dualstack_endpoint': True})
if 'config' in params:
params['config'] = params['config'].merge(dualconf)
else:
params['config'] = dualconf
return boto3_conn(**params)


def main():

argument_spec = dict(
Expand All @@ -726,7 +680,7 @@ def main():
),
bucket_name=dict(required=True, type='str'),
object_name=dict(type='str'),
dualstack=dict(default='no', type='bool'),
dualstack=dict(default=False, type='bool'),
ceph=dict(default=False, type='bool', aliases=['rgw']),
)

Expand Down
60 changes: 58 additions & 2 deletions tests/unit/module_utils/test_s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

from __future__ import (absolute_import, division, print_function)
from operator import mod

__metaclass__ = type

from ansible_collections.amazon.aws.tests.unit.compat.mock import MagicMock
from ansible_collections.amazon.aws.plugins.module_utils import s3
from ansible.module_utils.basic import AnsibleModule

import pytest
from unittest.mock import MagicMock, Mock, patch, ANY, call


class FakeAnsibleModule(AnsibleModule):
Expand Down Expand Up @@ -84,3 +84,59 @@ def test_validate_bucket_name():
assert module.fail_json.called
s3.validate_bucket_name(module, "my")
assert module.fail_json.called


def test_is_fakes3_with_none_arg():
result = s3.is_fakes3(None)
assert not result


def test_is_fakes3_with_valid_protocol():
assert s3.is_fakes3("https://test-s3.amazon.com")


def test_is_fakes3_with_fakes3_protocol():
assert s3.is_fakes3("fakes3://test-s3.amazon.com")


def test_is_fakes3_with_fakes3s_protocol():
assert s3.is_fakes3("fakes3s://test-s3.amazon.com")


def test_get_s3_connection_ceph_with_https():
aws_connect = dict(
aws_access_key_id="ACCESS012345",
aws_secret_access_key="SECRET123",
)
region = "us-east-1"
s3_url = "https://test.ceph-s3.domain-name.com:8080"

module = MagicMock()
module.params = dict()
# module.boto3_conn.return_value = True

result = s3.get_s3_connection(
module=module,
aws_connect_kwargs=aws_connect,
location=region,
ceph=True,
endpoint_url=s3_url
)

# print("Result MagicMock: {}".format(result))

print("called: {}".format(module.called))
print("method call_count: {}".format(module.boto3_conn.call_count))
# print("call_args: {}".format(result.call_args))
# print("mock_calls: {}".format(result.mock_calls))
# print("call_args_list: {}".format(result.call_args_list))
print("method_calls: {}".format(module.method_calls))
# assert module.boto3_conn.assert_called_once()
assert module.boto3_conn.assert_called_with(
conn_type='client',
resource='s3',
region=region,
endpoint=s3_url,
use_ssl=True,
**aws_connect
)
29 changes: 0 additions & 29 deletions tests/unit/plugins/modules/test_s3_object.py

This file was deleted.

0 comments on commit 795bad0

Please sign in to comment.