Skip to content

Commit

Permalink
Add Links to graph store Node
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jun 19, 2024
1 parent d9d878e commit 5775e7f
Show file tree
Hide file tree
Showing 9 changed files with 140 additions and 56 deletions.
16 changes: 10 additions & 6 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
)

Expand All @@ -18,7 +19,7 @@
from .concurrency import ConcurrentQueries
from .content import Kind
from .embedding_model import EmbeddingModel
from .link_tag import get_link_tags
from .link_tag import LinkTag
from .math import cosine_similarity

CONTENT_ID = "content_id"
Expand All @@ -31,8 +32,9 @@ class Node:
id: Optional[str] = None
"""Unique ID for the node. Will be generated by the GraphStore if not set."""
metadata: dict = field(default_factory=dict)
"""Metadata for the node. May contain information used to link this node
with other nodes."""
"""Metadata for the node."""
links: Set[LinkTag] = field(default_factory=set)
"""Links for the node."""


@dataclass
Expand Down Expand Up @@ -328,11 +330,13 @@ def add_nodes(
) -> Iterable[str]:
texts = []
metadatas = []
links: List[Set[LinkTag]] = []
for node in nodes:
if not isinstance(node, TextNode):
raise ValueError("Only adding TextNode is supported at the moment")
texts.append(node.text)
metadatas.append(node.metadata)
links.append(node.links)

text_embeddings = self._embedding.embed_texts(texts)

