Skip to content

Commit

Permalink
local and global results
Browse files Browse the repository at this point in the history
  • Loading branch information
shreyaspimpalgaonkar committed Aug 22, 2024
1 parent af481ae commit 75eba79
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 15 deletions.
35 changes: 27 additions & 8 deletions py/core/base/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,42 @@ class Config:
},
}

class KGLocalSearchResult(BaseModel):
"""Result of a local knowledge graph search operation."""
query: str
entities: list[dict[str, Any]]
relationships: list[dict[str, Any]]
communities: list[dict[str, Any]]

class KGSearchResult(BaseModel):
"""Result of a knowledge graph search operation."""
def __str__(self) -> str:
return f"LocalSearchResult(query={self.query}, search_result={self.search_result})"

def __repr__(self) -> str:
return self.__str__()


class KGGlobalSearchResult(BaseModel):
"""Result of a global knowledge graph search operation."""
query: str
search_result: list[Dict[str, Any]]

def __str__(self) -> str:
return f"KGSearchResult(query={self.query}, search_result={self.search_result})"
return f"GlobalSearchResult(query={self.query}, search_result={self.search_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"query": self.query,
"search_result": self.search_result,
}

class KGSearchResult(BaseModel):
"""Result of a knowledge graph search operation."""
local_result: Optional[KGLocalSearchResult] = None
global_result: Optional[KGGlobalSearchResult] = None

def __str__(self) -> str:
return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})"

def __repr__(self) -> str:
return self.__str__()


class AggregateSearchResult(BaseModel):
Expand Down
14 changes: 7 additions & 7 deletions py/core/pipes/retrieval/kg_search_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
RunLoggingSingleton,
)

from core.base.abstractions.search import KGSearchResult
from core.base.abstractions.search import KGLocalSearchResult, KGGlobalSearchResult, KGSearchResult
from ..abstractions.generator_pipe import GeneratorPipe

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -102,7 +102,7 @@ async def local_search(
kg_search_settings: KGSearchSettings,
*args: Any,
**kwargs: Any,
) -> KGSearchResult:
) -> KGLocalSearchResult:
# search over communities and
# do 3 searches. One over entities, one over relationships, one over communities

Expand All @@ -127,7 +127,7 @@ async def local_search(
)
all_search_results.append(search_result)

yield KGSearchResult(query=message, search_result=all_search_results)
yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2])

async def global_search(
self,
Expand All @@ -137,7 +137,7 @@ async def global_search(
kg_search_settings: KGSearchSettings,
*args: Any,
**kwargs: Any,
) -> KGSearchResult:
) -> KGGlobalSearchResult:
# map reduce
async for message in input.message:
map_responses = []
Expand Down Expand Up @@ -209,7 +209,7 @@ async def process_community(merged_report):

output = output.choices[0].message.content

yield KGSearchResult(query=message, search_result=[{"output": output}])
yield KGGlobalSearchResult(query=message, search_result=output, citations=None)

async def _run_logic(
self,
Expand All @@ -228,10 +228,10 @@ async def _run_logic(
async for result in self.local_search(
input, state, run_id, kg_search_settings
):
yield result
yield KGSearchResult(local_result=result)

else:
async for result in self.global_search(
input, state, run_id, kg_search_settings
):
yield result
yield KGSearchResult(global_result=result)

0 comments on commit 75eba79

Please sign in to comment.