Skip to content

Commit

Permalink
Add facet results to search page iterator (#10182)
Browse files Browse the repository at this point in the history
* refactor search paging

* add facet results to search page iterator

* py2 syntax

* use method for facet results

* pylint

* use facet method on item paged as well
  • Loading branch information
bryevdv authored Mar 9, 2020
1 parent 15fdf4c commit 6b6ada1
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 101 deletions.
106 changes: 106 additions & 0 deletions sdk/search/azure-search/azure/search/_index/_paging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import TYPE_CHECKING

import base64
import itertools
import json

from azure.core.paging import ItemPaged, PageIterator, ReturnType
from ._generated.models import SearchRequest

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Union


def convert_search_result(result):
ret = result.additional_properties
ret["@search.score"] = result.score
ret["@search.highlights"] = result.highlights
return ret


def pack_continuation_token(response):
if response.next_page_parameters is not None:
return base64.b64encode(
json.dumps(
[response.next_link, response.next_page_parameters.serialize()]
).encode("utf-8")
)
return None


def unpack_continuation_token(token):
next_link, next_page_parameters = json.loads(base64.b64decode(token))
next_page_request = SearchRequest.deserialize(next_page_parameters)
return next_link, next_page_request


class SearchItemPaged(ItemPaged[ReturnType]):
def __init__(self, *args, **kwargs):
super(SearchItemPaged, self).__init__(*args, **kwargs)
self._first_page_iterator_instance = None

def __next__(self):
# type: () -> ReturnType
if self._page_iterator is None:
first_iterator = self._first_iterator_instance()
self._page_iterator = itertools.chain.from_iterable(first_iterator)
return next(self._page_iterator)

def _first_iterator_instance(self):
if self._first_page_iterator_instance is None:
self._first_page_iterator_instance = self.by_page()
return self._first_page_iterator_instance

def get_facets(self):
# type: () -> Union[dict, None]
"""Return any facet results if faceting was requested.
"""
return self._first_iterator_instance().get_facets()


class SearchPageIterator(PageIterator):
def __init__(self, client, initial_query, kwargs, continuation_token=None):
super(SearchPageIterator, self).__init__(
get_next=self._get_next_cb,
extract_data=self._extract_data_cb,
continuation_token=continuation_token,
)
self._client = client
self._initial_query = initial_query
self._kwargs = kwargs
self._facets = None

def _get_next_cb(self, continuation_token):
if continuation_token is None:
return self._client.documents.search_post(
search_request=self._initial_query.request, **self._kwargs
)

_next_link, next_page_request = unpack_continuation_token(continuation_token)

return self._client.documents.search_post(search_request=next_page_request)

def _extract_data_cb(self, response): # pylint:disable=no-self-use
continuation_token = pack_continuation_token(response)
facets = response.facets
if facets is not None:
self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()}

results = [convert_search_result(r) for r in response.results]

return continuation_token, results

def get_facets(self):
if self._current_page is None:
self._response = self._get_next(self.continuation_token)
self.continuation_token, self._current_page = self._extract_data(
self._response
)
return self._facets
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,14 @@
# --------------------------------------------------------------------------
from typing import cast, List, TYPE_CHECKING

import base64
import json
import six

from azure.core.paging import ItemPaged, PageIterator
from azure.core.pipeline.policies import HeadersPolicy
from azure.core.tracing.decorator import distributed_trace
from ._generated import SearchIndexClient as _SearchIndexClient
from ._generated.models import IndexBatch, IndexingResult, SearchRequest
from ._generated.models import IndexBatch, IndexingResult
from ._index_documents_batch import IndexDocumentsBatch
from ._paging import SearchItemPaged, SearchPageIterator
from ._queries import AutocompleteQuery, SearchQuery, SuggestQuery

if TYPE_CHECKING:
Expand Down Expand Up @@ -51,58 +49,6 @@ def odata(statement, **kwargs):
return statement.format(**kw)


def convert_search_result(result):
ret = result.additional_properties
ret["@search.score"] = result.score
ret["@search.highlights"] = result.highlights
return ret


def pack_continuation_token(response):
if response.next_page_parameters is not None:
return base64.b64encode(
json.dumps(
[response.next_link, response.next_page_parameters.serialize()]
).encode("utf-8")
)
return None


def unpack_continuation_token(token):
next_link, next_page_parameters = json.loads(base64.b64decode(token))
next_page_request = SearchRequest.deserialize(next_page_parameters)
return next_link, next_page_request


class _SearchDocumentsPaged(PageIterator):
def __init__(self, client, initial_query, kwargs, continuation_token=None):
super(_SearchDocumentsPaged, self).__init__(
get_next=self._get_next_cb,
extract_data=self._extract_data_cb,
continuation_token=continuation_token,
)
self._client = client
self._initial_query = initial_query
self._kwargs = kwargs

def _get_next_cb(self, continuation_token):
if continuation_token is None:
return self._client.documents.search_post(
search_request=self._initial_query.request, **self._kwargs
)

_next_link, next_page_request = unpack_continuation_token(continuation_token)

return self._client.documents.search_post(search_request=next_page_request)

def _extract_data_cb(self, response): # pylint:disable=no-self-use
continuation_token = pack_continuation_token(response)

results = [convert_search_result(r) for r in response.results]

return continuation_token, results


class SearchIndexClient(object):
"""A client to interact with an existing Azure search index.
Expand Down Expand Up @@ -184,7 +130,7 @@ def get_document(self, key, selected_fields=None, **kwargs):

@distributed_trace
def search(self, query, **kwargs):
# type: (Union[str, SearchQuery], **Any) -> ItemPaged[dict]
# type: (Union[str, SearchQuery], **Any) -> SearchItemPaged[dict]
"""Search the Azure search index for documents.
:param query: An query for searching the index
Expand Down Expand Up @@ -218,8 +164,8 @@ def search(self, query, **kwargs):
)
)

