Skip to content

Commit

Permalink
Shreyas/kgsearchresult model (#957)
Browse files Browse the repository at this point in the history
* return type to kg_search_result

* add model

* local and global results

* modify config

* add models

* up

* fix config path

* fix models
  • Loading branch information
shreyaspimpalgaonkar authored Aug 23, 2024
1 parent 5654f7d commit 513c648
Show file tree
Hide file tree
Showing 12 changed files with 127 additions and 38 deletions.
20 changes: 20 additions & 0 deletions js/sdk/src/models.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ export interface KGSearchSettings {
max_llm_queries_for_global_search?: number;
local_search_limits?: Record<string, number>;
}

export interface KGLocalSearchResult {
query: string;
entities: Record<string, any>;
relationships: Record<string, any>;
communities: Record<string, any>;
}

export interface KGGlobalSearchResult {
query: string;
search_result: string[];
}

export interface KGSearchResult {
local_result?: KGLocalSearchResult;
global_result?: KGGlobalSearchResult;
}



export interface Message {
role: string;
content: string;
Expand Down
3 changes: 3 additions & 0 deletions py/cli/commands/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def serve(
).replace(":", "")

if docker:

run_docker_serve(
client,
host,
Expand All @@ -226,6 +227,8 @@ def serve(
exclude_postgres,
project_name,
image,
config_name,
config_path,
)
if (
"pytest" in sys.modules
Expand Down
9 changes: 5 additions & 4 deletions py/cli/utils/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,20 @@ def run_docker_serve(
exclude_postgres: bool,
project_name: str,
image: str,
config_name: Optional[str] = None,
config_path: Optional[str] = None,
):
check_set_docker_env_vars(exclude_neo4j, exclude_postgres)
set_ollama_api_base(exclude_ollama)

if config_path and config_name:
raise ValueError("Cannot specify both config_path and config_name")

if config_path:
config = R2RConfig.from_toml(config_path)
else:
if hasattr(client, "config_name"):
config_name = client.config_name
else:
if not config_name:
config_name = "default"

config = R2RConfig.from_toml(R2RBuilder.CONFIG_OPTIONS[config_name])

if config.parsing.provider == "unstructured" and not image:
Expand Down
2 changes: 2 additions & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
"Prompt",
# Search abstractions
"AggregateSearchResult",
"KGLocalSearchResult",
"KGGlobalSearchResult",
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"Prompt",
# Search abstractions
"AggregateSearchResult",
"KGLocalSearchResult",
"KGGlobalSearchResult",
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
Expand Down
4 changes: 4 additions & 0 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from .restructure import KGEnrichmentSettings
from .search import (
AggregateSearchResult,
KGLocalSearchResult,
KGGlobalSearchResult,
KGSearchResult,
KGSearchSettings,
VectorSearchResult,
Expand Down Expand Up @@ -78,6 +80,8 @@
"Prompt",
# Search abstractions
"AggregateSearchResult",
"KGLocalSearchResult",
"KGGlobalSearchResult",
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
Expand Down
29 changes: 18 additions & 11 deletions py/core/base/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,12 @@ class Config:
},
}


class KGLocalSearchResult(BaseModel):
"""Result of a local knowledge graph search operation."""

query: str
entities: list[dict[str, Any]]
relationships: list[dict[str, Any]]
communities: list[dict[str, Any]]
entities: dict[str, Any]
relationships: dict[str, Any]
communities: dict[str, Any]

def __str__(self) -> str:
return f"LocalSearchResult(query={self.query}, search_result={self.search_result})"
Expand All @@ -72,35 +70,44 @@ def __repr__(self) -> str:

class KGGlobalSearchResult(BaseModel):
"""Result of a global knowledge graph search operation."""

query: str
search_result: list[Dict[str, Any]]
search_result: list[str]

def __str__(self) -> str:
return f"GlobalSearchResult(query={self.query}, search_result={self.search_result})"
return f"KGGlobalSearchResult(query={self.query}, search_result={self.search_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"query": self.query,
"search_result": self.search_result
}

class KGSearchResult(BaseModel):
"""Result of a knowledge graph search operation."""

local_result: Optional[KGLocalSearchResult] = None
global_result: Optional[KGGlobalSearchResult] = None

def __str__(self) -> str:
return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"local_result": self.local_result.dict() if self.local_result else None,
"global_result": self.global_result.dict() if self.global_result else None
}


class AggregateSearchResult(BaseModel):
"""Result of an aggregate search operation."""

vector_search_results: Optional[list[VectorSearchResult]]
kg_search_results: Optional[KGSearchResult] = None
kg_search_results: Optional[list[KGSearchResult]] = None

def __str__(self) -> str:
return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})"
Expand Down
2 changes: 1 addition & 1 deletion py/core/base/api/models/retrieval/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SearchResponse(BaseModel):
...,
description="List of vector search results",
)
kg_search_results: Optional[KGSearchResult] = Field(
kg_search_results: Optional[list[KGSearchResult]] = Field(
None,
description="Knowledge graph search results, if applicable",
)
Expand Down
2 changes: 1 addition & 1 deletion py/core/configs/neo4j_kg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ kg_extraction_prompt = "graphrag_triplet_extraction_zero_shot"
[kg.kg_extraction_config]
model = "gpt-4o-mini"

[kg.kg_enrichment_config]
[kg.kg_enrichment_settings]
max_knowledge_triples = 100
generation_config = { model = "gpt-4o-mini" } # and other params
leiden_params = { max_cluster_size = 1000 } # more params in graspologic/partition/leiden.py
Expand Down
19 changes: 8 additions & 11 deletions py/core/pipes/retrieval/kg_search_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PipeType,
PromptProvider,
RunLoggingSingleton,
R2RException
)
from core.base.abstractions.search import (
KGGlobalSearchResult,
Expand All @@ -24,7 +25,6 @@

logger = logging.getLogger(__name__)


class KGSearchSearchPipe(GeneratorPipe):
"""
Embeds and stores documents using a specified embedding model and database.
Expand Down Expand Up @@ -132,12 +132,11 @@ async def local_search(
)
all_search_results.append(search_result)

yield KGLocalSearchResult(
query=message,
entities=all_search_results[0],
relationships=all_search_results[1],
communities=all_search_results[2],
)

if len(all_search_results[0])==0:
raise R2RException("No search results found. Please make sure you have run the KG enrichment step before running the search: r2r enrich-graph", 400)

yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2])

async def global_search(
self,
Expand Down Expand Up @@ -217,11 +216,9 @@ async def process_community(merged_report):
generation_config=kg_search_settings.kg_search_generation_config,
)

output = output.choices[0].message.content
output = [output.choices[0].message.content]

yield KGGlobalSearchResult(
query=message, search_result=output, citations=None
)
yield KGGlobalSearchResult(query=message, search_result=output)

async def _run_logic(
self,
Expand Down
1 change: 1 addition & 0 deletions py/core/pipes/retrieval/vector_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ async def search(
result.metadata["associatedQuery"] = message
results.append(result)
yield result

await self.enqueue_log(
run_id=run_id,
key="search_results",
Expand Down
72 changes: 62 additions & 10 deletions py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,27 +187,79 @@ class Config:
},
}

class KGLocalSearchResult(BaseModel):
query: str
entities: list[dict[str, Any]]
relationships: list[dict[str, Any]]
communities: list[dict[str, Any]]

class KGSearchResult(BaseModel):
def __str__(self) -> str:
return f"KGLocalSearchResult(query={self.query}, entities={self.entities}, relationships={self.relationships}, communities={self.communities})"

def dict(self) -> dict:
return {
"query": self.query,
"entities": self.entities,
"relationships": self.relationships,
"communities": self.communities
}

class KGGlobalSearchResult(BaseModel):
query: str
results: list[Dict[str, Any]]
search_result: list[str]

def __str__(self) -> str:
return f"KGGlobalSearchResult(query={self.query}, search_result={self.search_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"query": self.query,
"search_result": self.search_result
}


class KGSearchResult(BaseModel):
local_result: Optional[KGLocalSearchResult] = None
global_result: Optional[KGGlobalSearchResult] = None

def __str__(self) -> str:
return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"local_result": self.local_result.dict() if self.local_result else None,
"global_result": self.global_result.dict() if self.global_result else None
}

class Config:
json_schema_extra = {
"example": {
"query": "What is the capital of France?",
"results": [
{
"local_result": {
"query": "What is the capital of France?",
"entities": {
"Paris": {
"name": "Paris",
"description": "Paris is the capital of France.",
"description": "Paris is the capital of France."
}
}
],
},
"relationships": {},
"communities": {},
},
"global_result": {
"query": "What is the capital of France?",
"search_result": [
"Paris is the capital and most populous city of France."
]
}
}
}


class R2RException(Exception):
def __init__(
self, message: str, status_code: int, detail: Optional[Any] = None
Expand Down Expand Up @@ -419,7 +471,7 @@ class SearchResponse(BaseModel):
...,
description="List of vector search results",
)
kg_search_results: Optional[KGSearchResult] = Field(
kg_search_results: Optional[list[KGSearchResult]] = Field(
None,
description="Knowledge graph search results, if applicable",
)
Expand Down

0 comments on commit 513c648

Please sign in to comment.