Skip to content

Commit

Permalink
CLeanup CassandraKnowledgeGraph code
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jun 11, 2024
1 parent a11fa5b commit 5a9a321
Show file tree
Hide file tree
Showing 11 changed files with 166 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ def _node(node: LangChainNode) -> Node:
for node in document.nodes:
yield _node(node)
for edge in document.relationships:
yield Relation(source=_node(edge.source), target=_node(edge.target), type=edge.type)
yield Relation(
source=_node(edge.source), target=_node(edge.target), type=edge.type
)


class CassandraGraphStore(GraphStore):
Expand Down Expand Up @@ -56,15 +58,28 @@ def add_graph_documents(
def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]:
raise ValueError("Querying Cassandra should use `as_runnable`.")

def as_runnable(self, steps: int = 3, edge_filters: Sequence[str] = []) -> Runnable:
@property
def get_schema(self) -> str:
raise NotImplementedError

@property
def get_structured_schema(self) -> Dict[str, Any]:
raise NotImplementedError

def refresh_schema(self) -> None:
raise NotImplementedError

def as_runnable(self, steps: int = 3, edge_filters: Sequence[str] = ()) -> Runnable:
"""
Return a runnable that retrieves the sub-graph near the input entity or entities.
Parameters:
- steps: The maximum distance to follow from the starting points.
- edge_filters: Predicates to use for filtering the edges.
"""
return RunnableLambda(func=self.graph.traverse, afunc=self.graph.atraverse).bind(
return RunnableLambda(
func=self.graph.traverse, afunc=self.graph.atraverse
).bind(
steps=steps,
edge_filters=edge_filters,
edge_filters=edge_filters or [],
)
21 changes: 16 additions & 5 deletions libs/knowledge-graph/ragstack_knowledge_graph/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ def __init__(
self,
llm: BaseChatModel,
schema: KnowledgeSchema,
examples: Sequence[Example] = [],
examples: Sequence[Example] = (),
strict: bool = False,
) -> None:
self._validator = KnowledgeSchemaValidator(schema)
self.strict = strict

messages = [
SystemMessagePromptTemplate(
prompt=load_template("extraction.md", knowledge_schema_yaml=schema.to_yaml_str())
prompt=load_template(
"extraction.md", knowledge_schema_yaml=schema.to_yaml_str()
)
)
]

Expand All @@ -66,14 +68,20 @@ def _process_response(
self, document: Document, response: Union[Dict, BaseModel]
) -> GraphDocument:
raw_graph = cast(_Graph, response)
nodes = [map_to_base_node(node) for node in raw_graph.nodes] if raw_graph.nodes else []
nodes = (
[map_to_base_node(node) for node in raw_graph.nodes]
if raw_graph.nodes
else []
)
relationships = (
[map_to_base_relationship(rel) for rel in raw_graph.relationships]
if raw_graph.relationships
else []
)

document = GraphDocument(nodes=nodes, relationships=relationships, source=document)
document = GraphDocument(
nodes=nodes, relationships=relationships, source=document
)

if self.strict:
self._validator.validate_graph_document(document)
Expand All @@ -85,4 +93,7 @@ def extract(self, documents: List[Document]) -> List[GraphDocument]:
responses = self._chain.batch_as_completed(
[{"input": doc.page_content} for doc in documents]
)
return [self._process_response(documents[idx], response) for idx, response in responses]
return [
self._process_response(documents[idx], response)
for idx, response in responses
]
79 changes: 48 additions & 31 deletions libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def _parse_node(row) -> Node:
return Node(
name=row.name,
type=row.type,
properties=_deserialize_md_dict(row.properties_json) if row.properties_json else dict(),
properties=_deserialize_md_dict(row.properties_json)
if row.properties_json
else dict(),
)


Expand All @@ -40,15 +42,16 @@ def __init__(
"""
Create a Cassandra Knowledge Graph.
Parameters:
- node_table: Name of the table containing nodes. Defaults to `"entities"`.
- edge_table: Name of the table containing edges. Defaults to `"relationships"`.
_ text_embeddings: Name of the embeddings to use, if any.
- session: The Cassandra `Session` to use. If not specified, uses the default `cassio`
session, which requires `cassio.init` has been called.
- keyspace: The Cassandra keyspace to use. If not specified, uses the default `cassio`
keyspace, which requires `cassio.init` has been called.
- apply_schema: If true, the node table and edge table are created.
Args:
node_table: Name of the table containing nodes. Defaults to `"entities"`.
edge_table: Name of the table containing edges. Defaults to `
"relationships"`.
text_embeddings: Name of the embeddings to use, if any.
session: The Cassandra `Session` to use. If not specified, uses the default
`cassio` session, which requires `cassio.init` has been called.
keyspace: The Cassandra keyspace to use. If not specified, uses the default
`cassio` keyspace, which requires `cassio.init` has been called.
apply_schema: If true, the node table and edge table are created.
"""

session = check_resolve_session(session)
Expand All @@ -60,7 +63,9 @@ def __init__(
# > 0 to be created at all.
# > 1 to support cosine distance.
# So we default to 2.
len(text_embeddings.embed_query("test string")) if text_embeddings else 2
len(text_embeddings.embed_query("test string"))
if text_embeddings
else 2
)

self._session = session
Expand Down Expand Up @@ -107,7 +112,8 @@ def __init__(
def _apply_schema(self):
# Partition by `name` and cluster by `type`.
# Each `(name, type)` pair is a unique node.
# We can enumerate all `type` values for a given `name` to identify ambiguous terms.
# We can enumerate all `type` values for a given `name` to identify ambiguous
# terms.
self._session.execute(
f"""
CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._node_table} (
Expand Down Expand Up @@ -163,9 +169,9 @@ def query_nearest_nodes(self, nodes: Iterable[str], k: int = 1) -> Iterable[Node
"""
For each node, return the nearest nodes in the table.
Parameters:
- nodes: The strings to search for in the list of nodes.
- k: The number of similar nodes to retrieve for each string.
Args:
nodes: The strings to search for in the list of nodes.
k: The number of similar nodes to retrieve for each string.
"""
if self._text_embeddings is None:
raise ValueError("Unable to query for nearest nodes without embeddings")
Expand All @@ -174,7 +180,9 @@ def query_nearest_nodes(self, nodes: Iterable[str], k: int = 1) -> Iterable[Node
self._send_query_nearest_node(n, k) for n in nodes
]

nodes = {_parse_node(n) for node_future in node_futures for n in node_future.result()}
nodes = {
_parse_node(n) for node_future in node_futures for n in node_future.result()
}
return list(nodes)

# TODO: Introduce `ainsert` for async insertions.
Expand All @@ -184,6 +192,7 @@ def insert(
) -> None:
for batch in batched(elements, n=4):
from yaml import dump

text_embeddings = (
iter(
self._text_embeddings.embed_documents(
Expand All @@ -200,7 +209,12 @@ def insert(
properties_json = _serialize_md_dict(element.properties)
batch_statement.add(
self._insert_node,
(element.name, element.type, next(text_embeddings), properties_json),
(
element.name,
element.type,
next(text_embeddings),
properties_json,
),
)
elif isinstance(element, Relation):
batch_statement.add(
Expand Down Expand Up @@ -240,12 +254,13 @@ def subgraph(
# etc.

node_futures: Iterable[ResponseFuture] = [
self._session.execute_async(self._query_relationship, (n.name, n.type)) for n in nodes
self._session.execute_async(self._query_relationship, (n.name, n.type))
for n in nodes
]

nodes = [_parse_node(n) for future in node_futures for n in future.result()]

return (nodes, edges)
return nodes, edges

def traverse(
self,
Expand All @@ -254,15 +269,16 @@ def traverse(
steps: int = 3,
) -> Iterable[Relation]:
"""
Traverse the graph from the given starting nodes and return the resulting sub-graph.
Traverse the graph from the given starting nodes and return the resulting
sub-graph.
Parameters:
- start: The starting node or nodes.
- edge_filters: Filters to apply to the edges being traversed.
- steps: The number of steps of edges to follow from a start node.
Args:
start: The starting node or nodes.
edge_filters: Filters to apply to the edges being traversed.
steps: The number of steps of edges to follow from a start node.
Returns:
An iterable over relations in the traversed sub-graph.
An iterable over relations in the traversed sub-graph.
"""
return traverse(
start=start,
Expand All @@ -285,15 +301,16 @@ async def atraverse(
steps: int = 3,
) -> Iterable[Relation]:
"""
Traverse the graph from the given starting nodes and return the resulting sub-graph.
Traverse the graph from the given starting nodes and return the resulting
sub-graph.
Parameters:
- start: The starting node or nodes.
- edge_filters: Filters to apply to the edges being traversed.
- steps: The number of steps of edges to follow from a start node.
Args:
start: The starting node or nodes.
edge_filters: Filters to apply to the edges being traversed.
steps: The number of steps of edges to follow from a start node.
Returns:
An iterable over relations in the traversed sub-graph.
An iterable over relations in the traversed sub-graph.
"""
return await atraverse(
start=start,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def validate_graph_document(self, document: GraphDocument):
)
if relationship is None:
e.add_note(
f"No relationship allows ({r.source_id} -> {r.type} -> {r.target.type})"
"No relationship allows "
f"({r.source_id} -> {r.type} -> {r.target.type})"
)

if e.__notes__:
Expand Down
4 changes: 3 additions & 1 deletion libs/knowledge-graph/ragstack_knowledge_graph/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,9 @@ def _node_label(node: Node) -> str:
return f"{node.id} [{node.type}]"


def print_graph_documents(graph_documents: Union[GraphDocument, Iterable[GraphDocument]]):
def print_graph_documents(
graph_documents: Union[GraphDocument, Iterable[GraphDocument]]
):
if isinstance(graph_documents, GraphDocument):
graph_documents = [graph_documents]

Expand Down
20 changes: 12 additions & 8 deletions libs/knowledge-graph/ragstack_knowledge_graph/runnables.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ def extract_entities(
This will expect a dictionary containing the `"question"` to extract keywords from.
Parameters:
- llm: The LLM to use for extracting entities.
- node_types: List of node types to extract.
- keyword_extraction_prompt: The prompt to use for requesting entities.
This should include the `{question}` being asked as well as the `{format_instructions}`
which describe how to produce the output.
Args:
llm: The LLM to use for extracting entities.
node_types: List of node types to extract.
keyword_extraction_prompt: The prompt to use for requesting entities.
This should include the `{question}` being asked as well as the
`{format_instructions}` which describe how to produce the output.
"""
prompt = ChatPromptTemplate.from_messages([keyword_extraction_prompt])
assert "question" in prompt.input_variables
Expand All @@ -46,7 +46,9 @@ class SimpleNode(BaseModel):
"""Represents a node in a graph with associated properties."""

id: str = Field(description="Name or human-readable unique identifier.")
type: str = optional_enum_field(node_types, description="The type or label of the node.")
type: str = optional_enum_field(
node_types, description="The type or label of the node."
)

class SimpleNodeList(BaseModel):
"""Represents a list of simple nodes."""
Expand All @@ -61,5 +63,7 @@ class SimpleNodeList(BaseModel):
| ChatPromptTemplate.from_messages([keyword_extraction_prompt])
| llm
| output_parser
| RunnableLambda(lambda node_list: [Node(n["id"], n["type"]) for n in node_list["nodes"]])
| RunnableLambda(
lambda node_list: [Node(n["id"], n["type"]) for n in node_list["nodes"]]
)
)
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,20 @@ class KnowledgeSchemaInferer:
def __init__(self, llm: BaseChatModel) -> None:
prompt = ChatPromptTemplate.from_messages(
[
SystemMessagePromptTemplate(prompt=load_template("schema_inference.md")),
SystemMessagePromptTemplate(
prompt=load_template("schema_inference.md")
),
HumanMessagePromptTemplate.from_template("Input: {input}"),
]
)
# TODO: Use "full" output so we can detect parsing errors?
structured_llm = llm.with_structured_output(KnowledgeSchema)
self._chain = prompt | structured_llm

def infer_schemas_from(self, documents: Sequence[Document]) -> Sequence[KnowledgeSchema]:
responses = self._chain.batch([{"input": doc.page_content} for doc in documents])
def infer_schemas_from(
self, documents: Sequence[Document]
) -> Sequence[KnowledgeSchema]:
responses = self._chain.batch(
[{"input": doc.page_content} for doc in documents]
)
return cast(Sequence[KnowledgeSchema], responses)
4 changes: 3 additions & 1 deletion libs/knowledge-graph/ragstack_knowledge_graph/templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@
TEMPLATE_PATH = path.join(path.dirname(__file__), "prompt_templates")


def load_template(filename: str, **kwargs: Union[str, Callable[[], str]]) -> PromptTemplate:
def load_template(
filename: str, **kwargs: Union[str, Callable[[], str]]
) -> PromptTemplate:
template = PromptTemplate.from_file(path.join(TEMPLATE_PATH, filename))
if kwargs:
template = template.partial(**kwargs)
Expand Down
Loading

0 comments on commit 5a9a321

Please sign in to comment.