Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Shreyas/kg runtime cfg #913

Merged
merged 10 commits into from
Aug 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion py/cli/commands/restructure.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@ def enrich_graph(client):
Perform graph enrichment over the entire graph.
"""
with timer():
response = client.restructure()
response = client.enrich_graph()

click.echo(response)
2 changes: 2 additions & 0 deletions py/core/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
# Restructure abstractions
"KGEnrichmentSettings",
# User abstractions
"Token",
"TokenData",
Expand Down
3 changes: 3 additions & 0 deletions py/core/base/abstractions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
VectorSearchResult,
VectorSearchSettings,
)
from .restructure import KGEnrichmentSettings
from .user import Token, TokenData, UserStats
from .vector import Vector, VectorEntry, VectorType

Expand Down Expand Up @@ -81,6 +82,8 @@
"KGSearchSettings",
"VectorSearchResult",
"VectorSearchSettings",
# Restructure abstractions
"KGEnrichmentSettings",
# User abstractions
"Token",
"TokenData",
Expand Down
23 changes: 23 additions & 0 deletions py/core/base/abstractions/restructure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

from pydantic import BaseModel, Field

from .llm import GenerationConfig


class KGEnrichmentSettings(BaseModel):
"""Settings for knowledge graph enrichment."""

max_knowledge_triples: int = Field(
default=100,
description="The maximum number of knowledge triples to extract from each chunk.",
)

generation_config: GenerationConfig = Field(
default_factory=GenerationConfig,
description="Configuration for text generation during graph enrichment.",
)
leiden_params: dict = Field(
default_factory=dict,
description="Parameters for the Leiden algorithm.",
)
4 changes: 3 additions & 1 deletion py/core/base/providers/kg.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from ...base.utils.base_utils import RelationshipType
from ..abstractions.graph import Entity, KGExtraction, Triple
from ..abstractions.llm import GenerationConfig
from ..abstractions.restructure import KGEnrichmentSettings
from .base import ProviderConfig

logger = logging.getLogger(__name__)
Expand All @@ -20,8 +21,9 @@ class KGConfig(ProviderConfig):
kg_extraction_prompt: Optional[str] = "few_shot_ner_kg_extraction"
kg_search_prompt: Optional[str] = "kg_search"
kg_extraction_config: Optional[GenerationConfig] = None
kg_search_config: Optional[GenerationConfig] = None
kg_store_path: Optional[str] = None
max_knowledge_triples: Optional[int] = 100
kg_enrichment_settings: Optional[KGEnrichmentSettings] = KGEnrichmentSettings()

def validate(self) -> None:
if self.provider not in self.supported_providers:
Expand Down
20 changes: 14 additions & 6 deletions py/core/configs/neo4j_kg.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,20 @@ kg_extraction_prompt = "graphrag_triplet_extraction_zero_shot"


[kg.kg_extraction_config]
model = "gpt-4o-mini"
temperature = 1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }
model = "gpt-4o-mini"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }

[kg.kg_search_config]
model = "gpt-4o-mini"
temperature = 0.1
top_p = 1
max_tokens_to_sample = 1_024
stream = false
add_generation_kwargs = { }

[database]
provider = "postgres"
Expand Down
9 changes: 7 additions & 2 deletions py/core/main/api/routes/restructure/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from core.base import KGEnrichmentSettings
from core.main.api.routes.base_router import BaseRouter
from core.main.engine import R2REngine
from fastapi import Depends

from typing import Union
from fastapi import Body, Depends

class RestructureRouter(BaseRouter):
def __init__(self, engine: R2REngine):
Expand All @@ -12,6 +13,10 @@ def setup_routes(self):
@self.router.post("/enrich_graph")
@self.base_endpoint
async def enrich_graph(
KGEnrichmentSettings: Union[dict, KGEnrichmentSettings] = Body(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable name KGEnrichmentSettings shadows the class name and is misleading. Consider renaming it to follow Python naming conventions, such as enrichment_settings.

...,
description="Settings for knowledge graph enrichment",
),
auth_user=(
Depends(self.engine.providers.auth.auth_wrapper)
if self.engine.providers.auth
Expand Down
1 change: 1 addition & 0 deletions py/core/main/assembly/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
self.provider_factory_override: Optional[Type[R2RProviderFactory]] = (
None
)

self.pipe_factory_override: Optional[R2RPipeFactory] = None
self.pipeline_factory_override: Optional[R2RPipelineFactory] = None

Expand Down
5 changes: 4 additions & 1 deletion py/core/main/services/restructure_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Any, Dict, List

from core.base import R2RException, RunLoggingSingleton, RunManager
from core.base.abstractions import KGEnrichmentSettings

from ..abstractions import R2RAgents, R2RPipelines, R2RProviders
from ..assembly.config import R2RConfig
Expand Down Expand Up @@ -30,7 +31,9 @@ def __init__(
logging_connection,
)

async def enrich_graph(self) -> Dict[str, Any]:
async def enrich_graph(
self, enrich_graph_settings: KGEnrichmentSettings = KGEnrichmentSettings()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The enrich_graph_settings parameter is not used in this method. Consider using it or removing it if it's unnecessary.

) -> Dict[str, Any]:
"""
Perform graph enrichment.