return ItemPaged(
self._client, query, kwargs, page_iterator_class=_SearchDocumentsPaged
return SearchItemPaged(
self._client, query, kwargs, page_iterator_class=SearchPageIterator
)

@distributed_trace
Expand Down
90 changes: 90 additions & 0 deletions sdk/search/azure-search/azure/search/_index/aio/_paging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# -------------------------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License. See License.txt in the project root for
# license information.
# --------------------------------------------------------------------------
from typing import Union

from azure.core.async_paging import AsyncItemPaged, AsyncPageIterator, ReturnType
from .._generated.models import SearchRequest
from .._paging import (
convert_search_result,
pack_continuation_token,
unpack_continuation_token,
)


class AsyncSearchItemPaged(AsyncItemPaged[ReturnType]):
def __init__(self, *args, **kwargs):
super(AsyncSearchItemPaged, self).__init__(*args, **kwargs)
self._first_page_iterator_instance = None

async def __anext__(self) -> ReturnType:
if self._page_iterator is None:
self._page_iterator = self.by_page()
self._first_page_iterator_instance = self._page_iterator
return await self.__anext__()
if self._page is None:
# Let it raise StopAsyncIteration
self._page = await self._page_iterator.__anext__()
return await self.__anext__()
try:
return await self._page.__anext__()
except StopAsyncIteration:
self._page = None
return await self.__anext__()

def _first_iterator_instance(self):
if self._first_page_iterator_instance is None:
self._page_iterator = self.by_page()
self._first_page_iterator_instance = self._page_iterator
return self._first_page_iterator_instance

async def get_facets(self) -> Union[dict, None]:
"""Return any facet results if faceting was requested.
"""
return await self._first_iterator_instance().get_facets()


class AsyncSearchPageIterator(AsyncPageIterator[ReturnType]):
def __init__(self, client, initial_query, kwargs, continuation_token=None):
super(AsyncSearchPageIterator, self).__init__(
get_next=self._get_next_cb,
extract_data=self._extract_data_cb,
continuation_token=continuation_token,
)
self._client = client
self._initial_query = initial_query
self._kwargs = kwargs
self._facets = None

async def _get_next_cb(self, continuation_token):
if continuation_token is None:
return await self._client.documents.search_post(
search_request=self._initial_query.request, **self._kwargs
)

_next_link, next_page_request = unpack_continuation_token(continuation_token)

return await self._client.documents.search_post(
search_request=next_page_request
)

async def _extract_data_cb(self, response): # pylint:disable=no-self-use
continuation_token = pack_continuation_token(response)
facets = response.facets
if facets is not None:
self._facets = {k: [x.as_dict() for x in v] for k, v in facets.items()}

results = [convert_search_result(r) for r in response.results]

return continuation_token, results

async def get_facets(self):
if self._current_page is None:
self._response = await self._get_next(self.continuation_token)
self.continuation_token, self._current_page = await self._extract_data(
self._response
)
return self._facets
Original file line number Diff line number Diff line change
Expand Up @@ -7,56 +7,20 @@

import six

from azure.core.async_paging import AsyncItemPaged, AsyncPageIterator
from azure.core.pipeline.policies import HeadersPolicy
from azure.core.tracing.decorator_async import distributed_trace_async
from ._paging import AsyncSearchItemPaged, AsyncSearchPageIterator
from .._generated.aio import SearchIndexClient as _SearchIndexClient
from .._generated.models import IndexBatch, IndexingResult, SearchRequest
from .._index_documents_batch import IndexDocumentsBatch
from .._queries import AutocompleteQuery, SearchQuery, SuggestQuery
from .._search_index_client import (
convert_search_result,
pack_continuation_token,
unpack_continuation_token,
)

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any, Union
from .._credential import SearchApiKeyCredential


class _SearchDocumentsPagedAsync(AsyncPageIterator):
def __init__(self, client, initial_query, kwargs, continuation_token=None):
super(_SearchDocumentsPagedAsync, self).__init__(
get_next=self._get_next_cb,
extract_data=self._extract_data_cb,
continuation_token=continuation_token,
)
self._client = client
self._initial_query = initial_query
self._kwargs = kwargs

async def _get_next_cb(self, continuation_token):
if continuation_token is None:
return await self._client.documents.search_post(
search_request=self._initial_query.request, **self._kwargs
)

_next_link, next_page_request = unpack_continuation_token(continuation_token)

return await self._client.documents.search_post(
search_request=next_page_request
)

async def _extract_data_cb(self, response): # pylint:disable=no-self-use
continuation_token = pack_continuation_token(response)

results = [convert_search_result(r) for r in response.results]

return continuation_token, results


class SearchIndexClient(object):
"""A client to interact with an existing Azure search index.
Expand Down Expand Up @@ -139,7 +103,7 @@ async def get_document(self, key, selected_fields=None, **kwargs):

@distributed_trace_async
async def search(self, query, **kwargs):
# type: (Union[str, SearchQuery], **Any) -> AsyncItemPaged[dict]
# type: (Union[str, SearchQuery], **Any) -> AsyncSearchItemPaged[dict]
"""Search the Azure search index for documents.
:param query: An query for searching the index
Expand Down Expand Up @@ -173,8 +137,8 @@ async def search(self, query, **kwargs):
)
)

return AsyncItemPaged(
self._client, query, kwargs, page_iterator_class=_SearchDocumentsPagedAsync
return AsyncSearchItemPaged(
self._client, query, kwargs, page_iterator_class=AsyncSearchPageIterator
)

@distributed_trace_async
Expand Down
Loading

0 comments on commit 6b6ada1

Please sign in to comment.