diff --git a/.github/workflows/ci-unit-tests.yml b/.github/workflows/ci-unit-tests.yml index 47ea37872..a353de3a7 100644 --- a/.github/workflows/ci-unit-tests.yml +++ b/.github/workflows/ci-unit-tests.yml @@ -139,9 +139,15 @@ jobs: rm -rf $dir/.tox } - run_itests libs/colbert - run_itests libs/langchain - run_itests libs/llamaindex + if [[ "true" == "${{ needs.preconditions.outputs.libs_colbert }}" ]]; then + run_itests libs/colbert + fi + if [[ "true" == "${{ needs.preconditions.outputs.libs_langchain }}" ]]; then + run_itests libs/langchain + fi + if [[ "true" == "${{ needs.preconditions.outputs.libs_llamaindex }}" ]]; then + run_itests libs/llamaindex + fi - name: Cleanup AstraDB uses: nicoloboschi/cleanup-astradb@v1 diff --git a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py index 0b5858e22..f43dcb9a5 100644 --- a/libs/knowledge-store/ragstack_knowledge_store/graph_store.py +++ b/libs/knowledge-store/ragstack_knowledge_store/graph_store.py @@ -8,6 +8,7 @@ NamedTuple, Optional, Sequence, + Set, Tuple, ) @@ -18,7 +19,7 @@ from .concurrency import ConcurrentQueries from .content import Kind from .embedding_model import EmbeddingModel -from .link_tag import get_link_tags +from .links import Link from .math import cosine_similarity CONTENT_ID = "content_id" @@ -31,8 +32,9 @@ class Node: id: Optional[str] = None """Unique ID for the node. Will be generated by the GraphStore if not set.""" metadata: dict = field(default_factory=dict) - """Metadata for the node. May contain information used to link this node - with other nodes.""" + """Metadata for the node.""" + links: Set[Link] = field(default_factory=set) + """Links for the node.""" @dataclass @@ -328,11 +330,13 @@ def add_nodes( ) -> Iterable[str]: texts = [] metadatas = [] + nodes_links: List[Set[Link]] = [] for node in nodes: if not isinstance(node, TextNode): raise ValueError("Only adding TextNode is supported at the moment") texts.append(node.text) metadatas.append(node.metadata) + nodes_links.append(node.links) text_embeddings = self._embedding.embed_texts(texts) @@ -343,8 +347,8 @@ def add_nodes( # Step 1: Add the nodes, collecting the tags and new sources / targets. with self._concurrent_queries() as cq: - tuples = zip(texts, text_embeddings, metadatas) - for text, text_embedding, metadata in tuples: + tuples = zip(texts, text_embeddings, metadatas, nodes_links) + for text, text_embedding, metadata, links in tuples: if CONTENT_ID not in metadata: metadata[CONTENT_ID] = secrets.token_hex(8) id = metadata[CONTENT_ID] @@ -353,20 +357,22 @@ def add_nodes( link_to_tags = set() # link to these tags link_from_tags = set() # link from these tags - for tag in get_link_tags(metadata): - tag_str = f"{tag.kind}:{tag.tag}" - if tag.direction == "incoming" or tag.direction == "bidir": - # An incom`ing link should be linked *from* nodes with the given tag. - link_from_tags.add(tag_str) - tag_to_new_targets.setdefault(tag_str, dict())[id] = ( - tag.kind, - text_embedding, - ) - if tag.direction == "outgoing" or tag.direction == "bidir": - link_to_tags.add(tag_str) - tag_to_new_sources.setdefault(tag_str, list()).append( - (tag.kind, id) - ) + for tag in links: + if hasattr(tag, "tag"): + tag_str = f"{tag.kind}:{tag.tag}" + if tag.direction == "incoming" or tag.direction == "bidir": + # An incoming link should be linked *from* nodes with the + # given tag. + link_from_tags.add(tag_str) + tag_to_new_targets.setdefault(tag_str, dict())[id] = ( + tag.kind, + text_embedding, + ) + if tag.direction == "outgoing" or tag.direction == "bidir": + link_to_tags.add(tag_str) + tag_to_new_sources.setdefault(tag_str, list()).append( + (tag.kind, id) + ) cq.execute( self._insert_passage, diff --git a/libs/knowledge-store/ragstack_knowledge_store/link_tag.py b/libs/knowledge-store/ragstack_knowledge_store/link_tag.py deleted file mode 100644 index 28096e422..000000000 --- a/libs/knowledge-store/ragstack_knowledge_store/link_tag.py +++ /dev/null @@ -1,54 +0,0 @@ -from dataclasses import dataclass -from typing import Literal, Dict, Any, Set - - -@dataclass(frozen=True) -class _LinkTag: - kind: str - tag: str - direction: Literal["incoming", "outgoing", "bidir"] - - -@dataclass(frozen=True) -class LinkTag(_LinkTag): - def __init__(self, kind: str, tag: str, direction: str) -> None: - if self.__class__ == LinkTag: - raise TypeError("Abstract class LinkTag cannot be instantiated") - super().__init__(kind, tag, direction) - - -@dataclass(frozen=True) -class OutgoingLinkTag(LinkTag): - def __init__(self, kind: str, tag: str) -> None: - super().__init__(kind=kind, tag=tag, direction="outgoing") - - -@dataclass(frozen=True) -class IncomingLinkTag(LinkTag): - def __init__(self, kind: str, tag: str) -> None: - super().__init__(kind=kind, tag=tag, direction="incoming") - - -@dataclass(frozen=True) -class BidirLinkTag(LinkTag): - def __init__(self, kind: str, tag: str) -> None: - super().__init__(kind=kind, tag=tag, direction="bidir") - - -LINK_TAGS = "link_tags" - - -def get_link_tags(doc_or_md: Dict[str, Any]) -> Set[LinkTag]: - """Get the link-tag set from a document or metadata. - - Args: - doc_or_md: The document or metadata to get the link tags from. - - Returns: - The set of link tags from the document or metadata. - """ - link_tags = doc_or_md.setdefault(LINK_TAGS, set()) - if not isinstance(link_tags, Set): - link_tags = set(link_tags) - doc_or_md[LINK_TAGS] = link_tags - return link_tags diff --git a/libs/knowledge-store/ragstack_knowledge_store/links.py b/libs/knowledge-store/ragstack_knowledge_store/links.py new file mode 100644 index 000000000..05a9a90dd --- /dev/null +++ b/libs/knowledge-store/ragstack_knowledge_store/links.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass +from typing import Literal + + +@dataclass(frozen=True) +class Link: + kind: str + direction: Literal["incoming", "outgoing", "bidir"] + + def __post_init__(self): + if self.__class__ in [Link, LinkTag]: + raise TypeError( + f"Abstract class {self.__class__.__name__} cannot be instantiated" + ) + + +@dataclass(frozen=True) +class LinkTag(Link): + tag: str + + +@dataclass(frozen=True) +class OutgoingLinkTag(LinkTag): + def __init__(self, kind: str, tag: str) -> None: + super().__init__(kind=kind, tag=tag, direction="outgoing") + + +@dataclass(frozen=True) +class IncomingLinkTag(LinkTag): + def __init__(self, kind: str, tag: str) -> None: + super().__init__(kind=kind, tag=tag, direction="incoming") + + +@dataclass(frozen=True) +class BidirLinkTag(LinkTag): + def __init__(self, kind: str, tag: str) -> None: + super().__init__(kind=kind, tag=tag, direction="bidir") diff --git a/libs/langchain/ragstack_langchain/graph_store/base.py b/libs/langchain/ragstack_langchain/graph_store/base.py index c28696cf8..9e7249707 100644 --- a/libs/langchain/ragstack_langchain/graph_store/base.py +++ b/libs/langchain/ragstack_langchain/graph_store/base.py @@ -10,6 +10,7 @@ Iterator, List, Optional, + Set, ) from langchain_core.callbacks import ( @@ -20,7 +21,9 @@ from langchain_core.load import Serializable from langchain_core.runnables import run_in_executor from langchain_core.vectorstores import VectorStore, VectorStoreRetriever -from pydantic import Field +from langchain_core.pydantic_v1 import Field + +from ragstack_langchain.graph_store.links import METADATA_LINKS_KEY, Link def _has_next(iterator: Iterator) -> None: @@ -36,8 +39,9 @@ class Node(Serializable): id: Optional[str] """Unique ID for the node. Will be generated by the GraphStore if not set.""" metadata: dict = Field(default_factory=dict) - """Metadata for the node. May contain information used to link this node - with other nodes.""" + """Metadata for the node.""" + links: Set[Link] = Field(default_factory=set) + """Links associated with the node.""" class TextNode(Node): @@ -54,17 +58,20 @@ def _texts_to_nodes( ids_it = iter(ids) if ids else None for text in texts: try: - _metadata = next(metadatas_it) if metadatas_it else {} + _metadata = next(metadatas_it).copy() if metadatas_it else {} except StopIteration: raise ValueError("texts iterable longer than metadatas") try: _id = next(ids_it) if ids_it else None except StopIteration: raise ValueError("texts iterable longer than ids") + + links = _metadata.pop(METADATA_LINKS_KEY, set()) yield TextNode( id=_id, metadata=_metadata, text=text, + links=links, ) if ids and _has_next(ids_it): raise ValueError("ids iterable longer than texts") @@ -81,10 +88,13 @@ def _documents_to_nodes( _id = next(ids_it) if ids_it else None except StopIteration: raise ValueError("documents iterable longer than ids") + metadata = doc.metadata.copy() + links = metadata.pop(METADATA_LINKS_KEY, set()) yield TextNode( id=_id, - metadata=doc.metadata, + metadata=metadata, text=doc.page_content, + links=links, ) if ids and _has_next(ids_it): raise ValueError("ids iterable longer than documents") diff --git a/libs/langchain/ragstack_langchain/graph_store/cassandra.py b/libs/langchain/ragstack_langchain/graph_store/cassandra.py index fe2fe5e63..93f7a8325 100644 --- a/libs/langchain/ragstack_langchain/graph_store/cassandra.py +++ b/libs/langchain/ragstack_langchain/graph_store/cassandra.py @@ -101,7 +101,9 @@ def add_nodes( if not isinstance(node, TextNode): raise ValueError("Only adding TextNode is supported at the moment") _nodes.append( - graph_store.TextNode(id=node.id, text=node.text, metadata=node.metadata) + graph_store.TextNode( + id=node.id, text=node.text, metadata=node.metadata, links=node.links + ) ) return self.store.add_nodes(_nodes) diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/edge_extractor.py b/libs/langchain/ragstack_langchain/graph_store/extractors/edge_extractor.py index 5042edf95..d13e090f3 100644 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/edge_extractor.py +++ b/libs/langchain/ragstack_langchain/graph_store/extractors/edge_extractor.py @@ -1,28 +1,25 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import Generic, Iterable, Iterator, Set, TypeVar +from typing import Generic, Iterable, TypeVar from langchain_core.documents import Document from ragstack_knowledge_store._utils import strict_zip -from ragstack_knowledge_store.link_tag import LinkTag InputT = TypeVar("InputT") class EdgeExtractor(ABC, Generic[InputT]): @abstractmethod - def extract_one(self, document: Document, input: InputT): + def extract_one(self, document: Document, input: InputT) -> None: """Add edges from each `input` to the corresponding documents. Args: document: Document to add the link tags to. - inputs: The input content to extract edges from. + input: The input content to extract edges from. """ - def extract( - self, documents: Iterable[Document], inputs: Iterable[InputT] - ) -> Iterator[Set[LinkTag]]: + def extract(self, documents: Iterable[Document], inputs: Iterable[InputT]) -> None: """Add edges from each `input` to the corresponding documents. Args: diff --git a/libs/langchain/ragstack_langchain/graph_store/extractors/html_link_edge_extractor.py b/libs/langchain/ragstack_langchain/graph_store/extractors/html_link_edge_extractor.py index e3c257fcc..1e6ce9a21 100644 --- a/libs/langchain/ragstack_langchain/graph_store/extractors/html_link_edge_extractor.py +++ b/libs/langchain/ragstack_langchain/graph_store/extractors/html_link_edge_extractor.py @@ -1,6 +1,6 @@ from langchain_core.documents import Document -from ragstack_knowledge_store.link_tag import ( - get_link_tags, +from ragstack_langchain.graph_store.links import ( + get_links, IncomingLinkTag, OutgoingLinkTag, ) @@ -68,7 +68,7 @@ def __init__( url_field: Name of the metadata field containing the URL of the content. Defaults to "source". kind: The kind of edge to extract. Defaults to "hyperlink". - drop_fragmets: Whether fragments in URLs and links shoud be + drop_fragments: Whether fragments in URLs and links shoud be dropped. Defaults to `True`. """ try: @@ -99,7 +99,7 @@ def extract_one( hrefs = _parse_hrefs(input, url, self.drop_fragments) - link_tags = get_link_tags(document.metadata) + link_tags = get_links(document) link_tags.add(IncomingLinkTag(kind=self._kind, tag=url)) for url in hrefs: link_tags.add(OutgoingLinkTag(kind=self._kind, tag=url)) diff --git a/libs/langchain/ragstack_langchain/graph_store/links.py b/libs/langchain/ragstack_langchain/graph_store/links.py new file mode 100644 index 000000000..32c9d3e06 --- /dev/null +++ b/libs/langchain/ragstack_langchain/graph_store/links.py @@ -0,0 +1,69 @@ +from dataclasses import dataclass +from typing import Literal, Dict, Any, Set, Union + +from langchain_core.documents import Document + + +@dataclass(frozen=True) +class Link: + kind: str + direction: Literal["incoming", "outgoing", "bidir"] + + def __post_init__(self): + if self.__class__ in [Link, LinkTag]: + raise TypeError( + f"Abstract class {self.__class__.__name__} cannot be instantiated" + ) + + +@dataclass(frozen=True) +class LinkTag(Link): + tag: str + + +@dataclass(frozen=True) +class OutgoingLinkTag(LinkTag): + def __init__(self, kind: str, tag: str) -> None: + super().__init__(kind=kind, tag=tag, direction="outgoing") + + +@dataclass(frozen=True) +class IncomingLinkTag(LinkTag): + def __init__(self, kind: str, tag: str) -> None: + super().__init__(kind=kind, tag=tag, direction="incoming") + + +@dataclass(frozen=True) +class BidirLinkTag(LinkTag): + def __init__(self, kind: str, tag: str) -> None: + super().__init__(kind=kind, tag=tag, direction="bidir") + + +METADATA_LINKS_KEY = "links" + + +def get_links(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[Link]: + """Get the links from a document or metadata. + Args: + doc_or_md: The metadata to get the link tags from. + Returns: + The set of link tags from the document or metadata. + """ + + if isinstance(doc_or_md, Document): + doc_or_md = doc_or_md.metadata + + links = doc_or_md.setdefault(METADATA_LINKS_KEY, set()) + if not isinstance(links, Set): + links = set(links) + doc_or_md[METADATA_LINKS_KEY] = links + return links + + +def add_links(doc_or_md: Union[Document, Dict[str, Any]], *links: Link) -> None: + """Add links to the given metadata. + Args: + doc_or_md: The document or metadata to add the links to. + *links: The links to add to the metadata. + """ + get_links(doc_or_md).update(links) diff --git a/libs/langchain/tests/integration_tests/conftest.py b/libs/langchain/tests/integration_tests/conftest.py index 64de5a383..3ba892efa 100644 --- a/libs/langchain/tests/integration_tests/conftest.py +++ b/libs/langchain/tests/integration_tests/conftest.py @@ -1,6 +1,10 @@ import pytest +from dotenv import load_dotenv from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore +load_dotenv() + + status = { "local_cassandra_test_store": None, "astradb_test_store": None, diff --git a/libs/langchain/tests/integration_tests/test_graph_store.py b/libs/langchain/tests/integration_tests/test_graph_store.py index bf2c2e604..ce867522d 100644 --- a/libs/langchain/tests/integration_tests/test_graph_store.py +++ b/libs/langchain/tests/integration_tests/test_graph_store.py @@ -13,12 +13,14 @@ _texts_to_nodes, TextNode, ) -from ragstack_knowledge_store.link_tag import ( +from ragstack_langchain.graph_store.links import ( + METADATA_LINKS_KEY, BidirLinkTag, IncomingLinkTag, OutgoingLinkTag, ) from ragstack_tests_utils.test_store import KEYSPACE + from .conftest import get_local_cassandra_test_store, get_astradb_test_store from ragstack_langchain.graph_store import CassandraGraphStore @@ -35,7 +37,7 @@ def __init__(self, session: Session, keyspace: str, embedding: Embeddings) -> No def store( self, - initial_documents: Iterable[Document] = [], + initial_documents: Iterable[Document] = (), ids: Optional[Iterable[str]] = None, embedding: Optional[Embeddings] = None, ) -> CassandraGraphStore: @@ -123,7 +125,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None: page_content="A", metadata={ "content_id": "a", - "link_tags": { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="hyperlink", tag="http://a"), }, }, @@ -132,7 +134,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None: page_content="B", metadata={ "content_id": "b", - "link_tags": { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="hyperlink", tag="http://b"), OutgoingLinkTag(kind="hyperlink", tag="http://a"), }, @@ -142,7 +144,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None: page_content="C", metadata={ "content_id": "c", - "link_tags": { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="hyperlink", tag="http://a"), }, }, @@ -151,7 +153,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None: page_content="D", metadata={ "content_id": "d", - "link_tags": { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="hyperlink", tag="http://a"), OutgoingLinkTag(kind="hyperlink", tag="http://b"), }, @@ -196,7 +198,7 @@ def test_mmr_traversal(request, gs_factory: str): page_content="-0.124", metadata={ "content_id": "v0", - "link_tags": { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="explicit", tag="link"), }, }, @@ -211,7 +213,7 @@ def test_mmr_traversal(request, gs_factory: str): page_content="+0.25", metadata={ "content_id": "v2", - "link_tags": { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="explicit", tag="link"), }, }, @@ -220,7 +222,7 @@ def test_mmr_traversal(request, gs_factory: str): page_content="+1.0", metadata={ "content_id": "v3", - "link_tags": { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="explicit", tag="link"), }, }, @@ -255,7 +257,7 @@ def test_write_retrieve_keywords(request, gs_factory: str): page_content="Typical Greetings", metadata={ "content_id": "greetings", - "link_tags": { + METADATA_LINKS_KEY: { IncomingLinkTag(kind="parent", tag="parent"), }, }, @@ -264,7 +266,7 @@ def test_write_retrieve_keywords(request, gs_factory: str): page_content="Hello World", metadata={ "content_id": "doc1", - "link_tags": { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="parent", tag="parent"), BidirLinkTag(kind="kw", tag="greeting"), BidirLinkTag(kind="kw", tag="world"), @@ -275,7 +277,7 @@ def test_write_retrieve_keywords(request, gs_factory: str): page_content="Hello Earth", metadata={ "content_id": "doc2", - "link_tags": { + METADATA_LINKS_KEY: { OutgoingLinkTag(kind="parent", tag="parent"), BidirLinkTag(kind="kw", tag="greeting"), BidirLinkTag(kind="kw", tag="earth"), @@ -320,6 +322,13 @@ def test_texts_to_nodes(): TextNode(metadata={"a": "b"}, text="a"), TextNode(metadata={"c": "d"}, text="b"), ] + assert list( + _texts_to_nodes( + ["a"], + [{"links": {IncomingLinkTag(kind="hyperlink", tag="http://b")}}], + None, + ) + ) == [TextNode(links={IncomingLinkTag(kind="hyperlink", tag="http://b")}, text="a")] with pytest.raises(ValueError): list(_texts_to_nodes(["a", "b"], None, ["a"])) with pytest.raises(ValueError): @@ -332,15 +341,23 @@ def test_texts_to_nodes(): def test_documents_to_nodes(): documents = [ - Document(page_content="a", metadata={"a": "b"}), + Document( + page_content="a", + metadata={"links": {IncomingLinkTag(kind="hyperlink", tag="http://b")}}, + ), Document(page_content="b", metadata={"c": "d"}), ] assert list(_documents_to_nodes(documents, ["a", "b"])) == [ - TextNode(id="a", metadata={"a": "b"}, text="a"), + TextNode( + id="a", + metadata={}, + links={IncomingLinkTag(kind="hyperlink", tag="http://b")}, + text="a", + ), TextNode(id="b", metadata={"c": "d"}, text="b"), ] assert list(_documents_to_nodes(documents, None)) == [ - TextNode(metadata={"a": "b"}, text="a"), + TextNode(links={IncomingLinkTag(kind="hyperlink", tag="http://b")}, text="a"), TextNode(metadata={"c": "d"}, text="b"), ] with pytest.raises(ValueError):