From c778846e073de45943377ad6250e8fb5bec52bb4 Mon Sep 17 00:00:00 2001 From: kiraksi Date: Wed, 13 Mar 2024 19:36:03 -0700 Subject: [PATCH 1/2] feat: Add AsyncHTTPIterator to page_iterators --- google/api_core/page_iterator_async.py | 166 +++++++++++++- tests/asyncio/test_page_iterator_async.py | 268 +++++++++++++++++++++- 2 files changed, 432 insertions(+), 2 deletions(-) diff --git a/google/api_core/page_iterator_async.py b/google/api_core/page_iterator_async.py index c0725758..dede8412 100644 --- a/google/api_core/page_iterator_async.py +++ b/google/api_core/page_iterator_async.py @@ -69,7 +69,7 @@ import abc -from google.api_core.page_iterator import Page +from google.api_core.page_iterator import Page, _do_nothing_page_start def _item_to_value_identity(iterator, item): @@ -199,6 +199,170 @@ async def _next_page(self): raise NotImplementedError +class AsyncHTTPIterator(AsyncIterator): + """A generic class for iterating through HTTP/JSON API list responses for asynchronous I/O. + + To make an iterator work, you'll need to provide a way to convert a JSON + item returned from the API into the object of your choice (via + ``item_to_value``). You also may need to specify a custom ``items_key`` so + that a given response (containing a page of results) can be parsed into an + iterable page of the actual objects you want. + + Args: + client (google.cloud.client.Client): The API client. + api_request (Callable): The function to use to make API requests. + Generally, this will be a coroutine of + :meth:`google.cloud._http.JSONConnection.api_request`. + path (str): The method path to query for the list of items. + item_to_value (Callable[google.api_core.page_iterator.Iterator, Any]): + Callable to convert an item from the type in the JSON response into + a native object. Will be called with the iterator and a single + item. + items_key (str): The key in the API response where the list of items + can be found. + page_token (str): A token identifying a page in a result set to start + fetching results from. + page_size (int): The maximum number of results to fetch per page + max_results (int): The maximum number of results to fetch + extra_params (dict): Extra query string parameters for the + API call. + page_start (Callable[ + google.api_core.page_iterator.Iterator, + google.api_core.page_iterator.Page, dict]): Callable to provide + any special behavior after a new page has been created. Assumed + signature takes the :class:`.Iterator` that started the page, + the :class:`.Page` that was started and the dictionary containing + the page response. + next_token (str): The name of the field used in the response for page + tokens. + + .. autoattribute:: pages + """ + + _DEFAULT_ITEMS_KEY = "items" + _PAGE_TOKEN = "pageToken" + _MAX_RESULTS = "maxResults" + _NEXT_TOKEN = "nextPageToken" + _RESERVED_PARAMS = frozenset([_PAGE_TOKEN]) + _HTTP_METHOD = "GET" + + def __init__( + self, + client, + api_request, + path, + item_to_value, + items_key=_DEFAULT_ITEMS_KEY, + page_token=None, + page_size=None, + max_results=None, + extra_params=None, + page_start=_do_nothing_page_start, + next_token=_NEXT_TOKEN, + ): + super().__init__( + client, item_to_value, page_token=page_token, max_results=max_results + ) + self.api_request = api_request + self.path = path + self._items_key = items_key + self.extra_params = extra_params + self._page_size = page_size + self._page_start = page_start + self._next_token = next_token + # Verify inputs / provide defaults. + if self.extra_params is None: + self.extra_params = {} + self._verify_params() + + def _verify_params(self): + """Verifies the parameters don't use any reserved parameter. + + Raises: + ValueError: If a reserved parameter is used. + """ + reserved_in_use = self._RESERVED_PARAMS.intersection(self.extra_params) + if reserved_in_use: + raise ValueError("Using a reserved parameter", reserved_in_use) + + async def _next_page(self): + """Get the next page in the iterator. + + Returns: + Optional[Page]: The next page in the iterator or :data:`None` if + there are no pages left. + """ + if self._has_next_page(): + response = await self._get_next_page_response() + items = response.get(self._items_key, ()) + page = Page(self, items, self.item_to_value, raw_page=response) + self._page_start(self, page, response) + self.next_page_token = response.get(self._next_token) + return page + else: + return None + + def _has_next_page(self): + """Determines whether or not there are more pages with results. + + Returns: + bool: Whether the iterator has more pages. + """ + if self.page_number == 0: + return True + + if self.max_results is not None: + if self.num_results >= self.max_results: + return False + + return self.next_page_token is not None + + def _get_query_params(self): + """Getter for query parameters for the next request. + + Returns: + dict: A dictionary of query parameters. + """ + result = {} + if self.next_page_token is not None: + result[self._PAGE_TOKEN] = self.next_page_token + + page_size = None + if self.max_results is not None: + page_size = self.max_results - self.num_results + if self._page_size is not None: + page_size = min(page_size, self._page_size) + elif self._page_size is not None: + page_size = self._page_size + + if page_size is not None: + result[self._MAX_RESULTS] = page_size + + result.update(self.extra_params) + return result + + async def _get_next_page_response(self): + """Requests the next page from the path provided. + + Returns: + dict: The parsed JSON response of the next page's contents. + + Raises: + ValueError: If the HTTP method is not ``GET`` or ``POST``. + """ + params = self._get_query_params() + if self._HTTP_METHOD == "GET": + return await self.api_request( + method=self._HTTP_METHOD, path=self.path, query_params=params + ) + elif self._HTTP_METHOD == "POST": + return await self.api_request( + method=self._HTTP_METHOD, path=self.path, data=params + ) + else: + raise ValueError("Unexpected HTTP method", self._HTTP_METHOD) + + class AsyncGRPCIterator(AsyncIterator): """A generic class for iterating through gRPC list responses. diff --git a/tests/asyncio/test_page_iterator_async.py b/tests/asyncio/test_page_iterator_async.py index 75f9e1cf..efb68fe0 100644 --- a/tests/asyncio/test_page_iterator_async.py +++ b/tests/asyncio/test_page_iterator_async.py @@ -16,8 +16,9 @@ import mock import pytest +import math -from google.api_core import page_iterator_async +from google.api_core import page_iterator_async, page_iterator class PageAsyncIteratorImpl(page_iterator_async.AsyncIterator): @@ -189,6 +190,271 @@ def test___aiter___restart_after_page(self): iterator.__aiter__() +class TestAsyncHTTPIterator(object): + def test_constructor(self): + client = mock.sentinel.client + path = "/foo" + iterator = page_iterator_async.AsyncHTTPIterator( + client, mock.sentinel.api_request, path, mock.sentinel.item_to_value + ) + + assert not iterator._started + assert iterator.client is client + assert iterator.path == path + assert iterator.item_to_value is mock.sentinel.item_to_value + assert iterator._items_key == "items" + assert iterator.max_results is None + assert iterator.extra_params == {} + assert iterator._page_start == page_iterator._do_nothing_page_start + # Changing attributes. + assert iterator.page_number == 0 + assert iterator.next_page_token is None + assert iterator.num_results == 0 + assert iterator._page_size is None + + def test_constructor_w_extra_param_collision(self): + extra_params = {"pageToken": "val"} + + with pytest.raises(ValueError): + page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + extra_params=extra_params, + ) + + @pytest.mark.asyncio + async def test_iterate(self): + path = "/foo" + item1 = {"name": "1"} + item2 = {"name": "2"} + api_request = mock.AsyncMock(return_value={"items": [item1, item2]}) + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + api_request, + path=path, + item_to_value=page_iterator._item_to_value_identity, + ) + + assert iterator.num_results == 0 + + items = [] + async for item in iterator: + items.append(item) + + assert items == [{"name": "1"}, {"name": "2"}] + assert iterator.num_results == 2 + + api_request.assert_called_once_with(method="GET", path=path, query_params={}) + + def test__has_next_page_new(self): + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + + # The iterator should *always* indicate that it has a next page + # when created so that it can fetch the initial page. + assert iterator._has_next_page() + + def test__has_next_page_without_token(self): + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + + iterator.page_number = 1 + + # The iterator should not indicate that it has a new page if the + # initial page has been requested and there's no page token. + assert not iterator._has_next_page() + + def test__has_next_page_w_number_w_token(self): + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + + iterator.page_number = 1 + iterator.next_page_token = mock.sentinel.token + + # The iterator should indicate that it has a new page if the + # initial page has been requested and there's is a page token. + assert iterator._has_next_page() + + def test__has_next_page_w_max_results_not_done(self): + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + max_results=3, + page_token=mock.sentinel.token, + ) + + iterator.page_number = 1 + + # The iterator should indicate that it has a new page if there + # is a page token and it has not consumed more than max_results. + assert iterator.num_results < iterator.max_results + assert iterator._has_next_page() + + def test__has_next_page_w_max_results_done(self): + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + max_results=3, + page_token=mock.sentinel.token, + ) + + iterator.page_number = 1 + iterator.num_results = 3 + + # The iterator should not indicate that it has a new page if there + # if it has consumed more than max_results. + assert iterator.num_results == iterator.max_results + assert not iterator._has_next_page() + + def test__get_query_params_no_token(self): + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + + assert iterator._get_query_params() == {} + + def test__get_query_params_w_token(self): + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + iterator.next_page_token = "token" + + assert iterator._get_query_params() == {"pageToken": iterator.next_page_token} + + def test__get_query_params_w_max_results(self): + max_results = 3 + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + max_results=max_results, + ) + + iterator.num_results = 1 + local_max = max_results - iterator.num_results + + assert iterator._get_query_params() == {"maxResults": local_max} + + def test__get_query_params_extra_params(self): + extra_params = {"key": "val"} + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + extra_params=extra_params, + ) + + assert iterator._get_query_params() == extra_params + + @pytest.mark.asyncio + async def test__get_next_page_response_with_post(self): + path = "/foo" + page_response = {"items": ["one", "two"]} + api_request = mock.AsyncMock(return_value=page_response) + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + api_request, + path=path, + item_to_value=page_iterator._item_to_value_identity, + ) + iterator._HTTP_METHOD = "POST" + + response = await iterator._get_next_page_response() + + assert response == page_response + + api_request.assert_called_once_with(method="POST", path=path, data={}) + + @pytest.mark.asyncio + async def test__get_next_page_bad_http_method(self): + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + mock.sentinel.api_request, + mock.sentinel.path, + mock.sentinel.item_to_value, + ) + iterator._HTTP_METHOD = "NOT-A-VERB" + + with pytest.raises(ValueError): + await iterator._get_next_page_response() + + @pytest.mark.parametrize( + "page_size,max_results,pages", + [(3, None, False), (3, 8, False), (3, None, True), (3, 8, True)], + ) + @pytest.mark.asyncio + async def test_page_size_items(self, page_size, max_results, pages): + path = "/foo" + NITEMS = 10 + + n = [0] # blast you python 2! + + async def api_request(*args, **kw): + assert not args + query_params = dict( + maxResults=( + page_size + if max_results is None + else min(page_size, max_results - n[0]) + ) + ) + if n[0]: + query_params.update(pageToken="test") + assert kw == {"method": "GET", "path": "/foo", "query_params": query_params} + n_items = min(kw["query_params"]["maxResults"], NITEMS - n[0]) + items = [dict(name=str(i + n[0])) for i in range(n_items)] + n[0] += n_items + result = dict(items=items) + if n[0] < NITEMS: + result.update(nextPageToken="test") + return result + + iterator = page_iterator_async.AsyncHTTPIterator( + mock.sentinel.client, + api_request, + path=path, + item_to_value=page_iterator._item_to_value_identity, + page_size=page_size, + max_results=max_results, + ) + + assert iterator.num_results == 0 + + n_results = max_results if max_results is not None else NITEMS + + items = [] + async for item in iterator: + items.append(item) + + assert iterator.num_results == n_results + + class TestAsyncGRPCIterator(object): def test_constructor(self): client = mock.sentinel.client From 45d6ae6dbfa64ad8533577fc98a2cfc05eae4237 Mon Sep 17 00:00:00 2001 From: kiraksi Date: Thu, 14 Mar 2024 11:16:25 -0700 Subject: [PATCH 2/2] lint --- tests/asyncio/test_page_iterator_async.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/asyncio/test_page_iterator_async.py b/tests/asyncio/test_page_iterator_async.py index efb68fe0..3d0bcb8d 100644 --- a/tests/asyncio/test_page_iterator_async.py +++ b/tests/asyncio/test_page_iterator_async.py @@ -16,7 +16,6 @@ import mock import pytest -import math from google.api_core import page_iterator_async, page_iterator