Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add hybrid_parameters to search #235

Merged
merged 10 commits into from
Jul 10, 2024
3 changes: 2 additions & 1 deletion src/marqo/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
context: Optional[dict] = None, score_modifiers: Optional[dict] = None,
model_auth: Optional[dict] = None,
ef_search: Optional[int] = None, approximate: Optional[bool] = None,
text_query_prefix: Optional[str] = None,
text_query_prefix: Optional[str] = None, hybrid_parameters: Optional[dict] = None
) -> Dict[str, Any]:
"""Search the index.

Expand Down Expand Up @@ -273,6 +273,7 @@ def search(self, q: Optional[Union[str, dict]] = None, searchable_attributes: Op
"reRanker": reranker,
"boost": boost,
"textQueryPrefix": text_query_prefix,
"hybridParameters": hybrid_parameters
}

body = {k: v for k, v in body.items() if v is not None}
Expand Down
16 changes: 16 additions & 0 deletions src/marqo/models/search_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
from typing import Dict, List, Optional, Union
from marqo.models.marqo_models import StrictBaseModel
from abc import ABC
from enum import Enum

from pydantic import validator, BaseModel, root_validator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Let's do an import clean to remove a lot of unused imports



class SearchBody(StrictBaseModel):
Expand All @@ -26,3 +30,15 @@ class BulkSearchBody(SearchBody):
class BulkSearchQuery(StrictBaseModel):
queries: List[BulkSearchBody]


class RetrievalMethod(str, Enum):
Disjunction = 'disjunction'
Tensor = 'tensor'
Lexical = 'lexical'

vicilliar marked this conversation as resolved.
Show resolved Hide resolved

class RankingMethod(str, Enum):
RRF = 'rrf'
NormalizeLinear = 'normalize_linear'
Tensor = 'tensor'
Lexical = 'lexical'
199 changes: 199 additions & 0 deletions tests/v2_tests/test_hybrid_search.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
import copy
import marqo
from marqo import enums
from unittest import mock
import requests
import random
import math
import time
from tests.marqo_test import MarqoTestCase, CloudTestIndex
from marqo.errors import MarqoWebError
from pytest import mark


@mark.fixed
class TestHybridSearch(MarqoTestCase):
@staticmethod
def strip_marqo_fields(doc, strip_id=True):
"""Strips Marqo fields from a returned doc to get the original doc"""
copied = copy.deepcopy(doc)

strip_fields = ["_highlights", "_score"]
if strip_id:
strip_fields += ["_id"]

for to_strip in strip_fields:
del copied[to_strip]

return copied

def setUp(self):
super().setUp()
self.docs_list = [
# TODO: add score modifiers
# similar semantics to dogs
{"_id": "doc1", "text_field_1": "dogs"},
{"_id": "doc2", "text_field_1": "puppies"},
{"_id": "doc3", "text_field_1": "canines"},
{"_id": "doc4", "text_field_1": "huskies"},
{"_id": "doc5", "text_field_1": "four-legged animals"},

# shares lexical token with dogs
{"_id": "doc6", "text_field_1": "hot dogs"},
{"_id": "doc7", "text_field_1": "dogs is a word"},
{"_id": "doc8", "text_field_1": "something something dogs"},
{"_id": "doc9", "text_field_1": "dogs random words"},
{"_id": "doc10", "text_field_1": "dogs dogs dogs"},

{"_id": "doc11", "text_field_2": "dogs but wrong field"},
{"_id": "doc12", "text_field_2": "puppies puppies"},
{"_id": "doc13", "text_field_2": "canines canines"},
]

def test_hybrid_search_searchable_attributes(self):
"""
Tests that searchable attributes work as expected for all methods
"""

index_test_cases = [
(CloudTestIndex.structured_text, self.structured_index_name) # TODO: add unstructured when supported
]
for cloud_test_index_to_use, open_source_test_index_name in index_test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
self.client.index(test_index_name).add_documents(self.docs_list)

with self.subTest("retrieval: disjunction, ranking: rrf"):
hybrid_res = self.client.index(test_index_name).search(
"puppies",
search_method="HYBRID",
hybrid_parameters={
"retrievalMethod": "disjunction",
"rankingMethod": "rrf",
"alpha": 0.5,
"searchableAttributesLexical": ["text_field_2"],
"searchableAttributesTensor": ["text_field_2"]
},
limit=10
)
self.assertEqual(len(hybrid_res["hits"]), 3) # Only 3 documents have text_field_2 at all
self.assertEqual(hybrid_res["hits"][0]["_id"], "doc12") # puppies puppies in text field 2
self.assertEqual(hybrid_res["hits"][1]["_id"], "doc13")
self.assertEqual(hybrid_res["hits"][2]["_id"], "doc11")

with self.subTest("retrieval: lexical, ranking: tensor"):
hybrid_res = self.client.index(test_index_name).search(
"puppies",
search_method="HYBRID",
hybrid_parameters={
"retrievalMethod": "lexical",
"rankingMethod": "tensor",
"searchableAttributesLexical": ["text_field_2"]
},
limit=10
)
self.assertEqual(len(hybrid_res["hits"]),
1) # Only 1 document has puppies in text_field_2. Lexical retrieval will only get this one.
self.assertEqual(hybrid_res["hits"][0]["_id"], "doc12")

with self.subTest("retrieval: tensor, ranking: lexical"):
hybrid_res = self.client.index(test_index_name).search(
"puppies",
search_method="HYBRID",
hybrid_parameters={
"retrievalMethod": "tensor",
"rankingMethod": "lexical",
"searchableAttributesTensor": ["text_field_2"]
},
limit=10
)
self.assertEqual(len(hybrid_res["hits"]),
3) # Only 3 documents have text field 2. Tensor retrieval will get them all.
self.assertEqual(hybrid_res["hits"][0]["_id"], "doc12")
self.assertEqual(hybrid_res["hits"][1]["_id"], "doc11")
self.assertEqual(hybrid_res["hits"][2]["_id"], "doc13")

def test_hybrid_search_same_retrieval_and_ranking_matches_original_method(self):
"""
Tests that hybrid search with:
retrievalMethod = "lexical", rankingMethod = "lexical" and
retrievalMethod = "tensor", rankingMethod = "tensor"

