Skip to content

Commit

Permalink
up
Browse files Browse the repository at this point in the history
  • Loading branch information
emrgnt-cmplxty committed Jan 8, 2025
1 parent 0caece6 commit 84d167e
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 48 deletions.
14 changes: 14 additions & 0 deletions py/core/database/users.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,20 @@ async def create_tables(self):
await self.connection_manager.execute_query(user_table_query)
await self.connection_manager.execute_query(api_keys_table_query)

# (New) Code snippet for adding columns if missing
# Postgres >= 9.6 supports "ADD COLUMN IF NOT EXISTS"
check_columns_query = f"""
ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
ADD COLUMN IF NOT EXISTS metadata JSONB;
ALTER TABLE {self._get_table_name(self.TABLE_NAME)}
ADD COLUMN IF NOT EXISTS limits_overrides JSONB;
ALTER TABLE {self._get_table_name(self.API_KEYS_TABLE_NAME)}
ADD COLUMN IF NOT EXISTS description TEXT;
"""
await self.connection_manager.execute_query(check_columns_query)

async def get_user_by_id(self, id: UUID) -> User:
query, _ = (
QueryBuilder(self._get_table_name("users"))
Expand Down
88 changes: 40 additions & 48 deletions py/core/pipes/kg/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class ClientError(Exception):

class GraphExtractionPipe(AsyncPipe[dict]):
"""
Extracts knowledge graph information from document extractions.
Extracts knowledge graph information from document chunks.
"""

# TODO - Apply correct type hints to storage messages
Expand Down Expand Up @@ -65,7 +65,7 @@ def __init__(

async def extract_kg(
self,
extractions: list[DocumentChunk],
chunks: list[DocumentChunk],
generation_config: GenerationConfig,
max_knowledge_relationships: int,
entity_types: list[str],
Expand All @@ -76,16 +76,16 @@ async def extract_kg(
total_tasks: Optional[int] = None,
) -> KGExtraction:
"""
Extracts NER relationships from a extraction with retries.
Extracts NER relationships from a chunk with retries.
"""

# combine all extractions into a single string
combined_extraction: str = " ".join([extraction.data for extraction in extractions]) # type: ignore
# combine all chunks into a single string
combined_chunks: str = " ".join([chunk.data for chunk in chunks]) # type: ignore

messages = await self.database_provider.prompts_handler.get_message_payload(
task_prompt_name=self.database_provider.config.graph_creation_settings.graphrag_relationships_extraction_few_shot,
task_inputs={
"input": combined_extraction,
"input": combined_chunks,
"max_knowledge_relationships": max_knowledge_relationships,
"entity_types": "\n".join(entity_types),
"relation_types": "\n".join(relation_types),
Expand All @@ -103,7 +103,7 @@ async def extract_kg(

if not kg_extraction:
raise R2RException(
"No knowledge graph extraction found in the response string, the selected LLM likely failed to format it's response correctly.",
"No knowledge graph chunk found in the response string, the selected LLM likely failed to format it's response correctly.",
400,
)

Expand Down Expand Up @@ -139,10 +139,8 @@ def parse_fn(response_str: str) -> Any:
category=entity_category,
description=entity_description,
name=entity_value,
parent_id=extractions[0].document_id,
chunk_ids=[
extraction.id for extraction in extractions
],
parent_id=chunks[0].document_id,
chunk_ids=[chunk.id for chunk in chunks],
attributes={},
)
)
Expand All @@ -163,10 +161,8 @@ def parse_fn(response_str: str) -> Any:
object=object,
description=description,
weight=weight,
parent_id=extractions[0].document_id,
chunk_ids=[
extraction.id for extraction in extractions
],
parent_id=chunks[0].document_id,
chunk_ids=[chunk.id for chunk in chunks],
attributes={},
)
)
Expand All @@ -190,13 +186,13 @@ def parse_fn(response_str: str) -> Any:
await asyncio.sleep(delay)
else:
logger.error(
f"Failed after retries with for chunk {extractions[0].id} of document {extractions[0].document_id}: {e}"
f"Failed after retries with for chunk {chunks[0].id} of document {chunks[0].document_id}: {e}"
)
# raise e # you should raise an error.
# add metadata to entities and relationships

logger.info(
f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {extractions[0].document_id}",
f"GraphExtractionPipe: Completed task number {task_id} of {total_tasks} for document {chunks[0].document_id}",
)

return KGExtraction(
Expand Down Expand Up @@ -230,20 +226,20 @@ async def _run_logic( # type: ignore
logger = input.message.get("logger", logging.getLogger())

logger.info(
f"GraphExtractionPipe: Processing document {document_id} for KG extraction",
f"GraphExtractionPipe: Processing document {document_id} for KG chunk",
)

# Then create the extractions from the results
extractions = [
# Then create the chunks from the results
chunks = [
DocumentChunk(
id=extraction["id"],
document_id=extraction["document_id"],
owner_id=extraction["owner_id"],
collection_ids=extraction["collection_ids"],
data=extraction["text"],
metadata=extraction["metadata"],
id=chunk["id"],
document_id=chunk["document_id"],
owner_id=chunk["owner_id"],
collection_ids=chunk["collection_ids"],
data=chunk["text"],
metadata=chunk["metadata"],
)
for extraction in (
for chunk in (
await self.database_provider.documents_handler.list_document_chunks( # FIXME: This was using the pagination defaults from before... We need to review if this is as intended.
document_id=document_id,
offset=0,
Expand All @@ -252,41 +248,37 @@ async def _run_logic( # type: ignore
)["results"]
]

logger.info(
f"Found {len(extractions)} extractions for document {document_id}"
)
logger.info(f"Found {len(chunks)} chunks for document {document_id}")

if filter_out_existing_chunks:
existing_chunk_ids = await self.database_provider.graphs_handler.get_existing_document_entity_chunk_ids(
document_id=document_id
)
extractions = [
extraction
for extraction in extractions
if extraction.id not in existing_chunk_ids
chunks = [
chunk for chunk in chunks if chunk.id not in existing_chunk_ids
]
logger.info(
f"Filtered out {len(existing_chunk_ids)} existing extractions, remaining {len(extractions)} extractions for document {document_id}"
f"Filtered out {len(existing_chunk_ids)} existing chunks, remaining {len(chunks)} chunks for document {document_id}"
)

if len(extractions) == 0:
logger.info(f"No extractions left for document {document_id}")
if len(chunks) == 0:
logger.info(f"No chunks left for document {document_id}")
return

logger.info(
f"GraphExtractionPipe: Obtained {len(extractions)} extractions to process, time from start: {time.time() - start_time:.2f} seconds",
f"GraphExtractionPipe: Obtained {len(chunks)} chunks to process, time from start: {time.time() - start_time:.2f} seconds",
)

# sort the extractions accroding to chunk_order field in metadata in ascending order
extractions = sorted(
extractions,
# sort the chunks accroding to chunk_order field in metadata in ascending order
chunks = sorted(
chunks,
key=lambda x: x.metadata.get("chunk_order", float("inf")),
)

# group these extractions into groups of chunk_merge_count
# group these chunks into groups of chunk_merge_count
extractions_groups = [
extractions[i : i + chunk_merge_count]
for i in range(0, len(extractions), chunk_merge_count)
chunks[i : i + chunk_merge_count]
for i in range(0, len(chunks), chunk_merge_count)
]

logger.info(
Expand All @@ -296,7 +288,7 @@ async def _run_logic( # type: ignore
tasks = [
asyncio.create_task(
self.extract_kg(
extractions=extractions_group,
chunks=extractions_group,
generation_config=generation_config,
max_knowledge_relationships=max_knowledge_relationships,
entity_types=entity_types,
Expand All @@ -312,7 +304,7 @@ async def _run_logic( # type: ignore
total_tasks = len(tasks)

logger.info(
f"GraphExtractionPipe: Waiting for {total_tasks} KG extraction tasks to complete",
f"GraphExtractionPipe: Waiting for {total_tasks} KG chunk tasks to complete",
)

for completed_task in asyncio.as_completed(tasks):
Expand All @@ -321,7 +313,7 @@ async def _run_logic( # type: ignore
completed_tasks += 1
if completed_tasks % 100 == 0:
logger.info(
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks",
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG chunk tasks",
)
except Exception as e:
logger.error(f"Error in Extracting KG Relationships: {e}")
Expand All @@ -331,5 +323,5 @@ async def _run_logic( # type: ignore
)

logger.info(
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG extraction tasks, time from start: {time.time() - start_time:.2f} seconds",
f"GraphExtractionPipe: Completed {completed_tasks}/{total_tasks} KG chunk tasks, time from start: {time.time() - start_time:.2f} seconds",
)

0 comments on commit 84d167e

Please sign in to comment.