From 5a9a321c807285a93cf3389af789db71653694e8 Mon Sep 17 00:00:00 2001 From: Christophe Bornet Date: Tue, 11 Jun 2024 15:17:38 +0200 Subject: [PATCH] CLeanup CassandraKnowledgeGraph code --- .../cassandra_graph_store.py | 23 ++++- .../ragstack_knowledge_graph/extraction.py | 21 +++-- .../knowledge_graph.py | 79 ++++++++++------- .../knowledge_schema.py | 3 +- .../ragstack_knowledge_graph/render.py | 4 +- .../ragstack_knowledge_graph/runnables.py | 20 +++-- .../schema_inference.py | 12 ++- .../ragstack_knowledge_graph/templates.py | 4 +- .../ragstack_knowledge_graph/traverse.py | 85 ++++++++++--------- libs/knowledge-graph/tests/test_extraction.py | 4 +- .../tests/test_schema_inference.py | 8 +- 11 files changed, 166 insertions(+), 97 deletions(-) diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/cassandra_graph_store.py b/libs/knowledge-graph/ragstack_knowledge_graph/cassandra_graph_store.py index bfa1ff779..386f2ad97 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/cassandra_graph_store.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/cassandra_graph_store.py @@ -20,7 +20,9 @@ def _node(node: LangChainNode) -> Node: for node in document.nodes: yield _node(node) for edge in document.relationships: - yield Relation(source=_node(edge.source), target=_node(edge.target), type=edge.type) + yield Relation( + source=_node(edge.source), target=_node(edge.target), type=edge.type + ) class CassandraGraphStore(GraphStore): @@ -56,7 +58,18 @@ def add_graph_documents( def query(self, query: str, params: dict = {}) -> List[Dict[str, Any]]: raise ValueError("Querying Cassandra should use `as_runnable`.") - def as_runnable(self, steps: int = 3, edge_filters: Sequence[str] = []) -> Runnable: + @property + def get_schema(self) -> str: + raise NotImplementedError + + @property + def get_structured_schema(self) -> Dict[str, Any]: + raise NotImplementedError + + def refresh_schema(self) -> None: + raise NotImplementedError + + def as_runnable(self, steps: int = 3, edge_filters: Sequence[str] = ()) -> Runnable: """ Return a runnable that retrieves the sub-graph near the input entity or entities. @@ -64,7 +77,9 @@ def as_runnable(self, steps: int = 3, edge_filters: Sequence[str] = []) -> Runna - steps: The maximum distance to follow from the starting points. - edge_filters: Predicates to use for filtering the edges. """ - return RunnableLambda(func=self.graph.traverse, afunc=self.graph.atraverse).bind( + return RunnableLambda( + func=self.graph.traverse, afunc=self.graph.atraverse + ).bind( steps=steps, - edge_filters=edge_filters, + edge_filters=edge_filters or [], ) diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/extraction.py b/libs/knowledge-graph/ragstack_knowledge_graph/extraction.py index 0625480d2..a4b9da164 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/extraction.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/extraction.py @@ -35,7 +35,7 @@ def __init__( self, llm: BaseChatModel, schema: KnowledgeSchema, - examples: Sequence[Example] = [], + examples: Sequence[Example] = (), strict: bool = False, ) -> None: self._validator = KnowledgeSchemaValidator(schema) @@ -43,7 +43,9 @@ def __init__( messages = [ SystemMessagePromptTemplate( - prompt=load_template("extraction.md", knowledge_schema_yaml=schema.to_yaml_str()) + prompt=load_template( + "extraction.md", knowledge_schema_yaml=schema.to_yaml_str() + ) ) ] @@ -66,14 +68,20 @@ def _process_response( self, document: Document, response: Union[Dict, BaseModel] ) -> GraphDocument: raw_graph = cast(_Graph, response) - nodes = [map_to_base_node(node) for node in raw_graph.nodes] if raw_graph.nodes else [] + nodes = ( + [map_to_base_node(node) for node in raw_graph.nodes] + if raw_graph.nodes + else [] + ) relationships = ( [map_to_base_relationship(rel) for rel in raw_graph.relationships] if raw_graph.relationships else [] ) - document = GraphDocument(nodes=nodes, relationships=relationships, source=document) + document = GraphDocument( + nodes=nodes, relationships=relationships, source=document + ) if self.strict: self._validator.validate_graph_document(document) @@ -85,4 +93,7 @@ def extract(self, documents: List[Document]) -> List[GraphDocument]: responses = self._chain.batch_as_completed( [{"input": doc.page_content} for doc in documents] ) - return [self._process_response(documents[idx], response) for idx, response in responses] + return [ + self._process_response(documents[idx], response) + for idx, response in responses + ] diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py index 0dc40d297..6ea4a1d3a 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_graph.py @@ -23,7 +23,9 @@ def _parse_node(row) -> Node: return Node( name=row.name, type=row.type, - properties=_deserialize_md_dict(row.properties_json) if row.properties_json else dict(), + properties=_deserialize_md_dict(row.properties_json) + if row.properties_json + else dict(), ) @@ -40,15 +42,16 @@ def __init__( """ Create a Cassandra Knowledge Graph. - Parameters: - - node_table: Name of the table containing nodes. Defaults to `"entities"`. - - edge_table: Name of the table containing edges. Defaults to `"relationships"`. - _ text_embeddings: Name of the embeddings to use, if any. - - session: The Cassandra `Session` to use. If not specified, uses the default `cassio` - session, which requires `cassio.init` has been called. - - keyspace: The Cassandra keyspace to use. If not specified, uses the default `cassio` - keyspace, which requires `cassio.init` has been called. - - apply_schema: If true, the node table and edge table are created. + Args: + node_table: Name of the table containing nodes. Defaults to `"entities"`. + edge_table: Name of the table containing edges. Defaults to ` + "relationships"`. + text_embeddings: Name of the embeddings to use, if any. + session: The Cassandra `Session` to use. If not specified, uses the default + `cassio` session, which requires `cassio.init` has been called. + keyspace: The Cassandra keyspace to use. If not specified, uses the default + `cassio` keyspace, which requires `cassio.init` has been called. + apply_schema: If true, the node table and edge table are created. """ session = check_resolve_session(session) @@ -60,7 +63,9 @@ def __init__( # > 0 to be created at all. # > 1 to support cosine distance. # So we default to 2. - len(text_embeddings.embed_query("test string")) if text_embeddings else 2 + len(text_embeddings.embed_query("test string")) + if text_embeddings + else 2 ) self._session = session @@ -107,7 +112,8 @@ def __init__( def _apply_schema(self): # Partition by `name` and cluster by `type`. # Each `(name, type)` pair is a unique node. - # We can enumerate all `type` values for a given `name` to identify ambiguous terms. + # We can enumerate all `type` values for a given `name` to identify ambiguous + # terms. self._session.execute( f""" CREATE TABLE IF NOT EXISTS {self._keyspace}.{self._node_table} ( @@ -163,9 +169,9 @@ def query_nearest_nodes(self, nodes: Iterable[str], k: int = 1) -> Iterable[Node """ For each node, return the nearest nodes in the table. - Parameters: - - nodes: The strings to search for in the list of nodes. - - k: The number of similar nodes to retrieve for each string. + Args: + nodes: The strings to search for in the list of nodes. + k: The number of similar nodes to retrieve for each string. """ if self._text_embeddings is None: raise ValueError("Unable to query for nearest nodes without embeddings") @@ -174,7 +180,9 @@ def query_nearest_nodes(self, nodes: Iterable[str], k: int = 1) -> Iterable[Node self._send_query_nearest_node(n, k) for n in nodes ] - nodes = {_parse_node(n) for node_future in node_futures for n in node_future.result()} + nodes = { + _parse_node(n) for node_future in node_futures for n in node_future.result() + } return list(nodes) # TODO: Introduce `ainsert` for async insertions. @@ -184,6 +192,7 @@ def insert( ) -> None: for batch in batched(elements, n=4): from yaml import dump + text_embeddings = ( iter( self._text_embeddings.embed_documents( @@ -200,7 +209,12 @@ def insert( properties_json = _serialize_md_dict(element.properties) batch_statement.add( self._insert_node, - (element.name, element.type, next(text_embeddings), properties_json), + ( + element.name, + element.type, + next(text_embeddings), + properties_json, + ), ) elif isinstance(element, Relation): batch_statement.add( @@ -240,12 +254,13 @@ def subgraph( # etc. node_futures: Iterable[ResponseFuture] = [ - self._session.execute_async(self._query_relationship, (n.name, n.type)) for n in nodes + self._session.execute_async(self._query_relationship, (n.name, n.type)) + for n in nodes ] nodes = [_parse_node(n) for future in node_futures for n in future.result()] - return (nodes, edges) + return nodes, edges def traverse( self, @@ -254,15 +269,16 @@ def traverse( steps: int = 3, ) -> Iterable[Relation]: """ - Traverse the graph from the given starting nodes and return the resulting sub-graph. + Traverse the graph from the given starting nodes and return the resulting + sub-graph. - Parameters: - - start: The starting node or nodes. - - edge_filters: Filters to apply to the edges being traversed. - - steps: The number of steps of edges to follow from a start node. + Args: + start: The starting node or nodes. + edge_filters: Filters to apply to the edges being traversed. + steps: The number of steps of edges to follow from a start node. Returns: - An iterable over relations in the traversed sub-graph. + An iterable over relations in the traversed sub-graph. """ return traverse( start=start, @@ -285,15 +301,16 @@ async def atraverse( steps: int = 3, ) -> Iterable[Relation]: """ - Traverse the graph from the given starting nodes and return the resulting sub-graph. + Traverse the graph from the given starting nodes and return the resulting + sub-graph. - Parameters: - - start: The starting node or nodes. - - edge_filters: Filters to apply to the edges being traversed. - - steps: The number of steps of edges to follow from a start node. + Args: + start: The starting node or nodes. + edge_filters: Filters to apply to the edges being traversed. + steps: The number of steps of edges to follow from a start node. Returns: - An iterable over relations in the traversed sub-graph. + An iterable over relations in the traversed sub-graph. """ return await atraverse( start=start, diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py index 9e01641f5..353477457 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/knowledge_schema.py @@ -105,7 +105,8 @@ def validate_graph_document(self, document: GraphDocument): ) if relationship is None: e.add_note( - f"No relationship allows ({r.source_id} -> {r.type} -> {r.target.type})" + "No relationship allows " + f"({r.source_id} -> {r.type} -> {r.target.type})" ) if e.__notes__: diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/render.py b/libs/knowledge-graph/ragstack_knowledge_graph/render.py index 32fdcd7a3..365febe1e 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/render.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/render.py @@ -10,7 +10,9 @@ def _node_label(node: Node) -> str: return f"{node.id} [{node.type}]" -def print_graph_documents(graph_documents: Union[GraphDocument, Iterable[GraphDocument]]): +def print_graph_documents( + graph_documents: Union[GraphDocument, Iterable[GraphDocument]] +): if isinstance(graph_documents, GraphDocument): graph_documents = [graph_documents] diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/runnables.py b/libs/knowledge-graph/ragstack_knowledge_graph/runnables.py index 885329ede..875106463 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/runnables.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/runnables.py @@ -31,12 +31,12 @@ def extract_entities( This will expect a dictionary containing the `"question"` to extract keywords from. - Parameters: - - llm: The LLM to use for extracting entities. - - node_types: List of node types to extract. - - keyword_extraction_prompt: The prompt to use for requesting entities. - This should include the `{question}` being asked as well as the `{format_instructions}` - which describe how to produce the output. + Args: + llm: The LLM to use for extracting entities. + node_types: List of node types to extract. + keyword_extraction_prompt: The prompt to use for requesting entities. + This should include the `{question}` being asked as well as the + `{format_instructions}` which describe how to produce the output. """ prompt = ChatPromptTemplate.from_messages([keyword_extraction_prompt]) assert "question" in prompt.input_variables @@ -46,7 +46,9 @@ class SimpleNode(BaseModel): """Represents a node in a graph with associated properties.""" id: str = Field(description="Name or human-readable unique identifier.") - type: str = optional_enum_field(node_types, description="The type or label of the node.") + type: str = optional_enum_field( + node_types, description="The type or label of the node." + ) class SimpleNodeList(BaseModel): """Represents a list of simple nodes.""" @@ -61,5 +63,7 @@ class SimpleNodeList(BaseModel): | ChatPromptTemplate.from_messages([keyword_extraction_prompt]) | llm | output_parser - | RunnableLambda(lambda node_list: [Node(n["id"], n["type"]) for n in node_list["nodes"]]) + | RunnableLambda( + lambda node_list: [Node(n["id"], n["type"]) for n in node_list["nodes"]] + ) ) diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/schema_inference.py b/libs/knowledge-graph/ragstack_knowledge_graph/schema_inference.py index 9f1154a90..6ab9c1908 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/schema_inference.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/schema_inference.py @@ -16,7 +16,9 @@ class KnowledgeSchemaInferer: def __init__(self, llm: BaseChatModel) -> None: prompt = ChatPromptTemplate.from_messages( [ - SystemMessagePromptTemplate(prompt=load_template("schema_inference.md")), + SystemMessagePromptTemplate( + prompt=load_template("schema_inference.md") + ), HumanMessagePromptTemplate.from_template("Input: {input}"), ] ) @@ -24,6 +26,10 @@ def __init__(self, llm: BaseChatModel) -> None: structured_llm = llm.with_structured_output(KnowledgeSchema) self._chain = prompt | structured_llm - def infer_schemas_from(self, documents: Sequence[Document]) -> Sequence[KnowledgeSchema]: - responses = self._chain.batch([{"input": doc.page_content} for doc in documents]) + def infer_schemas_from( + self, documents: Sequence[Document] + ) -> Sequence[KnowledgeSchema]: + responses = self._chain.batch( + [{"input": doc.page_content} for doc in documents] + ) return cast(Sequence[KnowledgeSchema], responses) diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/templates.py b/libs/knowledge-graph/ragstack_knowledge_graph/templates.py index db310b636..70cf8ca67 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/templates.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/templates.py @@ -6,7 +6,9 @@ TEMPLATE_PATH = path.join(path.dirname(__file__), "prompt_templates") -def load_template(filename: str, **kwargs: Union[str, Callable[[], str]]) -> PromptTemplate: +def load_template( + filename: str, **kwargs: Union[str, Callable[[], str]] +) -> PromptTemplate: template = PromptTemplate.from_file(path.join(TEMPLATE_PATH, filename)) if kwargs: template = template.partial(**kwargs) diff --git a/libs/knowledge-graph/ragstack_knowledge_graph/traverse.py b/libs/knowledge-graph/ragstack_knowledge_graph/traverse.py index 24d7dd27e..83f4282bb 100644 --- a/libs/knowledge-graph/ragstack_knowledge_graph/traverse.py +++ b/libs/knowledge-graph/ragstack_knowledge_graph/traverse.py @@ -84,23 +84,23 @@ def traverse( """ Traverse the graph from the given starting nodes and return the resulting sub-graph. - Parameters: - - start: The starting node or nodes. - - edge_table: The table containing the edges. - - edge_source_name: The name of the column containing edge source names. - - edge_source_type: The name of the column containing edge source types. - - edge_target_name: The name of the column containing edge target names. - - edge_target_type: The name of the column containing edge target types. - - edge_type: The name of the column containing edge types. - - edge_filters: Filters to apply to the edges being traversed. - - steps: The number of steps of edges to follow from a start node. - - session: The session to use for executing the query. If not specified, - it will use th default cassio session. - - keyspace: The keyspace to use for the query. If not specified, it will - use the default cassio keyspace. + Args: + start: The starting node or nodes. + edge_table: The table containing the edges. + edge_source_name: The name of the column containing edge source names. + edge_source_type: The name of the column containing edge source types. + edge_target_name: The name of the column containing edge target names. + edge_target_type: The name of the column containing edge target types. + edge_type: The name of the column containing edge types. + edge_filters: Filters to apply to the edges being traversed. + steps: The number of steps of edges to follow from a start node. + session: The session to use for executing the query. If not specified, + it will use th default cassio session. + keyspace: The keyspace to use for the query. If not specified, it will + use the default cassio keyspace. Returns: - An iterable over relations in the traversed sub-graph. + An iterable over relations in the traversed sub-graph. """ if len(start) == 0: return [] @@ -166,7 +166,9 @@ def fetch_relationships(distance: int, source: Node) -> None: distances[source] = distance - request: ResponseFuture = session.execute_async(query, (source.name, source.type)) + request: ResponseFuture = session.execute_async( + query, (source.name, source.type) + ) pending.add(request._req_id) request.add_callbacks( handle_result, @@ -208,9 +210,9 @@ async def next(self): if self.response_future.has_more_pages: self.current_page_future = asyncio.Future() self.response_future.start_fetching_next_page() - return (self.depth, page, self) + return self.depth, page, self else: - return (self.depth, page, None) + return self.depth, page, None async def atraverse( @@ -221,34 +223,35 @@ async def atraverse( edge_target_name: str = "target_name", edge_target_type: str = "target_type", edge_type: str = "edge_type", - edge_filters: Sequence[str] = [], + edge_filters: Sequence[str] = (), steps: int = 3, session: Optional[Session] = None, keyspace: Optional[str] = None, ) -> Iterable[Relation]: """ - Async traversal of the graph from the given starting nodes and return the resulting sub-graph. + Async traversal of the graph from the given starting nodes and return the resulting + sub-graph. Parameters: - - start: The starting node or nodes. - - edge_table: The table containing the edges. - - edge_source_name: The name of the column containing edge source names. - - edge_source_type: The name of the column containing edge source types. - - edge_target_name: The name of the column containing edge target names. - - edge_target_type: The name of the column containing edge target types. - - edge_type: The name of the column containing edge types. - - edge_filters: Filters to apply to the edges being traversed. - Currently, this is specified as a dictionary containing the name - of the edge field to filter on and the CQL predicate to apply. - For example `{"foo": "IN ['a', 'b', 'c']"}`. - - steps: The number of steps of edges to follow from a start node. - - session: The session to use for executing the query. If not specified, - it will use th default cassio session. - - keyspace: The keyspace to use for the query. If not specified, it will - use the default cassio keyspace. + start: The starting node or nodes. + edge_table: The table containing the edges. + edge_source_name: The name of the column containing edge source names. + edge_source_type: The name of the column containing edge source types. + edge_target_name: The name of the column containing edge target names. + edge_target_type: The name of the column containing edge target types. + edge_type: The name of the column containing edge types. + edge_filters: Filters to apply to the edges being traversed. + Currently, this is specified as a dictionary containing the name of the + edge field to filter on and the CQL predicate to apply. + For example `{"foo": "IN ['a', 'b', 'c']"}`. + steps: The number of steps of edges to follow from a start node. + session: The session to use for executing the query. If not specified, + it will use th default cassio session. + keyspace: The keyspace to use for the query. If not specified, it will + use the default cassio keyspace. Returns: - An iterable over relations in the traversed sub-graph. + An iterable over relations in the traversed sub-graph. """ session = check_resolve_session(session) @@ -272,7 +275,9 @@ async def atraverse( keyspace=keyspace, ) - def fetch_relation(tg: asyncio.TaskGroup, depth: int, source: Node) -> AsyncPagedQuery: + def fetch_relation( + tg: asyncio.TaskGroup, depth: int, source: Node + ) -> AsyncPagedQuery: paged_query = AsyncPagedQuery( depth, session.execute_async(query, (source.name, source.type)) ) @@ -287,7 +292,9 @@ def fetch_relation(tg: asyncio.TaskGroup, depth: int, source: Node) -> AsyncPage pending = {fetch_relation(tg, 1, source) for source in start} while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + done, pending = await asyncio.wait( + pending, return_when=asyncio.FIRST_COMPLETED + ) for future in done: depth, relations, more = future.result() for relation in relations: diff --git a/libs/knowledge-graph/tests/test_extraction.py b/libs/knowledge-graph/tests/test_extraction.py index fbf4cec14..3944c5725 100644 --- a/libs/knowledge-graph/tests/test_extraction.py +++ b/libs/knowledge-graph/tests/test_extraction.py @@ -46,8 +46,8 @@ def test_extraction(extractor: KnowledgeSchemaExtractor): nobel_prize = Node(id="Nobel Prize", type="Award") pierre_curie = Node(id="Pierre Curie", type="Person") - # Annoyingly, the LLM seems to upper-case `of`. We probably need some instructions around - # putting things into standard title case, etc. + # Annoyingly, the LLM seems to upper-case `of`. We probably need some instructions + # around putting things into standard title case, etc. university_of_paris = Node(id="University Of Paris", type="Institution") assert sorted(results[0].nodes, key=lambda x: x.id) == sorted( diff --git a/libs/knowledge-graph/tests/test_schema_inference.py b/libs/knowledge-graph/tests/test_schema_inference.py index 0844ae816..4847e1ff2 100644 --- a/libs/knowledge-graph/tests/test_schema_inference.py +++ b/libs/knowledge-graph/tests/test_schema_inference.py @@ -22,7 +22,9 @@ def test_schema_inference(llm: BaseChatModel): schema_inferer = KnowledgeSchemaInferer(llm) - results = schema_inferer.infer_schemas_from([Document(page_content=MARIE_CURIE_SOURCE)])[0] + results = schema_inferer.infer_schemas_from( + [Document(page_content=MARIE_CURIE_SOURCE)] + )[0] print(results.to_yaml_str()) nodes = [n.type for n in results.nodes] @@ -38,7 +40,9 @@ def test_schema_inference(llm: BaseChatModel): print(rels) any_of_in_list(rels, "won", "won_award") any_of_in_list(rels, "is_nationality_of", "has_nationality") - any_of_in_list(rels, "first_professor_at", "professor_at", "works_at", "has_position_at") + any_of_in_list( + rels, "first_professor_at", "professor_at", "works_at", "has_position_at" + ) any_of_in_list(rels, "conducted_research_in") # We don't do more testing here since this is meant to attempt to infer things.