From 12675e8a504710e88480f7931adbc2c60de1a70e Mon Sep 17 00:00:00 2001 From: "Nikolay S." Date: Fri, 9 Aug 2024 14:03:38 +0200 Subject: [PATCH] Update elasticsearch_backend.py, added support 'track_total_hits' Support track_total_hits Handling in ElasticsearchSearchBackend - Implemented handling for `track_total_hits` in the Elasticsearch search backend. - Supports `bool` values (True/False) and `int` values. - Added a warning for invalid `track_total_hits` values. - Ensured the parameter is only added to search kwargs when valid. --- haystack/backends/elasticsearch_backend.py | 23 ++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/haystack/backends/elasticsearch_backend.py b/haystack/backends/elasticsearch_backend.py index e8febf9d3..2dfec9de2 100644 --- a/haystack/backends/elasticsearch_backend.py +++ b/haystack/backends/elasticsearch_backend.py @@ -15,6 +15,7 @@ FUZZY_MAX_EXPANSIONS, FUZZY_MIN_SIM, ID, + TRACK_TOTAL_HITS, ) from haystack.exceptions import MissingDependency, MoreLikeThisError, SkipDocument from haystack.inputs import Clean, Exact, PythonData, Raw @@ -545,6 +546,28 @@ def build_search_kwargs( if extra_kwargs: kwargs.update(extra_kwargs) + # If TRACK_TOTAL_HITS is False, 0, or None, do not include the parameter + if TRACK_TOTAL_HITS: + # Define a mapping for the track_total_hits parameter + # - If TRACK_TOTAL_HITS is True (bool), map to "true" (string) + # - If TRACK_TOTAL_HITS is an integer, use its value directly + track_total_hits_mapper = { + bool: "true", # Maps boolean True to "true" + int: TRACK_TOTAL_HITS, # Maps integer to its value + } + + # Get the mapped value based on the type of TRACK_TOTAL_HITS + # If the type is not in the mapper, fallback to False + track_total_hits = track_total_hits_mapper.get(type(TRACK_TOTAL_HITS), False) + + # If a valid track_total_hits value is obtained, add it to search_kwargs + if track_total_hits: + search_kwargs['track_total_hits'] = track_total_hits + else: + # Issue a warning if TRACK_TOTAL_HITS is not of type bool or int + warnings.warn( + "Wrong value of HAYSTACK_TRACK_TOTAL_HITS is provided. Valid options are `bool` or `int`." + ) return kwargs @log_query