diff --git a/py/core/base/abstractions/search.py b/py/core/base/abstractions/search.py index a9405b915..affc9cb9f 100644 --- a/py/core/base/abstractions/search.py +++ b/py/core/base/abstractions/search.py @@ -54,23 +54,42 @@ class Config: }, } +class KGLocalSearchResult(BaseModel): + """Result of a local knowledge graph search operation.""" + query: str + entities: list[dict[str, Any]] + relationships: list[dict[str, Any]] + communities: list[dict[str, Any]] -class KGSearchResult(BaseModel): - """Result of a knowledge graph search operation.""" + def __str__(self) -> str: + return f"LocalSearchResult(query={self.query}, search_result={self.search_result})" + + def __repr__(self) -> str: + return self.__str__() + + +class KGGlobalSearchResult(BaseModel): + """Result of a global knowledge graph search operation.""" query: str search_result: list[Dict[str, Any]] def __str__(self) -> str: - return f"KGSearchResult(query={self.query}, search_result={self.search_result})" + return f"GlobalSearchResult(query={self.query}, search_result={self.search_result})" def __repr__(self) -> str: return self.__str__() - def dict(self) -> dict: - return { - "query": self.query, - "search_result": self.search_result, - } + +class KGSearchResult(BaseModel): + """Result of a knowledge graph search operation.""" + local_result: Optional[KGLocalSearchResult] = None + global_result: Optional[KGGlobalSearchResult] = None + + def __str__(self) -> str: + return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})" + + def __repr__(self) -> str: + return self.__str__() class AggregateSearchResult(BaseModel): diff --git a/py/core/pipes/retrieval/kg_search_search_pipe.py b/py/core/pipes/retrieval/kg_search_search_pipe.py index d4d6ad02a..e5664fae5 100644 --- a/py/core/pipes/retrieval/kg_search_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_search_pipe.py @@ -15,7 +15,7 @@ RunLoggingSingleton, ) -from core.base.abstractions.search import KGSearchResult +from core.base.abstractions.search import KGLocalSearchResult, KGGlobalSearchResult, KGSearchResult from ..abstractions.generator_pipe import GeneratorPipe logger = logging.getLogger(__name__) @@ -102,7 +102,7 @@ async def local_search( kg_search_settings: KGSearchSettings, *args: Any, **kwargs: Any, - ) -> KGSearchResult: + ) -> KGLocalSearchResult: # search over communities and # do 3 searches. One over entities, one over relationships, one over communities @@ -127,7 +127,7 @@ async def local_search( ) all_search_results.append(search_result) - yield KGSearchResult(query=message, search_result=all_search_results) + yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2]) async def global_search( self, @@ -137,7 +137,7 @@ async def global_search( kg_search_settings: KGSearchSettings, *args: Any, **kwargs: Any, - ) -> KGSearchResult: + ) -> KGGlobalSearchResult: # map reduce async for message in input.message: map_responses = [] @@ -209,7 +209,7 @@ async def process_community(merged_report): output = output.choices[0].message.content - yield KGSearchResult(query=message, search_result=[{"output": output}]) + yield KGGlobalSearchResult(query=message, search_result=output, citations=None) async def _run_logic( self, @@ -228,10 +228,10 @@ async def _run_logic( async for result in self.local_search( input, state, run_id, kg_search_settings ): - yield result + yield KGSearchResult(local_result=result) else: async for result in self.global_search( input, state, run_id, kg_search_settings ): - yield result + yield KGSearchResult(global_result=result)