diff --git a/src/marqo/index.py b/src/marqo/index.py index a08aba2a..c2d16d34 100644 --- a/src/marqo/index.py +++ b/src/marqo/index.py @@ -209,7 +209,7 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op context: Optional[dict] = None, score_modifiers: Optional[dict] = None, model_auth: Optional[dict] = None, ef_search: Optional[int] = None, approximate: Optional[bool] = None, - text_query_prefix: Optional[str] = None, + text_query_prefix: Optional[str] = None, hybrid_parameters: Optional[dict] = None ) -> Dict[str, Any]: """Search the index. @@ -273,6 +273,7 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op "reRanker": reranker, "boost": boost, "textQueryPrefix": text_query_prefix, + "hybridParameters": hybrid_parameters } body = {k: v for k, v in body.items() if v is not None} diff --git a/src/marqo/models/search_models.py b/src/marqo/models/search_models.py index 8c390530..e5e47bcc 100644 --- a/src/marqo/models/search_models.py +++ b/src/marqo/models/search_models.py @@ -1,5 +1,9 @@ from typing import Dict, List, Optional, Union from marqo.models.marqo_models import StrictBaseModel +from abc import ABC +from enum import Enum + +from pydantic import validator, BaseModel, root_validator class SearchBody(StrictBaseModel): @@ -26,3 +30,15 @@ class BulkSearchBody(SearchBody): class BulkSearchQuery(StrictBaseModel): queries: List[BulkSearchBody] + +class RetrievalMethod(str, Enum): + Disjunction = 'disjunction' + Tensor = 'tensor' + Lexical = 'lexical' + + +class RankingMethod(str, Enum): + RRF = 'rrf' + NormalizeLinear = 'normalize_linear' + Tensor = 'tensor' + Lexical = 'lexical' \ No newline at end of file diff --git a/tests/v2_tests/test_hybrid_search.py b/tests/v2_tests/test_hybrid_search.py new file mode 100644 index 00000000..6bba7587 --- /dev/null +++ b/tests/v2_tests/test_hybrid_search.py @@ -0,0 +1,199 @@ +import copy +import marqo +from marqo import enums +from unittest import mock +import requests +import random +import math +import time +from tests.marqo_test import MarqoTestCase, CloudTestIndex +from marqo.errors import MarqoWebError +from pytest import mark + + +@mark.fixed +class TestHybridSearch(MarqoTestCase): + @staticmethod + def strip_marqo_fields(doc, strip_id=True): + """Strips Marqo fields from a returned doc to get the original doc""" + copied = copy.deepcopy(doc) + + strip_fields = ["_highlights", "_score"] + if strip_id: + strip_fields += ["_id"] + + for to_strip in strip_fields: + del copied[to_strip] + + return copied + + def setUp(self): + super().setUp() + self.docs_list = [ + # TODO: add score modifiers + # similar semantics to dogs + {"_id": "doc1", "text_field_1": "dogs"}, + {"_id": "doc2", "text_field_1": "puppies"}, + {"_id": "doc3", "text_field_1": "canines"}, + {"_id": "doc4", "text_field_1": "huskies"}, + {"_id": "doc5", "text_field_1": "four-legged animals"}, + + # shares lexical token with dogs + {"_id": "doc6", "text_field_1": "hot dogs"}, + {"_id": "doc7", "text_field_1": "dogs is a word"}, + {"_id": "doc8", "text_field_1": "something something dogs"}, + {"_id": "doc9", "text_field_1": "dogs random words"}, + {"_id": "doc10", "text_field_1": "dogs dogs dogs"}, + + {"_id": "doc11", "text_field_2": "dogs but wrong field"}, + {"_id": "doc12", "text_field_2": "puppies puppies"}, + {"_id": "doc13", "text_field_2": "canines canines"}, + ] + + def test_hybrid_search_searchable_attributes(self): + """ + Tests that searchable attributes work as expected for all methods + """ + + index_test_cases = [ + (CloudTestIndex.structured_text, self.structured_index_name) # TODO: add unstructured when supported + ] + for cloud_test_index_to_use, open_source_test_index_name in index_test_cases: + test_index_name = self.get_test_index_name( + cloud_test_index_to_use=cloud_test_index_to_use, + open_source_test_index_name=open_source_test_index_name + ) + self.client.index(test_index_name).add_documents(self.docs_list) + + with self.subTest("retrieval: disjunction, ranking: rrf"): + hybrid_res = self.client.index(test_index_name).search( + "puppies", + search_method="HYBRID", + hybrid_parameters={ + "retrievalMethod": "disjunction", + "rankingMethod": "rrf", + "alpha": 0.5, + "searchableAttributesLexical": ["text_field_2"], + "searchableAttributesTensor": ["text_field_2"] + }, + limit=10 + ) + self.assertEqual(len(hybrid_res["hits"]), 3) # Only 3 documents have text_field_2 at all + self.assertEqual(hybrid_res["hits"][0]["_id"], "doc12") # puppies puppies in text field 2 + self.assertEqual(hybrid_res["hits"][1]["_id"], "doc13") + self.assertEqual(hybrid_res["hits"][2]["_id"], "doc11") + + with self.subTest("retrieval: lexical, ranking: tensor"): + hybrid_res = self.client.index(test_index_name).search( + "puppies", + search_method="HYBRID", + hybrid_parameters={ + "retrievalMethod": "lexical", + "rankingMethod": "tensor", + "searchableAttributesLexical": ["text_field_2"] + }, + limit=10 + ) + self.assertEqual(len(hybrid_res["hits"]), + 1) # Only 1 document has puppies in text_field_2. Lexical retrieval will only get this one. + self.assertEqual(hybrid_res["hits"][0]["_id"], "doc12") + + with self.subTest("retrieval: tensor, ranking: lexical"): + hybrid_res = self.client.index(test_index_name).search( + "puppies", + search_method="HYBRID", + hybrid_parameters={ + "retrievalMethod": "tensor", + "rankingMethod": "lexical", + "searchableAttributesTensor": ["text_field_2"] + }, + limit=10 + ) + self.assertEqual(len(hybrid_res["hits"]), + 3) # Only 3 documents have text field 2. Tensor retrieval will get them all. + self.assertEqual(hybrid_res["hits"][0]["_id"], "doc12") + self.assertEqual(hybrid_res["hits"][1]["_id"], "doc11") + self.assertEqual(hybrid_res["hits"][2]["_id"], "doc13") + + def test_hybrid_search_same_retrieval_and_ranking_matches_original_method(self): + """ + Tests that hybrid search with: + retrievalMethod = "lexical", rankingMethod = "lexical" and + retrievalMethod = "tensor", rankingMethod = "tensor" + + Results must be the same as lexical search and tensor search respectively. + """ + + index_test_cases = [ + (CloudTestIndex.structured_text, self.structured_index_name) # TODO: add unstructured when supported + ] + for cloud_test_index_to_use, open_source_test_index_name in index_test_cases: + test_index_name = self.get_test_index_name( + cloud_test_index_to_use=cloud_test_index_to_use, + open_source_test_index_name=open_source_test_index_name + ) + self.client.index(test_index_name).add_documents(self.docs_list) + + test_cases = [ + ("lexical", "lexical"), + ("tensor", "tensor") + ] + + for retrievalMethod, rankingMethod in test_cases: + with self.subTest(retrieval=retrievalMethod, ranking=rankingMethod): + hybrid_res = self.client.index(test_index_name).search( + "dogs", + search_method="HYBRID", + hybrid_parameters={ + "retrievalMethod": retrievalMethod, + "rankingMethod": rankingMethod + }, + limit=10 + ) + + base_res = self.client.index(test_index_name).search( + "dogs", + search_method=retrievalMethod, # will be either lexical or tensor + limit=10 + ) + + self.assertEqual(len(hybrid_res["hits"]), len(base_res["hits"])) + for i in range(len(hybrid_res["hits"])): + self.assertEqual(hybrid_res["hits"][i]["_id"], base_res["hits"][i]["_id"]) + + def test_hybrid_search_with_filter(self): + """ + Tests that filter is applied correctly in hybrid search. + """ + + index_test_cases = [ + (CloudTestIndex.structured_text, self.structured_index_name) # TODO: add unstructured when supported + ] + for cloud_test_index_to_use, open_source_test_index_name in index_test_cases: + test_index_name = self.get_test_index_name( + cloud_test_index_to_use=cloud_test_index_to_use, + open_source_test_index_name=open_source_test_index_name + ) + self.client.index(test_index_name).add_documents(self.docs_list) + + test_cases = [ + ("disjunction", "rrf"), + ("lexical", "lexical"), + ("tensor", "tensor") + ] + + for retrievalMethod, rankingMethod in test_cases: + with self.subTest(retrieval=retrievalMethod, ranking=rankingMethod): + hybrid_res = self.client.index(test_index_name).search( + "dogs", + search_method="HYBRID", + filter_string="text_field_1:(something something dogs)", + hybrid_parameters={ + "retrievalMethod": retrievalMethod, + "rankingMethod": rankingMethod + }, + limit=10 + ) + + self.assertEqual(len(hybrid_res["hits"]), 1) + self.assertEqual(hybrid_res["hits"][0]["_id"], "doc8")