diff --git a/botocore/exceptions.py b/botocore/exceptions.py index f629d597ec..90d0a3abb6 100644 --- a/botocore/exceptions.py +++ b/botocore/exceptions.py @@ -219,3 +219,8 @@ class IncompleteReadError(BotoCoreError): """HTTP response did not return expected number of bytes.""" fmt = ('{actual_bytes} read, but total bytes ' 'expected is {expected_bytes}.') + + +class InvalidExpressionError(BotoCoreError): + """Expression is either invalid or too complex.""" + fmt = 'Invalid expression {expression}: Only dotted lookups are supported.' diff --git a/botocore/paginate.py b/botocore/paginate.py index 1a68ff5540..50a739a72d 100644 --- a/botocore/paginate.py +++ b/botocore/paginate.py @@ -26,6 +26,7 @@ import jmespath from botocore.exceptions import PaginationError +from botocore.utils import set_value_from_jmespath class Paginator(object): @@ -66,6 +67,7 @@ def _get_result_keys(self, config): if result_key is not None: if not isinstance(result_key, list): result_key = [result_key] + result_key = [jmespath.compile(rk) for rk in result_key] return result_key def paginate(self, endpoint, **kwargs): @@ -145,7 +147,10 @@ def __iter__(self): starting_truncation = self._handle_first_request( parsed, primary_result_key, starting_truncation) first_request = False - num_current_response = len(parsed.get(primary_result_key, [])) + current_response = primary_result_key.search(parsed) + if current_response is None: + current_response = [] + num_current_response = len(current_response) truncate_amount = 0 if self._max_items is not None: truncate_amount = (total_items + num_current_response) \ @@ -196,23 +201,33 @@ def _handle_first_request(self, parsed, primary_result_key, # First we need to slice into the array and only return # the truncated amount. starting_truncation = self._parse_starting_token()[1] - parsed[primary_result_key] = parsed[ - primary_result_key][starting_truncation:] + all_data = primary_result_key.search(parsed) + set_value_from_jmespath( + parsed, + primary_result_key.expression, + all_data[starting_truncation:] + ) # We also need to truncate any secondary result keys # because they were not truncated in the previous last # response. for token in self.result_keys: if token == primary_result_key: continue - parsed[token] = [] + set_value_from_jmespath(parsed, token.expression, []) return starting_truncation def _truncate_response(self, parsed, primary_result_key, truncate_amount, starting_truncation, next_token): - original = parsed.get(primary_result_key, []) + original = primary_result_key.search(parsed) + if original is None: + original = [] amount_to_keep = len(original) - truncate_amount truncated = original[:amount_to_keep] - parsed[primary_result_key] = truncated + set_value_from_jmespath( + parsed, + primary_result_key.expression, + truncated + ) # The issue here is that even though we know how much we've truncated # we need to account for this globally including any starting # left truncation. For example: @@ -246,11 +261,12 @@ def build_full_result(self): response = {} key_names = [i.result_key for i in iterators] for key in key_names: - response[key] = [] + set_value_from_jmespath(response, key.expression, []) for vals in zip_longest(*iterators): for k, val in zip(key_names, vals): if val is not None: - response[k].append(val) + response.setdefault(k.expression, []) + response[k.expression].append(val) if self.resume_token is not None: response['NextToken'] = self.resume_token return response @@ -283,5 +299,8 @@ def __init__(self, pages_iterator, result_key): def __iter__(self): for _, page in self._pages_iterator: - for result in page.get(self.result_key, []): + results = self.result_key.search(page) + if results is None: + results = [] + for result in results: yield result diff --git a/botocore/utils.py b/botocore/utils.py index 73058128dc..7fd2222f6d 100644 --- a/botocore/utils.py +++ b/botocore/utils.py @@ -11,6 +11,8 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. +from .exceptions import InvalidExpressionError + def normalize_url_path(path): if not path: @@ -57,3 +59,47 @@ def remove_dot_segments(url): output.append(url[:next_slash]) url = url[next_slash:] return ''.join(output) + + +def validate_jmespath_for_set(expression): + # Validates a limited jmespath expression to determine if we can set a value + # based on it. Only works with dotted paths. + if not expression or expression == '.': + raise InvalidExpressionError(expression=expression) + + for invalid in ['[', ']', '*']: + if invalid in expression: + raise InvalidExpressionError(expression=expression) + + +def set_value_from_jmespath(source, expression, value, is_first=True): + # This takes a (limited) jmespath-like expression & can set a value based + # on it. + # Limitations: + # * Only handles dotted lookups + # * No offsets/wildcards/slices/etc. + if is_first: + validate_jmespath_for_set(expression) + + bits = expression.split('.', 1) + current_key, remainder = bits[0], bits[1] if len(bits) > 1 else '' + + if not current_key: + raise InvalidExpressionError(expression=expression) + + if remainder: + if not current_key in source: + # We've got something in the expression that's not present in the + # source (new key). If there's any more bits, we'll set the key with + # an empty dictionary. + source[current_key] = {} + + return set_value_from_jmespath( + source[current_key], + remainder, + value, + is_first=False + ) + + # If we're down to a single key, set it. + source[current_key] = value diff --git a/tests/unit/test_paginate.py b/tests/unit/test_paginate.py index b74a4f7fe8..c0367d61ff 100644 --- a/tests/unit/test_paginate.py +++ b/tests/unit/test_paginate.py @@ -31,7 +31,10 @@ def setUp(self): self.paginator = Paginator(self.operation) def test_result_key_available(self): - self.assertEqual(self.paginator.result_keys, ['Foo']) + self.assertEqual( + [rk.expression for rk in self.paginator.result_keys], + ['Foo'] + ) def test_no_next_token(self): response = {'not_the_next_token': 'foobar'} @@ -506,11 +509,57 @@ def test_resume_with_multiple_input_keys(self): [mock.call(None, InMarker1='m1', InMarker2='m2'),]) def test_result_key_exposed_on_paginator(self): - self.assertEqual(self.paginator.result_keys, ['Users', 'Groups']) + self.assertEqual( + [rk.expression for rk in self.paginator.result_keys], + ['Users', 'Groups'] + ) def test_result_key_exposed_on_page_iterator(self): pages = self.paginator.paginate(None, max_items=3) - self.assertEqual(pages.result_keys, ['Users', 'Groups']) + self.assertEqual( + [rk.expression for rk in pages.result_keys], + ['Users', 'Groups'] + ) + + +class TestExpressionKeyIterators(unittest.TestCase): + def setUp(self): + self.operation = mock.Mock() + # This is something like what we'd see in RDS. + self.paginate_config = { + "py_input_token": "Marker", + "output_token": "Marker", + "limit_key": "MaxRecords", + "result_key": "EngineDefaults.Parameters" + } + self.operation.pagination = self.paginate_config + self.paginator = Paginator(self.operation) + self.responses = [ + (None, { + "EngineDefaults": {"Parameters": ["One", "Two"] + }, "Marker": "m1"}), + (None, { + "EngineDefaults": {"Parameters": ["Three", "Four"] + }, "Marker": "m2"}), + (None, {"EngineDefaults": {"Parameters": ["Five"]}}), + ] + + def test_result_key_iters(self): + self.operation.call.side_effect = self.responses + pages = self.paginator.paginate(None) + iterators = pages.result_key_iters() + self.assertEqual(len(iterators), 1) + self.assertEqual(list(iterators[0]), + ['One', 'Two', 'Three', 'Four', 'Five']) + + def test_build_full_result_with_single_key(self): + self.operation.call.side_effect = self.responses + pages = self.paginator.paginate(None) + complete = pages.build_full_result() + self.assertEqual(complete, { + 'EngineDefaults': {'Parameters': []}, + 'EngineDefaults.Parameters': ['One', 'Two', 'Three', 'Four', 'Five'] + }) if __name__ == '__main__': diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 360f5d6512..d568c043e4 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -11,11 +11,14 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License. -import unittest +from tests import unittest from botocore import xform_name +from botocore.exceptions import InvalidExpressionError from botocore.utils import remove_dot_segments from botocore.utils import normalize_url_path +from botocore.utils import validate_jmespath_for_set +from botocore.utils import set_value_from_jmespath class TestURINormalization(unittest.TestCase): @@ -69,5 +72,64 @@ def test_special_cases(self): self.assertEqual(xform_name('CreateStorediSCSIVolume', '-'), 'create-stored-iscsi-volume') +class TestValidateJMESPathForSet(unittest.TestCase): + def setUp(self): + super(TestValidateJMESPathForSet, self).setUp() + self.data = { + 'Response': { + 'Thing': { + 'Id': 1, + 'Name': 'Thing #1', + } + }, + 'Marker': 'some-token' + } + + def test_invalid_exp(self): + with self.assertRaises(InvalidExpressionError): + validate_jmespath_for_set('Response.*.Name') + + with self.assertRaises(InvalidExpressionError): + validate_jmespath_for_set('Response.Things[0]') + + with self.assertRaises(InvalidExpressionError): + validate_jmespath_for_set('') + + with self.assertRaises(InvalidExpressionError): + validate_jmespath_for_set('.') + + +class TestSetValueFromJMESPath(unittest.TestCase): + def setUp(self): + super(TestSetValueFromJMESPath, self).setUp() + self.data = { + 'Response': { + 'Thing': { + 'Id': 1, + 'Name': 'Thing #1', + } + }, + 'Marker': 'some-token' + } + + def test_single_depth_existing(self): + set_value_from_jmespath(self.data, 'Marker', 'new-token') + self.assertEqual(self.data['Marker'], 'new-token') + + def test_single_depth_new(self): + self.assertFalse('Limit' in self.data) + set_value_from_jmespath(self.data, 'Limit', 100) + self.assertEqual(self.data['Limit'], 100) + + def test_multiple_depth_existing(self): + set_value_from_jmespath(self.data, 'Response.Thing.Name', 'New Name') + self.assertEqual(self.data['Response']['Thing']['Name'], 'New Name') + + def test_multiple_depth_new(self): + self.assertFalse('Brand' in self.data) + set_value_from_jmespath(self.data, 'Brand.New', {'abc': 123}) + self.assertEqual(self.data['Brand']['New']['abc'], 123) + + if __name__ == '__main__': unittest.main()