Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shreyas/kgsearchresult model #957

Merged
merged 10 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions js/sdk/src/models.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,26 @@ export interface KGSearchSettings {
max_llm_queries_for_global_search?: number;
local_search_limits?: Record<string, number>;
}

export interface KGLocalSearchResult {
query: string;
entities: Record<string, any>;
relationships: Record<string, any>;
communities: Record<string, any>;
}

export interface KGGlobalSearchResult {
query: string;
search_result: string[];
}

export interface KGSearchResult {
local_result?: KGLocalSearchResult;
global_result?: KGGlobalSearchResult;
}



export interface Message {
role: string;
content: string;
Expand Down
3 changes: 3 additions & 0 deletions py/cli/commands/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ def serve(
).replace(":", "")

if docker:

run_docker_serve(
client,
host,
Expand All @@ -227,6 +228,8 @@ def serve(
exclude_postgres,
project_name,
image,
config_name,
config_path,
)
if (
"pytest" in sys.modules
Expand Down
9 changes: 5 additions & 4 deletions py/cli/utils/docker_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,19 +110,20 @@ def run_docker_serve(
exclude_postgres: bool,
project_name: str,
image: str,
config_name: Optional[str] = None,
config_path: Optional[str] = None,
):
check_set_docker_env_vars(exclude_neo4j, exclude_postgres)
set_ollama_api_base(exclude_ollama)

if config_path and config_name:
raise ValueError("Cannot specify both config_path and config_name")

if config_path:
config = R2RConfig.from_toml(config_path)
else:
if hasattr(client, "config_name"):
config_name = client.config_name
else:
if not config_name:
config_name = "default"

config = R2RConfig.from_toml(R2RBuilder.CONFIG_OPTIONS[config_name])

if config.parsing.provider == "unstructured" and not image:
Expand Down
2 changes: 2 additions & 0 deletions py/core/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@
"Prompt",
# Search abstractions
"AggregateSearchResult",
"KGLocalSearchResult",
"KGGlobalSearchResult",
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
Expand Down
2 changes: 2 additions & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
"Prompt",
# Search abstractions
"AggregateSearchResult",
"KGLocalSearchResult",
"KGGlobalSearchResult",
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
Expand Down
4 changes: 4 additions & 0 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from .restructure import KGEnrichmentSettings
from .search import (
AggregateSearchResult,
KGLocalSearchResult,
KGGlobalSearchResult,
KGSearchResult,
KGSearchSettings,
VectorSearchResult,
Expand Down Expand Up @@ -78,6 +80,8 @@
"Prompt",
# Search abstractions
"AggregateSearchResult",
"KGLocalSearchResult",
"KGGlobalSearchResult",
"KGSearchResult",
"KGSearchSettings",
"VectorSearchResult",
Expand Down
29 changes: 18 additions & 11 deletions py/core/base/abstractions/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,12 @@ class Config:
},
}


class KGLocalSearchResult(BaseModel):
"""Result of a local knowledge graph search operation."""

query: str
entities: list[dict[str, Any]]
relationships: list[dict[str, Any]]
communities: list[dict[str, Any]]
entities: dict[str, Any]
relationships: dict[str, Any]
communities: dict[str, Any]

def __str__(self) -> str:
return f"LocalSearchResult(query={self.query}, search_result={self.search_result})"
Expand All @@ -72,35 +70,44 @@ def __repr__(self) -> str:

class KGGlobalSearchResult(BaseModel):
"""Result of a global knowledge graph search operation."""

query: str
search_result: list[Dict[str, Any]]
search_result: list[str]

def __str__(self) -> str:
return f"GlobalSearchResult(query={self.query}, search_result={self.search_result})"
return f"KGGlobalSearchResult(query={self.query}, search_result={self.search_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The dict method is already provided by Pydantic's BaseModel, so there's no need to redefine it. This applies to other instances in this file and elsewhere in the codebase.

return {
"query": self.query,
"search_result": self.search_result
}

class KGSearchResult(BaseModel):
"""Result of a knowledge graph search operation."""

local_result: Optional[KGLocalSearchResult] = None
global_result: Optional[KGGlobalSearchResult] = None

def __str__(self) -> str:
return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"local_result": self.local_result.dict() if self.local_result else None,
"global_result": self.global_result.dict() if self.global_result else None
}


class AggregateSearchResult(BaseModel):
"""Result of an aggregate search operation."""

vector_search_results: Optional[list[VectorSearchResult]]
kg_search_results: Optional[KGSearchResult] = None
kg_search_results: Optional[list[KGSearchResult]] = None

def __str__(self) -> str:
return f"AggregateSearchResult(vector_search_results={self.vector_search_results}, kg_search_results={self.kg_search_results})"
Expand Down
2 changes: 1 addition & 1 deletion py/core/base/api/models/retrieval/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ class SearchResponse(BaseModel):
...,
description="List of vector search results",
)
kg_search_results: Optional[KGSearchResult] = Field(
kg_search_results: Optional[list[KGSearchResult]] = Field(
None,
description="Knowledge graph search results, if applicable",
)
Expand Down
2 changes: 1 addition & 1 deletion py/core/configs/neo4j_kg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ kg_extraction_prompt = "graphrag_triplet_extraction_zero_shot"
[kg.kg_extraction_config]
model = "gpt-4o-mini"

[kg.kg_enrichment_config]
[kg.kg_enrichment_settings]
max_knowledge_triples = 100
generation_config = { model = "gpt-4o-mini" } # and other params
leiden_params = { max_cluster_size = 1000 } # more params in graspologic/partition/leiden.py
Expand Down
19 changes: 8 additions & 11 deletions py/core/pipes/retrieval/kg_search_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
PipeType,
PromptProvider,
RunLoggingSingleton,
R2RException
)
from core.base.abstractions.search import (
KGGlobalSearchResult,
Expand All @@ -24,7 +25,6 @@

logger = logging.getLogger(__name__)


class KGSearchSearchPipe(GeneratorPipe):
"""
Embeds and stores documents using a specified embedding model and database.
Expand Down Expand Up @@ -132,12 +132,11 @@ async def local_search(
)
all_search_results.append(search_result)

yield KGLocalSearchResult(
query=message,
entities=all_search_results[0],
relationships=all_search_results[1],
communities=all_search_results[2],
)

if len(all_search_results[0])==0:
raise R2RException("No search results found. Please make sure you have run the KG enrichment step before running the search: r2r enrich-graph", 400)

yield KGLocalSearchResult(query=message, entities=all_search_results[0], relationships=all_search_results[1], communities=all_search_results[2])

async def global_search(
self,
Expand Down Expand Up @@ -217,11 +216,9 @@ async def process_community(merged_report):
generation_config=kg_search_settings.kg_search_generation_config,
)

output = output.choices[0].message.content
output = [output.choices[0].message.content]

yield KGGlobalSearchResult(
query=message, search_result=output, citations=None
)
yield KGGlobalSearchResult(query=message, search_result=output)

async def _run_logic(
self,
Expand Down
1 change: 1 addition & 0 deletions py/core/pipes/retrieval/vector_search_pipe.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ async def search(
result.metadata["associatedQuery"] = message
results.append(result)
yield result

await self.enqueue_log(
run_id=run_id,
key="search_results",
Expand Down
72 changes: 62 additions & 10 deletions py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,27 +187,79 @@ class Config:
},
}

class KGLocalSearchResult(BaseModel):
query: str
entities: list[dict[str, Any]]
relationships: list[dict[str, Any]]
communities: list[dict[str, Any]]

class KGSearchResult(BaseModel):
def __str__(self) -> str:
return f"KGLocalSearchResult(query={self.query}, entities={self.entities}, relationships={self.relationships}, communities={self.communities})"

def dict(self) -> dict:
return {
"query": self.query,
"entities": self.entities,
"relationships": self.relationships,
"communities": self.communities
}

class KGGlobalSearchResult(BaseModel):
query: str
results: list[Dict[str, Any]]
search_result: list[str]

def __str__(self) -> str:
return f"KGGlobalSearchResult(query={self.query}, search_result={self.search_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"query": self.query,
"search_result": self.search_result
}


class KGSearchResult(BaseModel):
local_result: Optional[KGLocalSearchResult] = None
global_result: Optional[KGGlobalSearchResult] = None

def __str__(self) -> str:
return f"KGSearchResult(local_result={self.local_result}, global_result={self.global_result})"

def __repr__(self) -> str:
return self.__str__()

def dict(self) -> dict:
return {
"local_result": self.local_result.dict() if self.local_result else None,
"global_result": self.global_result.dict() if self.global_result else None
}

class Config:
json_schema_extra = {
"example": {
"query": "What is the capital of France?",
"results": [
{
"local_result": {
"query": "What is the capital of France?",
"entities": {
"Paris": {
"name": "Paris",
"description": "Paris is the capital of France.",
"description": "Paris is the capital of France."
}
}
],
},
"relationships": {},
"communities": {},
},
"global_result": {
"query": "What is the capital of France?",
"search_result": [
"Paris is the capital and most populous city of France."
]
}
}
}


class R2RException(Exception):
def __init__(
self, message: str, status_code: int, detail: Optional[Any] = None
Expand Down Expand Up @@ -419,7 +471,7 @@ class SearchResponse(BaseModel):
...,
description="List of vector search results",
)
kg_search_results: Optional[KGSearchResult] = Field(
kg_search_results: Optional[list[KGSearchResult]] = Field(
None,
description="Knowledge graph search results, if applicable",
)
Expand Down
Loading