-
Notifications
You must be signed in to change notification settings - Fork 2.9k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add facet results to search page iterator (#10182)
* refactor search paging * add facet results to search page iterator * py2 syntax * use method for facet results * pylint * use facet method on item paged as well
- Loading branch information
Showing
11 changed files
with
483 additions
and
101 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
from typing import TYPE_CHECKING | ||
|
||
import base64 | ||
import itertools | ||
import json | ||
|
||
from azure.core.paging import ItemPaged, PageIterator, ReturnType | ||
from ._generated.models import SearchRequest | ||
|
||
if TYPE_CHECKING: | ||
# pylint:disable=unused-import,ungrouped-imports | ||
from typing import Any, Union | ||
|
||
|
||
def convert_search_result(result): | ||
ret = result.additional_properties | ||
ret["@search.score"] = result.score | ||
ret["@search.highlights"] = result.highlights | ||
return ret | ||
|
||
|
||
def pack_continuation_token(response): | ||
if response.next_page_parameters is not None: | ||
return base64.b64encode( | ||
json.dumps( | ||
[response.next_link, response.next_page_parameters.serialize()] | ||
).encode("utf-8") | ||
) | ||
return None | ||
|
||
|
||
def unpack_continuation_token(token): | ||
next_link, next_page_parameters = json.loads(base64.b64decode(token)) | ||
next_page_request = SearchRequest.deserialize(next_page_parameters) | ||
return next_link, next_page_request | ||
|
||
|
||
class SearchItemPaged(ItemPaged[ReturnType]): | ||
def __init__(self, *args, **kwargs): | ||
super(SearchItemPaged, self).__init__(*args, **kwargs) | ||
self._first_page_iterator_instance = None | ||
|
||
def __next__(self): | ||
# type: () -> ReturnType | ||
if self._page_iterator is None: | ||
first_iterator = self._first_iterator_instance() | ||
self._page_iterator = itertools.chain.from_iterable(first_iterator) | ||
return next(self._page_iterator) | ||
|
||
def _first_iterator_instance(self): | ||
if self._first_page_iterator_instance is None: | ||
self._first_page_iterator_instance = self.by_page() | ||
return self._first_page_iterator_instance | ||
|
||
def get_facets(self): | ||
# type: () -> Union[dict, None] | ||
"""Return any facet results if faceting was requested. | ||
""" | ||
return self._first_iterator_instance().get_facets() | ||
|
||
|
||
class SearchPageIterator(PageIterator): | ||
def __init__(self, client, initial_query, kwargs, continuation_token=None): | ||
super(SearchPageIterator, self).__init__( | ||
get_next=self._get_next_cb, | ||
extract_data=self._extract_data_cb, | ||
continuation_token=continuation_token, | ||
) | ||
self._client = client | ||
self._initial_query = initial_query | ||
self._kwargs = kwargs | ||
self._facets = None | ||
|
||
def _get_next_cb(self, continuation_token): | ||
if continuation_token is None: | ||
return self._client.documents.search_post( | ||
search_request=self._initial_query.request, **self._kwargs | ||
) | ||
|
||
_next_link, next_page_request = unpack_continuation_token(continuation_token) | ||
|
||
return self._client.documents.search_post(search_request=next_page_request) | ||
|
||
def _extract_data_cb(self, response): # pylint:disable=no-self-use | ||
continuation_token = pack_continuation_token(response) | ||
facets = response.facets | ||
if facets is not None: | ||
self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()} | ||
|
||
results = [convert_search_result(r) for r in response.results] | ||
|
||
return continuation_token, results | ||
|
||
def get_facets(self): | ||
if self._current_page is None: | ||
self._response = self._get_next(self.continuation_token) | ||
self.continuation_token, self._current_page = self._extract_data( | ||
self._response | ||
) | ||
return self._facets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
90 changes: 90 additions & 0 deletions
90
sdk/search/azure-search/azure/search/_index/aio/_paging.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
# ------------------------------------------------------------------------- | ||
# Copyright (c) Microsoft Corporation. All rights reserved. | ||
# Licensed under the MIT License. See License.txt in the project root for | ||
# license information. | ||
# -------------------------------------------------------------------------- | ||
from typing import Union | ||
|
||
from azure.core.async_paging import AsyncItemPaged, AsyncPageIterator, ReturnType | ||
from .._generated.models import SearchRequest | ||
from .._paging import ( | ||
convert_search_result, | ||
pack_continuation_token, | ||
unpack_continuation_token, | ||
) | ||
|
||
|
||
class AsyncSearchItemPaged(AsyncItemPaged[ReturnType]): | ||
def __init__(self, *args, **kwargs): | ||
super(AsyncSearchItemPaged, self).__init__(*args, **kwargs) | ||
self._first_page_iterator_instance = None | ||
|
||
async def __anext__(self) -> ReturnType: | ||
if self._page_iterator is None: | ||
self._page_iterator = self.by_page() | ||
self._first_page_iterator_instance = self._page_iterator | ||
return await self.__anext__() | ||
if self._page is None: | ||
# Let it raise StopAsyncIteration | ||
self._page = await self._page_iterator.__anext__() | ||
return await self.__anext__() | ||
try: | ||
return await self._page.__anext__() | ||
except StopAsyncIteration: | ||
self._page = None | ||
return await self.__anext__() | ||
|
||
def _first_iterator_instance(self): | ||
if self._first_page_iterator_instance is None: | ||
self._page_iterator = self.by_page() | ||
self._first_page_iterator_instance = self._page_iterator | ||
return self._first_page_iterator_instance | ||
|
||
async def get_facets(self) -> Union[dict, None]: | ||
"""Return any facet results if faceting was requested. | ||
""" | ||
return await self._first_iterator_instance().get_facets() | ||
|
||
|
||
class AsyncSearchPageIterator(AsyncPageIterator[ReturnType]): | ||
def __init__(self, client, initial_query, kwargs, continuation_token=None): | ||
super(AsyncSearchPageIterator, self).__init__( | ||
get_next=self._get_next_cb, | ||
extract_data=self._extract_data_cb, | ||
continuation_token=continuation_token, | ||
) | ||
self._client = client | ||
self._initial_query = initial_query | ||
self._kwargs = kwargs | ||
self._facets = None | ||
|
||
async def _get_next_cb(self, continuation_token): | ||
if continuation_token is None: | ||
return await self._client.documents.search_post( | ||
search_request=self._initial_query.request, **self._kwargs | ||
) | ||
|
||
_next_link, next_page_request = unpack_continuation_token(continuation_token) | ||
|
||
return await self._client.documents.search_post( | ||
search_request=next_page_request | ||
) | ||
|
||
async def _extract_data_cb(self, response): # pylint:disable=no-self-use | ||
continuation_token = pack_continuation_token(response) | ||
facets = response.facets | ||
if facets is not None: | ||
self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()} | ||
|
||
results = [convert_search_result(r) for r in response.results] | ||
|
||
return continuation_token, results | ||
|
||
async def get_facets(self): | ||
if self._current_page is None: | ||
self._response = await self._get_next(self.continuation_token) | ||
self.continuation_token, self._current_page = await self._extract_data( | ||
self._response | ||
) | ||
return self._facets |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.