Expand Down
32 changes: 16 additions & 16 deletions py/core/pipes/kg/clustering.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
PromptProvider,
RunLoggingSingleton,
Triple,
KGEnrichmentSettings,
)

logger = logging.getLogger(__name__)
Expand All @@ -38,9 +39,6 @@ def __init__(
llm_provider: CompletionProvider,
prompt_provider: PromptProvider,
embedding_provider: EmbeddingProvider,
cluster_batch_size: int = 100,
max_cluster_size: int = 10,
use_lcc: bool = True,
pipe_logger: Optional[RunLoggingSingleton] = None,
type: PipeType = PipeType.OTHER,
config: Optional[AsyncPipe.PipeConfig] = None,
Expand All @@ -57,23 +55,20 @@ def __init__(
)
self.kg_provider = kg_provider
self.llm_provider = llm_provider
self.cluster_batch_size = cluster_batch_size
self.max_cluster_size = max_cluster_size
self.use_lcc = use_lcc
self.prompt_provider = prompt_provider
self.embedding_provider = embedding_provider

def _compute_leiden_communities(
self,
graph: nx.Graph,
seed: int = 0xDEADBEEF,
settings: KGEnrichmentSettings,
) -> dict[int, dict[str, int]]:
"""Compute Leiden communities."""
try:
from graspologic.partition import hierarchical_leiden

community_mapping = hierarchical_leiden(
graph, max_cluster_size=self.max_cluster_size, random_seed=seed
graph, **settings.leiden_params
)
results: dict[int, dict[str, int]] = {}
for partition in community_mapping:
Expand All @@ -84,7 +79,9 @@ def _compute_leiden_communities(
except ImportError as e:
raise ImportError("Please install the graspologic package.") from e

async def cluster_kg(self, triples: list[Triple]) -> list[Community]:
async def cluster_kg(
self, triples: list[Triple], settings: KGEnrichmentSettings = KGEnrichmentSettings()
) -> list[Community]:
"""
Clusters the knowledge graph triples into communities using hierarchical Leiden algorithm.
"""
Expand All @@ -100,7 +97,9 @@ async def cluster_kg(self, triples: list[Triple]) -> list[Community]:
id=f"{triple.subject}->{triple.predicate}->{triple.object}",
)

hierarchical_communities = self._compute_leiden_communities(G)
hierarchical_communities = self._compute_leiden_communities(
G, settings=settings
)

community_details = {}

Expand Down Expand Up @@ -172,9 +171,7 @@ async def process_community(community_key, community):
"input_text": input_text,
},
),
generation_config=GenerationConfig(
model="gpt-4o-mini",
),
generation_config=settings.generation_config,
)

description = description.choices[0].message.content
Expand Down Expand Up @@ -202,8 +199,11 @@ async def process_community(community_key, community):
)
)

for completed_task in asyncio.as_completed(tasks):
yield await completed_task
total_tasks = len(tasks)
for i, completed_task in enumerate(asyncio.as_completed(tasks), 1):
result = await completed_task
logger.info(f"Progress: {i}/{total_tasks} communities completed ({i/total_tasks*100:.2f}%)")
yield result