Results must be the same as lexical search and tensor search respectively.
"""

index_test_cases = [
(CloudTestIndex.structured_text, self.structured_index_name) # TODO: add unstructured when supported
]
for cloud_test_index_to_use, open_source_test_index_name in index_test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
self.client.index(test_index_name).add_documents(self.docs_list)

test_cases = [
("lexical", "lexical"),
("tensor", "tensor")
]

for retrievalMethod, rankingMethod in test_cases:
with self.subTest(retrieval=retrievalMethod, ranking=rankingMethod):
hybrid_res = self.client.index(test_index_name).search(
"dogs",
search_method="HYBRID",
hybrid_parameters={
"retrievalMethod": retrievalMethod,
"rankingMethod": rankingMethod
},
limit=10
)

base_res = self.client.index(test_index_name).search(
"dogs",
search_method=retrievalMethod, # will be either lexical or tensor
limit=10
)

self.assertEqual(len(hybrid_res["hits"]), len(base_res["hits"]))
for i in range(len(hybrid_res["hits"])):
self.assertEqual(hybrid_res["hits"][i]["_id"], base_res["hits"][i]["_id"])

def test_hybrid_search_with_filter(self):
"""
Tests that filter is applied correctly in hybrid search.
"""

index_test_cases = [
(CloudTestIndex.structured_text, self.structured_index_name) # TODO: add unstructured when supported
]
for cloud_test_index_to_use, open_source_test_index_name in index_test_cases:
test_index_name = self.get_test_index_name(
cloud_test_index_to_use=cloud_test_index_to_use,
open_source_test_index_name=open_source_test_index_name
)
self.client.index(test_index_name).add_documents(self.docs_list)

test_cases = [
("disjunction", "rrf"),
("lexical", "lexical"),
("tensor", "tensor")
]

for retrievalMethod, rankingMethod in test_cases:
with self.subTest(retrieval=retrievalMethod, ranking=rankingMethod):
hybrid_res = self.client.index(test_index_name).search(
"dogs",
search_method="HYBRID",
filter_string="text_field_1:(something something dogs)",
hybrid_parameters={
"retrievalMethod": retrievalMethod,
"rankingMethod": rankingMethod
},
limit=10
)

self.assertEqual(len(hybrid_res["hits"]), 1)
self.assertEqual(hybrid_res["hits"][0]["_id"], "doc8")
Loading