diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index c4fc133f8..f43dcb9a5 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -19,7 +19,7 @@ from .concurrency import ConcurrentQueries from .content import Kind from .embedding_model import EmbeddingModel -from .link_tag import LinkTag +from .links import Link from .math import cosine_similarity CONTENT_ID = "content_id" @@ -33,7 +33,7 @@ class Node: """Unique ID for the node. Will be generated by the GraphStore if not set.""" metadata: dict = field(default_factory=dict) """Metadata for the node.""" - links: Set[LinkTag] = field(default_factory=set) + links: Set[Link] = field(default_factory=set) """Links for the node.""" @@ -330,13 +330,13 @@ def add_nodes( ) -> Iterable[str]: texts = [] metadatas = [] - links: List[Set[LinkTag]] = [] + nodes_links: List[Set[Link]] = [] for node in nodes: if not isinstance(node, TextNode): raise ValueError("Only adding TextNode is supported at the moment") texts.append(node.text) metadatas.append(node.metadata) - links.append(node.links) + nodes_links.append(node.links) text_embeddings = self._embedding.embed_texts(texts) @@ -347,8 +347,8 @@ def add_nodes( # Step 1: Add the nodes, collecting the tags and new sources / targets. with self._concurrent_queries() as cq: - tuples = zip(texts, text_embeddings, metadatas, links) - for text, text_embedding, metadata, _links in tuples: + tuples = zip(texts, text_embeddings, metadatas, nodes_links) + for text, text_embedding, metadata, links in tuples: if CONTENT_ID not in metadata: metadata[CONTENT_ID] = secrets.token_hex(8) id = metadata[CONTENT_ID] @@ -357,20 +357,22 @@ def add_nodes( link_to_tags = set() # link to these tags link_from_tags = set() # link from these tags - for tag in _links: - tag_str = f"{tag.kind}:{tag.tag}" - if tag.direction == "incoming" or tag.direction == "bidir": - # An incom`ing link should be linked *from* nodes with the given tag. - link_from_tags.add(tag_str) - tag_to_new_targets.setdefault(tag_str, dict())[id] = ( - tag.kind, - text_embedding, - ) - if tag.direction == "outgoing" or tag.direction == "bidir": - link_to_tags.add(tag_str) - tag_to_new_sources.setdefault(tag_str, list()).append( - (tag.kind, id) - ) + for tag in links: + if hasattr(tag, "tag"): + tag_str = f"{tag.kind}:{tag.tag}" + if tag.direction == "incoming" or tag.direction == "bidir": + # An incoming link should be linked *from* nodes with the + # given tag. + link_from_tags.add(tag_str) + tag_to_new_targets.setdefault(tag_str, dict())[id] = ( + tag.kind, + text_embedding, + ) + if tag.direction == "outgoing" or tag.direction == "bidir": + link_to_tags.add(tag_str) + tag_to_new_sources.setdefault(tag_str, list()).append( + (tag.kind, id) + ) cq.execute( self._insert_passage, diff --git a/libs/knowledge-store/ragstack_knowledge_store/link_tag.py b/libs/knowledge-store/ragstack_knowledge_store/links.py similarity index 68% rename from libs/knowledge-store/ragstack_knowledge_store/link_tag.py rename to libs/knowledge-store/ragstack_knowledge_store/links.py index e9899d21d..05a9a90dd 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/link_tag.py +++ b/libs/knowledge-store/ragstack_knowledge_store/links.py @@ -1,20 +1,22 @@ from dataclasses import dataclass -from typing import Literal, Dict, Any, Set +from typing import Literal @dataclass(frozen=True) -class _LinkTag: +class Link: kind: str - tag: str direction: Literal["incoming", "outgoing", "bidir"] + def __post_init__(self): + if self.__class__ in [Link, LinkTag]: + raise TypeError( + f"Abstract class {self.__class__.__name__} cannot be instantiated" + ) + @dataclass(frozen=True) -class LinkTag(_LinkTag): - def __init__(self, kind: str, tag: str, direction: str) -> None: - if self.__class__ == LinkTag: - raise TypeError("Abstract class LinkTag cannot be instantiated") - super().__init__(kind, tag, direction) +class LinkTag(Link): + tag: str @dataclass(frozen=True) diff --git a/libs/langchain/ragstack_langchain/graph_store/base.py b/libs/langchain/ragstack_langchain/graph_store/base.py index 4341ece17..9e7249707 100644 --- a/libs/langchain/ragstack_langchain/graph_store/base.py +++ b/libs/langchain/ragstack_langchain/graph_store/base.py @@ -23,7 +23,7 @@ from langchain_core.vectorstores import VectorStore, VectorStoreRetriever from langchain_core.pydantic_v1 import Field -from ragstack_langchain.graph_store.links import LinkTag, LINKS +from ragstack_langchain.graph_store.links import METADATA_LINKS_KEY, Link def _has_next(iterator: Iterator) -> None: @@ -40,7 +40,7 @@ class Node(Serializable): """Unique ID for the node. Will be generated by the GraphStore if not set.""" metadata: dict = Field(default_factory=dict) """Metadata for the node.""" - links: Set[LinkTag] = Field(default_factory=set) + links: Set[Link] = Field(default_factory=set) """Links associated with the node.""" @@ -66,7 +66,7 @@ def _texts_to_nodes( except StopIteration: raise ValueError("texts iterable longer than ids") - links = _metadata.pop(LINKS, set()) + links = _metadata.pop(METADATA_LINKS_KEY, set()) yield TextNode( id=_id, metadata=_metadata, @@ -89,7 +89,7 @@ def _documents_to_nodes( except StopIteration: raise ValueError("documents iterable longer than ids") metadata = doc.metadata.copy() - links = metadata.pop(LINKS, set()) + links = metadata.pop(METADATA_LINKS_KEY, set()) yield TextNode( id=_id, metadata=metadata, diff --git a/libs/langchain/ragstack_langchain/graph_store/links.py b/libs/langchain/ragstack_langchain/graph_store/links.py index 76f310b65..32c9d3e06 100644 --- a/libs/langchain/ragstack_langchain/graph_store/links.py +++ b/libs/langchain/ragstack_langchain/graph_store/links.py @@ -39,7 +39,7 @@ def __init__(self, kind: str, tag: str) -> None: super().__init__(kind=kind, tag=tag, direction="bidir") -LINKS = "links" +METADATA_LINKS_KEY = "links" def get_links(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[Link]: @@ -53,10 +53,10 @@ def get_links(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[Link]: if isinstance(doc_or_md, Document): doc_or_md = doc_or_md.metadata - links = doc_or_md.setdefault(LINKS, set()) + links = doc_or_md.setdefault(METADATA_LINKS_KEY, set()) if not isinstance(links, Set): links = set(links) - doc_or_md[LINKS] = links + doc_or_md[METADATA_LINKS_KEY] = links return links diff --git a/libs/langchain/tests/integration_tests/test_graph_store.py b/libs/langchain/tests/integration_tests/test_graph_store.py index 9a055dfce..ce867522d 100644 --- a/libs/langchain/tests/integration_tests/test_graph_store.py +++ b/libs/langchain/tests/integration_tests/test_graph_store.py @@ -14,7 +14,7 @@ TextNode, ) from ragstack_langchain.graph_store.links import ( - LINKS, + METADATA_LINKS_KEY, BidirLinkTag, IncomingLinkTag, OutgoingLinkTag, @@ -37,7 +37,7 @@ def __init__(self, session: Session, keyspace: str, embedding: Embeddings) -> No def store( self, - initial_documents: Iterable[Document] = [], + initial_documents: Iterable[Document] = (), ids: Optional[Iterable[str]] = None, embedding: Optional[Embeddings] = None, ) -> CassandraGraphStore: @@ -125,7 +125,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None: page_content="A", metadata={ "content_id": "a", - LINKS: { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="hyperlink", tag="http://a"), }, }, @@ -134,7 +134,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None: page_content="B", metadata={ "content_id": "b", - LINKS: { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="hyperlink", tag="http://b"), OutgoingLinkTag(kind="hyperlink", tag="http://a"), }, @@ -144,7 +144,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None: page_content="C", metadata={ "content_id": "c", - LINKS: { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="hyperlink", tag="http://a"), }, }, @@ -153,7 +153,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None: page_content="D", metadata={ "content_id": "d", - LINKS: { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="hyperlink", tag="http://a"), OutgoingLinkTag(kind="hyperlink", tag="http://b"), }, @@ -198,7 +198,7 @@ def test_mmr_traversal(request, gs_factory: str): page_content="-0.124", metadata={ "content_id": "v0", - LINKS: { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="explicit", tag="link"), }, }, @@ -213,7 +213,7 @@ def test_mmr_traversal(request, gs_factory: str): page_content="+0.25", metadata={ "content_id": "v2", - LINKS: { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="explicit", tag="link"), }, }, @@ -222,7 +222,7 @@ def test_mmr_traversal(request, gs_factory: str): page_content="+1.0", metadata={ "content_id": "v3", - LINKS: { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="explicit", tag="link"), }, }, @@ -257,7 +257,7 @@ def test_write_retrieve_keywords(request, gs_factory: str): page_content="Typical Greetings", metadata={ "content_id": "greetings", - LINKS: { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="parent", tag="parent"), }, }, @@ -266,7 +266,7 @@ def test_write_retrieve_keywords(request, gs_factory: str): page_content="Hello World", metadata={ "content_id": "doc1", - LINKS: { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="parent", tag="parent"), BidirLinkTag(kind="kw", tag="greeting"), BidirLinkTag(kind="kw", tag="world"), @@ -277,7 +277,7 @@ def test_write_retrieve_keywords(request, gs_factory: str): page_content="Hello Earth", metadata={ "content_id": "doc2", - LINKS: { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="parent", tag="parent"), BidirLinkTag(kind="kw", tag="greeting"), BidirLinkTag(kind="kw", tag="earth"),