Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update final nodes output #1247

Open
wants to merge 12 commits into
base: incremental_indexing/main
Choose a base branch
from
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241003185355991586.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Update crete final nodes"
}
228 changes: 225 additions & 3 deletions graphrag/index/update/dataframes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,11 @@

import os
from dataclasses import dataclass
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable


import numpy as np
import pandas as pd
Expand Down Expand Up @@ -120,15 +125,28 @@ async def update_dataframe_outputs(
)
delta_text_units = dataframe_dict["create_final_text_units"]

merged_text_units = _update_and_merge_text_units(
merged_text_units_df = _update_and_merge_text_units(
old_text_units, delta_text_units, entity_id_mapping
)

# TODO: Using _new in the meantime, to compare outputs without overwriting the original
await storage.set(
"create_final_text_units_new.parquet", merged_text_units.to_parquet()
"create_final_text_units_new.parquet", merged_text_units_df.to_parquet()
)

# Update final nodes
old_nodes = await _load_table_from_storage("create_final_nodes.parquet", storage)
delta_nodes = dataframe_dict["create_final_nodes"]

merged_nodes = _merge_and_update_nodes(
old_nodes,
delta_nodes,
merged_entities_df,
merged_relationships_df,
)

await storage.set("create_final_nodes_new.parquet", merged_nodes.to_parquet())


async def _concat_dataframes(name, dataframe_dict, storage):
"""Concatenate the dataframes.
Expand Down Expand Up @@ -229,7 +247,8 @@ def _group_and_resolve_entities(


def _update_and_merge_relationships(
old_relationships: pd.DataFrame, delta_relationships: pd.DataFrame
old_relationships: pd.DataFrame,
delta_relationships: pd.DataFrame,
) -> pd.DataFrame:
"""Update and merge relationships.

Expand Down Expand Up @@ -309,3 +328,206 @@ def _update_and_merge_text_units(

# Merge the final text units
return pd.concat([old_text_units, delta_text_units], ignore_index=True, copy=False)


def _merge_and_update_nodes(
old_nodes: pd.DataFrame,
delta_nodes: pd.DataFrame,
merged_entities_df: pd.DataFrame,
merged_relationships_df: pd.DataFrame,
community_count_threshold: int = 2,
) -> pd.DataFrame:
"""Merge and update nodes.

Parameters
----------
old_nodes : pd.DataFrame
The old nodes.
delta_nodes : pd.DataFrame
The delta nodes.
merged_entities_df : pd.DataFrame
The merged entities.
merged_relationships_df : pd.DataFrame
The merged relationships.
community_count_threshold : int, optional
The community count threshold, by default 2.
If a node has enough relationships to a community, it will be assigned to that community.

Returns
-------
pd.DataFrame
The updated nodes.
"""
# Increment all community ids by the max of the old nodes
old_max_community_id = old_nodes["community"].fillna(0).astype(int).max()

# Merge delta_nodes with merged_entities_df to get the new human_readable_id
delta_nodes = delta_nodes.merge(
merged_entities_df[["name", "human_readable_id"]],
left_on="title",
right_on="name",
how="left",
suffixes=("", "_new"),
)

# Replace existing human_readable_id with the new one from merged_entities_df
delta_nodes["human_readable_id"] = delta_nodes.loc[
:, "human_readable_id_new"
].combine_first(delta_nodes.loc[:, "human_readable_id"])

# Drop the auxiliary column from the merge
delta_nodes.drop(columns=["name", "human_readable_id_new"], inplace=True)

# Increment only the non-NaN values in delta_nodes["community"]
delta_nodes["community"] = delta_nodes["community"].where(
delta_nodes["community"].isna(),
delta_nodes["community"].fillna(0).astype(int) + old_max_community_id + 1,
)

# Set index for comparison
old_nodes_index = old_nodes.set_index(["level", "title"]).index
delta_nodes_index = delta_nodes.set_index(["level", "title"]).index

# Get all delta nodes that are not in the old nodes
new_delta_nodes_df = delta_nodes.loc[
~delta_nodes_index.isin(old_nodes_index)
].reset_index(drop=True)

# Get all delta nodes that are in the old nodes
existing_delta_nodes_df = delta_nodes[
delta_nodes_index.isin(old_nodes_index)
].reset_index(drop=True)

# Concat the DataFrames
concat_nodes = pd.concat([old_nodes, existing_delta_nodes_df], ignore_index=True)
columns_to_agg: dict[str, str | Callable] = {
col: "first"
for col in concat_nodes.columns
if col not in ["description", "source_id", "level", "title"]
}

# Specify custom aggregation for description and source_id
columns_to_agg.update({
"description": lambda x: os.linesep.join(x.astype(str)),
"source_id": lambda x: ",".join(str(i) for i in x.tolist()),
})

old_nodes = (
concat_nodes.groupby(["level", "title"]).agg(columns_to_agg).reset_index()
)

old_nodes["community"] = old_nodes["community"].astype("Int64")

new_delta_nodes_df = _assign_communities(
new_delta_nodes_df,
merged_relationships_df,
old_nodes,
community_count_threshold,
)

# Concatenate the old nodes with the new delta nodes
merged_final_nodes = pd.concat(
[old_nodes, new_delta_nodes_df], ignore_index=True, copy=False
)

# Merge both source and target degrees
merged_final_nodes = merged_final_nodes.merge(
merged_relationships_df[["source", "source_degree"]],
how="left",
left_on="title",
right_on="source",
).merge(
merged_relationships_df[["target", "target_degree"]],
how="left",
left_on="title",
right_on="target",
)

# Assign 'source_degree' to 'size' and 'degree'
merged_final_nodes["size"] = merged_final_nodes["source_degree"]

# Fill NaN values in 'size' and 'degree' with target_degree
merged_final_nodes["size"] = (
merged_final_nodes["size"]
.fillna(merged_final_nodes["target_degree"])
.astype("Int64")
)
merged_final_nodes["degree"] = merged_final_nodes["size"]

# Drop duplicates and the auxiliary 'source', 'target, 'source_degree' and 'target_degree' columns
return merged_final_nodes.drop(
columns=["source", "source_degree", "target", "target_degree"]
).drop_duplicates()


def _assign_communities(
new_delta_nodes_df: pd.DataFrame,
merged_relationships_df: pd.DataFrame,
old_nodes: pd.DataFrame,
community_count_threshold: int = 2,
) -> pd.DataFrame:
"""Assign communities to new delta nodes based on the most common community of related nodes.

Parameters
----------
new_delta_nodes_df : pd.DataFrame
The new delta nodes.
merged_relationships_df : pd.DataFrame
The merged relationships.
old_nodes : pd.DataFrame
The old nodes.
community_count_threshold : int, optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should we have some params to define the % of existing relationships vs new relationships (e.g. if a new node has 2 old neighbors and 10 new neighbors then we consider putting it in a new community, rather than an old community?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the time I'm doing simple majority, and not even fine tuning it haha. I totally agree that should be percentage and not an int. to be independent on the dataset size.
My plan is to further refine this value once we have community reports generated.
But, will address the change of converting to a percentage right now

The community count threshold, by default 2.
If a node has enough relationships to a community, it will be assigned to that community.
"""
# Find all relationships for the new delta nodes
node_relationships = merged_relationships_df[
merged_relationships_df["source"].isin(new_delta_nodes_df["title"])
| merged_relationships_df["target"].isin(new_delta_nodes_df["title"])
]

# Find old nodes that are related to these relationships
related_communities = old_nodes.loc[
old_nodes.loc[:, "title"].isin(node_relationships["source"])
| old_nodes.loc[:, "title"].isin(node_relationships["target"])
]

# Merge with new_delta_nodes_df to get the level and community info
related_communities = related_communities.merge(
new_delta_nodes_df[["level", "title"]], on=["level", "title"], how="inner"
)

# Count the communities for each (level, title) pair
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe we remove all these logics for calculating community_counts too?

community_counts = (
related_communities.groupby(["level", "title"])["community"]
.value_counts()
.reset_index(name="count")
)

# Filter by community threshold and select the most common community for each node
most_common_communities = community_counts[
community_counts["count"] >= community_count_threshold
]
most_common_communities = (
most_common_communities.groupby(["level", "title"]).first().reset_index()
)

# Merge the most common community information back into new_delta_nodes_df
new_delta_nodes_df = new_delta_nodes_df.merge(
most_common_communities[["level", "title", "community"]],
on=["level", "title"],
how="left",
suffixes=("", "_new"),
)

# Update the community in new_delta_nodes_df if a common community was found
new_delta_nodes_df["community"] = (
new_delta_nodes_df.loc[:, "community_new"]
.combine_first(new_delta_nodes_df.loc[:, "community"])
.astype("Int64")
)

# Drop the auxiliary column used for merging
new_delta_nodes_df.drop(columns=["community_new"], inplace=True)

return new_delta_nodes_df
24 changes: 9 additions & 15 deletions tests/verbs/test_create_base_entity_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@


async def test_create_base_entity_graph():
input_tables = load_input_tables(
[
"workflow:create_summarized_entities",
]
)
input_tables = load_input_tables([
"workflow:create_summarized_entities",
])
expected = load_expected(workflow_name)

storage = MemoryPipelineStorage()
Expand Down Expand Up @@ -60,11 +58,9 @@ async def test_create_base_entity_graph():


async def test_create_base_entity_graph_with_embeddings():
input_tables = load_input_tables(
[
"workflow:create_summarized_entities",
]
)
input_tables = load_input_tables([
"workflow:create_summarized_entities",
])
expected = load_expected(workflow_name)

config = get_config_for_workflow(workflow_name)
Expand All @@ -87,11 +83,9 @@ async def test_create_base_entity_graph_with_embeddings():


async def test_create_base_entity_graph_with_snapshots():
input_tables = load_input_tables(
[
"workflow:create_summarized_entities",
]
)
input_tables = load_input_tables([
"workflow:create_summarized_entities",
])
expected = load_expected(workflow_name)

storage = MemoryPipelineStorage()
Expand Down
24 changes: 10 additions & 14 deletions tests/verbs/test_create_final_community_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,11 @@


async def test_create_final_community_reports():
input_tables = load_input_tables(
[
"workflow:create_final_nodes",
"workflow:create_final_covariates",
"workflow:create_final_relationships",
]
)
input_tables = load_input_tables([
"workflow:create_final_nodes",
"workflow:create_final_covariates",
"workflow:create_final_relationships",
])
expected = load_expected(workflow_name)

config = get_config_for_workflow(workflow_name)
Expand Down Expand Up @@ -50,13 +48,11 @@ async def test_create_final_community_reports():


async def test_create_final_community_reports_with_embeddings():
input_tables = load_input_tables(
[
"workflow:create_final_nodes",
"workflow:create_final_covariates",
"workflow:create_final_relationships",
]
)
input_tables = load_input_tables([
"workflow:create_final_nodes",
"workflow:create_final_covariates",
"workflow:create_final_relationships",
])
expected = load_expected(workflow_name)

config = get_config_for_workflow(workflow_name)
Expand Down
16 changes: 6 additions & 10 deletions tests/verbs/test_create_summarized_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,9 @@


async def test_create_summarized_entities():
input_tables = load_input_tables(
[
"workflow:create_base_extracted_entities",
]
)
input_tables = load_input_tables([
"workflow:create_base_extracted_entities",
])
expected = load_expected(workflow_name)

storage = MemoryPipelineStorage()
Expand Down Expand Up @@ -69,11 +67,9 @@ async def test_create_summarized_entities():


async def test_create_summarized_entities_with_snapshots():
input_tables = load_input_tables(
[
"workflow:create_base_extracted_entities",
]
)
input_tables = load_input_tables([
"workflow:create_base_extracted_entities",
])
expected = load_expected(workflow_name)

storage = MemoryPipelineStorage()
Expand Down
Loading