-
Notifications
You must be signed in to change notification settings - Fork 339
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Shreyas/kg runtime cfg #913
Changes from all commits
778a7fc
c44a78a
8567d5a
62154e0
761d8b6
a67de21
2187b63
bbb0ffd
dbd7114
85ba0cb
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
from typing import Optional | ||
|
||
from pydantic import BaseModel, Field | ||
|
||
from .llm import GenerationConfig | ||
|
||
|
||
class KGEnrichmentSettings(BaseModel): | ||
"""Settings for knowledge graph enrichment.""" | ||
|
||
max_knowledge_triples: int = Field( | ||
default=100, | ||
description="The maximum number of knowledge triples to extract from each chunk.", | ||
) | ||
|
||
generation_config: GenerationConfig = Field( | ||
default_factory=GenerationConfig, | ||
description="Configuration for text generation during graph enrichment.", | ||
) | ||
leiden_params: dict = Field( | ||
default_factory=dict, | ||
description="Parameters for the Leiden algorithm.", | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,6 +2,7 @@ | |
from typing import Any, Dict, List | ||
|
||
from core.base import R2RException, RunLoggingSingleton, RunManager | ||
from core.base.abstractions import KGEnrichmentSettings | ||
|
||
from ..abstractions import R2RAgents, R2RPipelines, R2RProviders | ||
from ..assembly.config import R2RConfig | ||
|
@@ -30,7 +31,9 @@ def __init__( | |
logging_connection, | ||
) | ||
|
||
async def enrich_graph(self) -> Dict[str, Any]: | ||
async def enrich_graph( | ||
self, enrich_graph_settings: KGEnrichmentSettings = KGEnrichmentSettings() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
) -> Dict[str, Any]: | ||
""" | ||
Perform graph enrichment. | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -88,6 +88,17 @@ class KGSearchSettings(BaseModel): | |
} | ||
|
||
|
||
class KGEnrichmentSettings(BaseModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
leiden_params: dict = Field( | ||
default_factory=dict, | ||
description="Parameters for the Leiden algorithm.", | ||
) | ||
generation_config: GenerationConfig = Field( | ||
default_factory=GenerationConfig, | ||
description="Configuration for text generation during graph enrichment.", | ||
) | ||
|
||
|
||
class ProviderConfig(BaseModel, ABC): | ||
"""A base provider configuration class""" | ||
|
||
|
@@ -237,7 +248,23 @@ def model_dump(self, *args, **kwargs): | |
str(uuid) for uuid in dump["selected_group_ids"] | ||
] | ||
return dump | ||
|
||
class KGEnrichmentSettings(BaseModel): | ||
max_knowledge_triples: int = Field( | ||
default=100, | ||
description="The maximum number of knowledge triples to extract from each chunk.", | ||
) | ||
generation_config: GenerationConfig = Field( | ||
default_factory=GenerationConfig, | ||
description="The generation configuration for the KG enrichment.", | ||
) | ||
leiden_params: dict = Field( | ||
default_factory=dict, | ||
description="The parameters for the Leiden algorithm.", | ||
) | ||
|
||
class KGEnrichementResponse(BaseModel): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The class name |
||
enriched_content: Dict[str, Any] | ||
|
||
class UserResponse(BaseModel): | ||
id: UUID | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,16 +1,23 @@ | ||
from typing import Any, Dict, List, Optional | ||
from typing import Any, Dict, List, Optional, Union | ||
|
||
from core.base import Document | ||
from core.base.api.models import KGEnrichementResponse | ||
|
||
from .models import KGEnrichmentSettings, KGEnrichementResponse | ||
|
||
class RestructureMethods: | ||
@staticmethod | ||
async def enrich_graph(client) -> KGEnrichementResponse: | ||
async def enrich_graph( | ||
client, KGEnrichmentSettings: KGEnrichmentSettings = KGEnrichmentSettings() | ||
) -> KGEnrichementResponse: | ||
""" | ||
Perform graph enrichment over the entire graph. | ||
|
||
Returns: | ||
KGEnrichementResponse: Results of the graph enrichment process. | ||
""" | ||
return await client._make_request("POST", "enrich_graph") | ||
if not isinstance(KGEnrichmentSettings, dict): | ||
KGEnrichmentSettings = KGEnrichmentSettings.model_dump() | ||
|
||
data = { | ||
"KGEnrichmentSettings": KGEnrichmentSettings, | ||
} | ||
return await client._make_request("POST", "enrich_graph", json=data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The variable name
KGEnrichmentSettings
shadows the class name and is misleading. Consider renaming it to follow Python naming conventions, such asenrichment_settings
.