Expand All @@ -343,8 +347,8 @@ def add_nodes(

# Step 1: Add the nodes, collecting the tags and new sources / targets.
with self._concurrent_queries() as cq:
tuples = zip(texts, text_embeddings, metadatas)
for text, text_embedding, metadata in tuples:
tuples = zip(texts, text_embeddings, metadatas, links)
for text, text_embedding, metadata, _links in tuples:
if CONTENT_ID not in metadata:
metadata[CONTENT_ID] = secrets.token_hex(8)
id = metadata[CONTENT_ID]
Expand All @@ -353,7 +357,7 @@ def add_nodes(
link_to_tags = set() # link to these tags
link_from_tags = set() # link from these tags

for tag in get_link_tags(metadata):
for tag in _links:
tag_str = f"{tag.kind}:{tag.tag}"
if tag.direction == "incoming" or tag.direction == "bidir":
# An incom`ing link should be linked *from* nodes with the given tag.
Expand Down
19 changes: 0 additions & 19 deletions libs/knowledge-store/ragstack_knowledge_store/link_tag.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,22 +33,3 @@ def __init__(self, kind: str, tag: str) -> None:
class BidirLinkTag(LinkTag):
def __init__(self, kind: str, tag: str) -> None:
super().__init__(kind=kind, tag=tag, direction="bidir")


LINK_TAGS = "link_tags"


def get_link_tags(doc_or_md: Dict[str, Any]) -> Set[LinkTag]:
"""Get the link-tag set from a document or metadata.
Args:
doc_or_md: The document or metadata to get the link tags from.
Returns:
The set of link tags from the document or metadata.
"""
link_tags = doc_or_md.setdefault(LINK_TAGS, set())
if not isinstance(link_tags, Set):
link_tags = set(link_tags)
doc_or_md[LINK_TAGS] = link_tags
return link_tags
20 changes: 15 additions & 5 deletions libs/langchain/ragstack_langchain/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterator,
List,
Optional,
Set,
)

from langchain_core.callbacks import (
Expand All @@ -20,7 +21,9 @@
from langchain_core.load import Serializable
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from pydantic import Field
from langchain_core.pydantic_v1 import Field

from ragstack_langchain.graph_store.links import LinkTag, LINKS


def _has_next(iterator: Iterator) -> None:
Expand All @@ -36,8 +39,9 @@ class Node(Serializable):
id: Optional[str]
"""Unique ID for the node. Will be generated by the GraphStore if not set."""
metadata: dict = Field(default_factory=dict)
"""Metadata for the node. May contain information used to link this node
with other nodes."""
"""Metadata for the node."""
links: Set[LinkTag] = Field(default_factory=set)
"""Links associated with the node."""


class TextNode(Node):
Expand All @@ -54,17 +58,20 @@ def _texts_to_nodes(
ids_it = iter(ids) if ids else None
for text in texts:
try:
_metadata = next(metadatas_it) if metadatas_it else {}
_metadata = next(metadatas_it).copy() if metadatas_it else {}
except StopIteration:
raise ValueError("texts iterable longer than metadatas")
try:
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("texts iterable longer than ids")

links = _metadata.pop(LINKS, set())
yield TextNode(
id=_id,
metadata=_metadata,
text=text,
links=links,
)
if ids and _has_next(ids_it):
raise ValueError("ids iterable longer than texts")
Expand All @@ -81,10 +88,13 @@ def _documents_to_nodes(
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("documents iterable longer than ids")
metadata = doc.metadata.copy()
links = metadata.pop(LINKS, set())
yield TextNode(
id=_id,
metadata=doc.metadata,
metadata=metadata,
text=doc.page_content,
links=links,
)
if ids and _has_next(ids_it):
raise ValueError("ids iterable longer than documents")
Expand Down
4 changes: 3 additions & 1 deletion libs/langchain/ragstack_langchain/graph_store/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def add_nodes(
if not isinstance(node, TextNode):
raise ValueError("Only adding TextNode is supported at the moment")
_nodes.append(
graph_store.TextNode(id=node.id, text=node.text, metadata=node.metadata)
graph_store.TextNode(
id=node.id, text=node.text, metadata=node.metadata, links=node.links
)
)
return self.store.add_nodes(_nodes)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Generic, Iterable, Iterator, Set, TypeVar
from typing import Generic, Iterable, TypeVar

from langchain_core.documents import Document
from ragstack_knowledge_store._utils import strict_zip
from ragstack_knowledge_store.link_tag import LinkTag

InputT = TypeVar("InputT")


class EdgeExtractor(ABC, Generic[InputT]):
@abstractmethod
def extract_one(self, document: Document, input: InputT):
def extract_one(self, document: Document, input: InputT) -> None:
"""Add edges from each `input` to the corresponding documents.
Args:
document: Document to add the link tags to.
inputs: The input content to extract edges from.
input: The input content to extract edges from.
"""

def extract(
self, documents: Iterable[Document], inputs: Iterable[InputT]
) -> Iterator[Set[LinkTag]]:
def extract(self, documents: Iterable[Document], inputs: Iterable[InputT]) -> None:
"""Add edges from each `input` to the corresponding documents.
Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langchain_core.documents import Document
from ragstack_knowledge_store.link_tag import (
get_link_tags,
from ragstack_langchain.graph_store.links import (
get_links,
IncomingLinkTag,
OutgoingLinkTag,
)
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
url_field: Name of the metadata field containing the URL
of the content. Defaults to "source".
kind: The kind of edge to extract. Defaults to "hyperlink".
drop_fragmets: Whether fragments in URLs and links shoud be
drop_fragments: Whether fragments in URLs and links shoud be
dropped. Defaults to `True`.
"""
try:
Expand Down Expand Up @@ -99,7 +99,7 @@ def extract_one(

hrefs = _parse_hrefs(input, url, self.drop_fragments)

link_tags = get_link_tags(document.metadata)
link_tags = get_links(document)
link_tags.add(IncomingLinkTag(kind=self._kind, tag=url))
for url in hrefs:
link_tags.add(OutgoingLinkTag(kind=self._kind, tag=url))
69 changes: 69 additions & 0 deletions libs/langchain/ragstack_langchain/graph_store/links.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
from dataclasses import dataclass
from typing import Literal, Dict, Any, Set, Union

from langchain_core.documents import Document


@dataclass(frozen=True)
class Link:
kind: str
direction: Literal["incoming", "outgoing", "bidir"]

def __post_init__(self):
if self.__class__ in [Link, LinkTag]:
raise TypeError(
f"Abstract class {self.__class__.__name__} cannot be instantiated"
)


@dataclass(frozen=True)
class LinkTag(Link):
tag: str


@dataclass(frozen=True)
class OutgoingLinkTag(LinkTag):
def __init__(self, kind: str, tag: str) -> None:
super().__init__(kind=kind, tag=tag, direction="outgoing")


@dataclass(frozen=True)
class IncomingLinkTag(LinkTag):
def __init__(self, kind: str, tag: str) -> None:
super().__init__(kind=kind, tag=tag, direction="incoming")


@dataclass(frozen=True)
class BidirLinkTag(LinkTag):
def __init__(self, kind: str, tag: str) -> None:
super().__init__(kind=kind, tag=tag, direction="bidir")


LINKS = "links"


def get_links(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[Link]:
"""Get the links from a document or metadata.
Args:
doc_or_md: The metadata to get the link tags from.
Returns:
The set of link tags from the document or metadata.
"""

if isinstance(doc_or_md, Document):
doc_or_md = doc_or_md.metadata

links = doc_or_md.setdefault(LINKS, set())
if not isinstance(links, Set):
links = set(links)
doc_or_md[LINKS] = links
return links


def add_links(doc_or_md: Union[Document, Dict[str, Any]], *links: Link) -> None:
"""Add links to the given metadata.
Args:
doc_or_md: The document or metadata to add the links to.
*links: The links to add to the metadata.
"""
get_links(doc_or_md).update(links)
4 changes: 4 additions & 0 deletions libs/langchain/tests/integration_tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import pytest
from dotenv import load_dotenv
from ragstack_tests_utils import AstraDBTestStore, LocalCassandraTestStore

load_dotenv()


status = {
"local_cassandra_test_store": None,
"astradb_test_store": None,
Expand Down
Loading

0 comments on commit 5775e7f

Please sign in to comment.