async def _run_logic(
self,
Expand All @@ -230,5 +230,5 @@ async def _run_logic(

triples = self.kg_provider.get_triples()

async for community in self.cluster_kg(triples):
async for community in self.cluster_kg(triples, self.kg_provider.config.kg_enrichment_settings):
yield community
2 changes: 1 addition & 1 deletion py/core/pipes/kg/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ async def extract_kg(

task_inputs = {"input": fragment.data}
task_inputs["max_knowledge_triples"] = (
self.kg_provider.config.max_knowledge_triples
self.kg_provider.config.kg_enrichment_settings.max_knowledge_triples
)

messages = self.prompt_provider._get_message_payload(
Expand Down
7 changes: 5 additions & 2 deletions py/core/providers/chunking/r2r_chunking.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,20 @@ def __init__(self, config: ChunkingConfig):
)

def _initialize_text_splitter(self) -> TextSplitter:
logger.info(
f"Initializing text splitter with method: {self.config.method}"
) # Debug log
if self.config.method == Method.RECURSIVE:
return RecursiveCharacterTextSplitter(
chunk_size=self.config.chunk_size,
chunk_overlap=self.config.chunk_overlap,
)
elif self.config.method == Method.BASIC:
# Implement basic method
raise NotImplementedError("Basic method not implemented yet")
pass
elif self.config.method == Method.BY_TITLE:
# Implement by_title method
raise NotImplementedError("By_title method not implemented yet")
pass
else:
raise ValueError(f"Unsupported method type: {self.config.method}")

Expand Down
19 changes: 11 additions & 8 deletions py/core/providers/parsing/unstructured_parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,6 @@

class UnstructuredParsingProvider(ParsingProvider):
def __init__(self, use_api, config):
try:
from unstructured.partition.auto import partition

self.partition = partition
except ImportError as e:
raise ImportError(
"Please install the unstructured package to use the unstructured parsing provider."
) from e
if config.excluded_parsers:
logger.warning(
"Excluded parsers are not supported by the unstructured parsing provider."
Expand Down Expand Up @@ -57,6 +49,17 @@ def __init__(self, use_api, config):
self.operations = operations
self.dict_to_elements = dict_to_elements

else:
try:
from unstructured.partition.auto import partition

self.partition = partition

except ImportError:
raise ImportError(
"Please install the unstructured package to use the unstructured parsing provider."
)

super().__init__(config)

async def parse(
Expand Down
27 changes: 27 additions & 0 deletions py/sdk/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,17 @@ class KGSearchSettings(BaseModel):
}


class KGEnrichmentSettings(BaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The KGEnrichmentSettings class is defined twice in this file. Consider removing one of the definitions to avoid redundancy.

leiden_params: dict = Field(
default_factory=dict,
description="Parameters for the Leiden algorithm.",
)
generation_config: GenerationConfig = Field(
default_factory=GenerationConfig,
description="Configuration for text generation during graph enrichment.",
)


class ProviderConfig(BaseModel, ABC):
"""A base provider configuration class"""

Expand Down Expand Up @@ -237,7 +248,23 @@ def model_dump(self, *args, **kwargs):
str(uuid) for uuid in dump["selected_group_ids"]
]
return dump

class KGEnrichmentSettings(BaseModel):
max_knowledge_triples: int = Field(
default=100,
description="The maximum number of knowledge triples to extract from each chunk.",
)
generation_config: GenerationConfig = Field(
default_factory=GenerationConfig,
description="The generation configuration for the KG enrichment.",
)
leiden_params: dict = Field(
default_factory=dict,
description="The parameters for the Leiden algorithm.",
)

class KGEnrichementResponse(BaseModel):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The class name KGEnrichementResponse contains a typo. Consider renaming it to KGEnrichmentResponse.

enriched_content: Dict[str, Any]

class UserResponse(BaseModel):
id: UUID
Expand Down
17 changes: 12 additions & 5 deletions py/sdk/restructure.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Optional, Union

from core.base import Document
from core.base.api.models import KGEnrichementResponse

from .models import KGEnrichmentSettings, KGEnrichementResponse

class RestructureMethods:
@staticmethod
async def enrich_graph(client) -> KGEnrichementResponse:
async def enrich_graph(
client, KGEnrichmentSettings: KGEnrichmentSettings = KGEnrichmentSettings()
) -> KGEnrichementResponse:
"""
Perform graph enrichment over the entire graph.

Returns:
KGEnrichementResponse: Results of the graph enrichment process.
"""
return await client._make_request("POST", "enrich_graph")
if not isinstance(KGEnrichmentSettings, dict):
KGEnrichmentSettings = KGEnrichmentSettings.model_dump()

data = {
"KGEnrichmentSettings": KGEnrichmentSettings,
}
return await client._make_request("POST", "enrich_graph", json=data)
Loading
Loading