Skip to content

Commit

Permalink
Shreyas/KG Search Result model (#937)
Browse files Browse the repository at this point in the history
* return type to kg_search_result

* add model

* local and global results

* modify config
  • Loading branch information
shreyaspimpalgaonkar authored Aug 23, 2024
1 parent 0debda8 commit 0af19cb
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 23 deletions.
36 changes: 35 additions & 1 deletion py/core/base/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,42 @@ class Config:
},
}

class KGLocalSearchResult(BaseModel):
"""Result of a local knowledge graph search operation."""
query: str
entities: list[dict[str, Any]]
relationships: list[dict[str, Any]]
communities: list[dict[str, Any]]

KGSearchResult = list[Tuple[str, list[Dict[str, Any]]]]
def __str__(self) -> str:
return f"LocalSearchResult(query={self.query}, search_result={self.search_result})"

def __repr__(self) -> str:
return self.__str__()


class KGGlobalSearchResult(BaseModel):
"""Result of a global knowledge graph search operation."""
query: str
search_result: list[Dict[str, Any]]

def __str__(self) -> str:
return f"GlobalSearchResult(query={self.query}, search_result={self.search_result})"

def __repr__(self) -> str:
return self.__str__()


class KGSearchResult(BaseModel):
"""Result of a knowledge graph search operation."""
local_result: Optional[KGLocalSearchResult] = None
global_result: Optional[KGGlobalSearchResult] = None

def __str__(self) -> str:
return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})"

def __repr__(self) -> str:
return self.__str__()


class AggregateSearchResult(BaseModel):
Expand Down
16 changes: 5 additions & 11 deletions py/core/configs/neo4j_kg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,22 +25,16 @@ provider = "neo4j"
batch_size = 256
kg_extraction_prompt = "graphrag_triplet_extraction_zero_shot"


[kg.kg_extraction_config]
model = "gpt-4o-mini"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }

[kg.kg_enrichment_config]
max_knowledge_triples = 100
generation_config = { model = "gpt-4o-mini" } # and other params
leiden_params = { max_cluster_size = 1000 } # more params in graspologic/partition/leiden.py

[kg.kg_search_config]
model = "gpt-4o-mini"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }

[database]
provider = "postgres"
Expand Down
20 changes: 10 additions & 10 deletions py/core/pipes/retrieval/kg_search_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
RunLoggingSingleton,
)

from core.base.abstractions.search import KGLocalSearchResult, KGGlobalSearchResult, KGSearchResult
from ..abstractions.generator_pipe import GeneratorPipe

logger = logging.getLogger(__name__)


class KGSearchSearchPipe(GeneratorPipe):
"""
Embeds and stores documents using a specified embedding model and database.
Expand Down Expand Up @@ -102,7 +102,7 @@ async def local_search(
kg_search_settings: KGSearchSettings,
*args: Any,
**kwargs: Any,
):
) -> KGLocalSearchResult:
# search over communities and
# do 3 searches. One over entities, one over relationships, one over communities

Expand All @@ -127,7 +127,7 @@ async def local_search(
)
all_search_results.append(search_result)

yield message, all_search_results
yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2])

async def global_search(
self,
Expand All @@ -137,7 +137,7 @@ async def global_search(
kg_search_settings: KGSearchSettings,
*args: Any,
**kwargs: Any,
):
) -> KGGlobalSearchResult:
# map reduce
async for message in input.message:
map_responses = []
Expand Down Expand Up @@ -209,7 +209,7 @@ async def process_community(merged_report):

output = output.choices[0].message.content

yield message, [{"output": output}]
yield KGGlobalSearchResult(query=message, search_result=output, citations=None)

async def _run_logic(
self,
Expand All @@ -219,19 +219,19 @@ async def _run_logic(
kg_search_settings: KGSearchSettings,
*args: Any,
**kwargs: Any,
):
) -> KGSearchResult:

logger.info("Performing global search")
kg_search_type = kg_search_settings.kg_search_type

if kg_search_type == "local":
async for query, result in self.local_search(
async for result in self.local_search(
input, state, run_id, kg_search_settings
):
yield (query, result)
yield KGSearchResult(local_result=result)

else:
async for query, result in self.global_search(
async for result in self.global_search(
input, state, run_id, kg_search_settings
):
yield (query, result)
yield KGSearchResult(global_result=result)
18 changes: 17 additions & 1 deletion py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,24 @@ class Config:
}


KGSearchResult = list[Tuple[str, list[Dict[str, Any]]]]
class KGSearchResult(BaseModel):
query: str
results: list[Dict[str, Any]]

class Config:
json_schema_extra = {
"example": {
"query": "What is the capital of France?",
"results": [
{
"Paris": {
"name": "Paris",
"description": "Paris is the capital of France."
}
}
]
}
}

class R2RException(Exception):
def __init__(
Expand Down

0 comments on commit 0af19cb

Please sign in to comment.