diff --git a/js/sdk/src/models.tsx b/js/sdk/src/models.tsx index d2b74b61d..1d135b204 100644 --- a/js/sdk/src/models.tsx +++ b/js/sdk/src/models.tsx @@ -60,6 +60,26 @@ export interface KGSearchSettings { max_llm_queries_for_global_search?: number; local_search_limits?: Record; } + +export interface KGLocalSearchResult { + query: string; + entities: Record; + relationships: Record; + communities: Record; +} + +export interface KGGlobalSearchResult { + query: string; + search_result: string[]; +} + +export interface KGSearchResult { + local_result?: KGLocalSearchResult; + global_result?: KGGlobalSearchResult; +} + + + export interface Message { role: string; content: string; diff --git a/py/cli/commands/server.py b/py/cli/commands/server.py index 5355b4e8f..5276270f9 100644 --- a/py/cli/commands/server.py +++ b/py/cli/commands/server.py @@ -218,6 +218,7 @@ def serve( ).replace(":", "") if docker: + run_docker_serve( client, host, @@ -227,6 +228,8 @@ def serve( exclude_postgres, project_name, image, + config_name, + config_path, ) if ( "pytest" in sys.modules diff --git a/py/cli/utils/docker_utils.py b/py/cli/utils/docker_utils.py index 4d6e48824..289bf3a1a 100644 --- a/py/cli/utils/docker_utils.py +++ b/py/cli/utils/docker_utils.py @@ -110,19 +110,20 @@ def run_docker_serve( exclude_postgres: bool, project_name: str, image: str, + config_name: Optional[str] = None, config_path: Optional[str] = None, ): check_set_docker_env_vars(exclude_neo4j, exclude_postgres) set_ollama_api_base(exclude_ollama) + if config_path and config_name: + raise ValueError("Cannot specify both config_path and config_name") + if config_path: config = R2RConfig.from_toml(config_path) else: - if hasattr(client, "config_name"): - config_name = client.config_name - else: + if not config_name: config_name = "default" - config = R2RConfig.from_toml(R2RBuilder.CONFIG_OPTIONS[config_name]) if config.parsing.provider == "unstructured" and not image: diff --git a/py/core/__init__.py b/py/core/__init__.py index 28f50891d..fbd510c8a 100644 --- a/py/core/__init__.py +++ b/py/core/__init__.py @@ -73,6 +73,8 @@ "Prompt", # Search abstractions "AggregateSearchResult", + "KGLocalSearchResult", + "KGGlobalSearchResult", "KGSearchResult", "KGSearchSettings", "VectorSearchResult", diff --git a/py/core/base/__init__.py b/py/core/base/__init__.py index 5fdb09128..48ac9dbdb 100644 --- a/py/core/base/__init__.py +++ b/py/core/base/__init__.py @@ -43,6 +43,8 @@ "Prompt", # Search abstractions "AggregateSearchResult", + "KGLocalSearchResult", + "KGGlobalSearchResult", "KGSearchResult", "KGSearchSettings", "VectorSearchResult", diff --git a/py/core/base/abstractions/__init__.py b/py/core/base/abstractions/__init__.py index 34f64ad71..971ad4c1c 100644 --- a/py/core/base/abstractions/__init__.py +++ b/py/core/base/abstractions/__init__.py @@ -31,6 +31,8 @@ from .restructure import KGEnrichmentSettings from .search import ( AggregateSearchResult, + KGLocalSearchResult, + KGGlobalSearchResult, KGSearchResult, KGSearchSettings, VectorSearchResult, @@ -78,6 +80,8 @@ "Prompt", # Search abstractions "AggregateSearchResult", + "KGLocalSearchResult", + "KGGlobalSearchResult", "KGSearchResult", "KGSearchSettings", "VectorSearchResult", diff --git a/py/core/base/abstractions/search.py b/py/core/base/abstractions/search.py index 565728794..d8e3bca0d 100644 --- a/py/core/base/abstractions/search.py +++ b/py/core/base/abstractions/search.py @@ -54,14 +54,12 @@ class Config: }, } - class KGLocalSearchResult(BaseModel): """Result of a local knowledge graph search operation.""" - query: str - entities: list[dict[str, Any]] - relationships: list[dict[str, Any]] - communities: list[dict[str, Any]] + entities: dict[str, Any] + relationships: dict[str, Any] + communities: dict[str, Any] def __str__(self) -> str: return f"LocalSearchResult(query={self.query}, search_result={self.search_result})" @@ -72,35 +70,44 @@ def __repr__(self) -> str: class KGGlobalSearchResult(BaseModel): """Result of a global knowledge graph search operation.""" - query: str - search_result: list[Dict[str, Any]] + search_result: list[str] def __str__(self) -> str: - return f"GlobalSearchResult(query={self.query}, search_result={self.search_result})" + return f"KGGlobalSearchResult(query={self.query}, search_result={self.search_result})" def __repr__(self) -> str: return self.__str__() + def dict(self) -> dict: + return { + "query": self.query, + "search_result": self.search_result + } class KGSearchResult(BaseModel): """Result of a knowledge graph search operation.""" - local_result: Optional[KGLocalSearchResult] = None global_result: Optional[KGGlobalSearchResult] = None - + def __str__(self) -> str: return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})" def __repr__(self) -> str: return self.__str__() + + def dict(self) -> dict: + return { + "local_result": self.local_result.dict() if self.local_result else None, + "global_result": self.global_result.dict() if self.global_result else None + } class AggregateSearchResult(BaseModel): """Result of an aggregate search operation.""" vector_search_results: Optional[list[VectorSearchResult]] - kg_search_results: Optional[KGSearchResult] = None + kg_search_results: Optional[list[KGSearchResult]] = None def __str__(self) -> str: return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})" diff --git a/py/core/base/api/models/retrieval/responses.py b/py/core/base/api/models/retrieval/responses.py index 6bb086971..e5c87a091 100644 --- a/py/core/base/api/models/retrieval/responses.py +++ b/py/core/base/api/models/retrieval/responses.py @@ -11,7 +11,7 @@ class SearchResponse(BaseModel): ..., description="List of vector search results", ) - kg_search_results: Optional[KGSearchResult] = Field( + kg_search_results: Optional[list[KGSearchResult]] = Field( None, description="Knowledge graph search results, if applicable", ) diff --git a/py/core/configs/neo4j_kg.toml b/py/core/configs/neo4j_kg.toml index cbaa13a0d..ee177aff6 100644 --- a/py/core/configs/neo4j_kg.toml +++ b/py/core/configs/neo4j_kg.toml @@ -28,7 +28,7 @@ kg_extraction_prompt = "graphrag_triplet_extraction_zero_shot" [kg.kg_extraction_config] model = "gpt-4o-mini" - [kg.kg_enrichment_config] + [kg.kg_enrichment_settings] max_knowledge_triples = 100 generation_config = { model = "gpt-4o-mini" } # and other params leiden_params = { max_cluster_size = 1000 } # more params in graspologic/partition/leiden.py diff --git a/py/core/pipes/retrieval/kg_search_search_pipe.py b/py/core/pipes/retrieval/kg_search_search_pipe.py index fcd37889a..e81c0eb51 100644 --- a/py/core/pipes/retrieval/kg_search_search_pipe.py +++ b/py/core/pipes/retrieval/kg_search_search_pipe.py @@ -13,6 +13,7 @@ PipeType, PromptProvider, RunLoggingSingleton, + R2RException ) from core.base.abstractions.search import ( KGGlobalSearchResult, @@ -24,7 +25,6 @@ logger = logging.getLogger(__name__) - class KGSearchSearchPipe(GeneratorPipe): """ Embeds and stores documents using a specified embedding model and database. @@ -132,12 +132,11 @@ async def local_search( ) all_search_results.append(search_result) - yield KGLocalSearchResult( - query=message, - entities=all_search_results[0], - relationships=all_search_results[1], - communities=all_search_results[2], - ) + + if len(all_search_results[0])==0: + raise R2RException("No search results found. Please make sure you have run the KG enrichment step before running the search: r2r enrich-graph", 400) + + yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2]) async def global_search( self, @@ -217,11 +216,9 @@ async def process_community(merged_report): generation_config=kg_search_settings.kg_search_generation_config, ) - output = output.choices[0].message.content + output = [output.choices[0].message.content] - yield KGGlobalSearchResult( - query=message, search_result=output, citations=None - ) + yield KGGlobalSearchResult(query=message, search_result=output) async def _run_logic( self, diff --git a/py/core/pipes/retrieval/vector_search_pipe.py b/py/core/pipes/retrieval/vector_search_pipe.py index d01317f5f..a14f1fa46 100644 --- a/py/core/pipes/retrieval/vector_search_pipe.py +++ b/py/core/pipes/retrieval/vector_search_pipe.py @@ -91,6 +91,7 @@ async def search( result.metadata["associatedQuery"] = message results.append(result) yield result + await self.enqueue_log( run_id=run_id, key="search_results", diff --git a/py/sdk/models.py b/py/sdk/models.py index 571659dfb..69d481a89 100644 --- a/py/sdk/models.py +++ b/py/sdk/models.py @@ -187,27 +187,79 @@ class Config: }, } +class KGLocalSearchResult(BaseModel): + query: str + entities: list[dict[str, Any]] + relationships: list[dict[str, Any]] + communities: list[dict[str, Any]] -class KGSearchResult(BaseModel): + def __str__(self) -> str: + return f"KGLocalSearchResult(query={self.query}, entities={self.entities}, relationships={self.relationships}, communities={self.communities})" + + def dict(self) -> dict: + return { + "query": self.query, + "entities": self.entities, + "relationships": self.relationships, + "communities": self.communities + } + +class KGGlobalSearchResult(BaseModel): query: str - results: list[Dict[str, Any]] + search_result: list[str] + + def __str__(self) -> str: + return f"KGGlobalSearchResult(query={self.query}, search_result={self.search_result})" + + def __repr__(self) -> str: + return self.__str__() + + def dict(self) -> dict: + return { + "query": self.query, + "search_result": self.search_result + } + + +class KGSearchResult(BaseModel): + local_result: Optional[KGLocalSearchResult] = None + global_result: Optional[KGGlobalSearchResult] = None + + def __str__(self) -> str: + return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})" + + def __repr__(self) -> str: + return self.__str__() + + def dict(self) -> dict: + return { + "local_result": self.local_result.dict() if self.local_result else None, + "global_result": self.global_result.dict() if self.global_result else None + } class Config: json_schema_extra = { "example": { - "query": "What is the capital of France?", - "results": [ - { + "local_result": { + "query": "What is the capital of France?", + "entities": { "Paris": { "name": "Paris", - "description": "Paris is the capital of France.", + "description": "Paris is the capital of France." } - } - ], + }, + "relationships": {}, + "communities": {}, + }, + "global_result": { + "query": "What is the capital of France?", + "search_result": [ + "Paris is the capital and most populous city of France." + ] + } } } - class R2RException(Exception): def __init__( self, message: str, status_code: int, detail: Optional[Any] = None @@ -419,7 +471,7 @@ class SearchResponse(BaseModel): ..., description="List of vector search results", ) - kg_search_results: Optional[KGSearchResult] = Field( + kg_search_results: Optional[list[KGSearchResult]] = Field( None, description="Knowledge graph search results, if applicable", )