Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Links to graph store Node #507

Merged
merged 2 commits into from
Jun 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 9 additions & 3 deletions .github/workflows/ci-unit-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,9 +139,15 @@ jobs:
rm -rf $dir/.tox
}

run_itests libs/colbert
run_itests libs/langchain
run_itests libs/llamaindex
if [[ "true" == "${{ needs.preconditions.outputs.libs_colbert }}" ]]; then
run_itests libs/colbert
fi
if [[ "true" == "${{ needs.preconditions.outputs.libs_langchain }}" ]]; then
run_itests libs/langchain
fi
if [[ "true" == "${{ needs.preconditions.outputs.libs_llamaindex }}" ]]; then
run_itests libs/llamaindex
fi

- name: Cleanup AstraDB
uses: nicoloboschi/cleanup-astradb@v1
Expand Down
44 changes: 25 additions & 19 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
NamedTuple,
Optional,
Sequence,
Set,
Tuple,
)

Expand All @@ -18,7 +19,7 @@
from .concurrency import ConcurrentQueries
from .content import Kind
from .embedding_model import EmbeddingModel
from .link_tag import get_link_tags
from .links import Link
from .math import cosine_similarity

CONTENT_ID = "content_id"
Expand All @@ -31,8 +32,9 @@ class Node:
id: Optional[str] = None
"""Unique ID for the node. Will be generated by the GraphStore if not set."""
metadata: dict = field(default_factory=dict)
"""Metadata for the node. May contain information used to link this node
with other nodes."""
"""Metadata for the node."""
links: Set[Link] = field(default_factory=set)
"""Links for the node."""


@dataclass
Expand Down Expand Up @@ -328,11 +330,13 @@ def add_nodes(
) -> Iterable[str]:
texts = []
metadatas = []
nodes_links: List[Set[Link]] = []
for node in nodes:
if not isinstance(node, TextNode):
raise ValueError("Only adding TextNode is supported at the moment")
texts.append(node.text)
metadatas.append(node.metadata)
nodes_links.append(node.links)

text_embeddings = self._embedding.embed_texts(texts)

Expand All @@ -343,8 +347,8 @@ def add_nodes(

# Step 1: Add the nodes, collecting the tags and new sources / targets.
with self._concurrent_queries() as cq:
tuples = zip(texts, text_embeddings, metadatas)
for text, text_embedding, metadata in tuples:
tuples = zip(texts, text_embeddings, metadatas, nodes_links)
for text, text_embedding, metadata, links in tuples:
if CONTENT_ID not in metadata:
metadata[CONTENT_ID] = secrets.token_hex(8)
id = metadata[CONTENT_ID]
Expand All @@ -353,20 +357,22 @@ def add_nodes(
link_to_tags = set() # link to these tags
link_from_tags = set() # link from these tags

for tag in get_link_tags(metadata):
tag_str = f"{tag.kind}:{tag.tag}"
if tag.direction == "incoming" or tag.direction == "bidir":
# An incom`ing link should be linked *from* nodes with the given tag.
link_from_tags.add(tag_str)
tag_to_new_targets.setdefault(tag_str, dict())[id] = (
tag.kind,
text_embedding,
)
if tag.direction == "outgoing" or tag.direction == "bidir":
link_to_tags.add(tag_str)
tag_to_new_sources.setdefault(tag_str, list()).append(
(tag.kind, id)
)
for tag in links:
if hasattr(tag, "tag"):
tag_str = f"{tag.kind}:{tag.tag}"
if tag.direction == "incoming" or tag.direction == "bidir":
# An incoming link should be linked *from* nodes with the
# given tag.
link_from_tags.add(tag_str)
tag_to_new_targets.setdefault(tag_str, dict())[id] = (
tag.kind,
text_embedding,
)
if tag.direction == "outgoing" or tag.direction == "bidir":
link_to_tags.add(tag_str)
tag_to_new_sources.setdefault(tag_str, list()).append(
(tag.kind, id)
)

cq.execute(
self._insert_passage,
Expand Down
54 changes: 0 additions & 54 deletions libs/knowledge-store/ragstack_knowledge_store/link_tag.py

This file was deleted.

37 changes: 37 additions & 0 deletions libs/knowledge-store/ragstack_knowledge_store/links.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from dataclasses import dataclass
from typing import Literal


@dataclass(frozen=True)
class Link:
kind: str
direction: Literal["incoming", "outgoing", "bidir"]

def __post_init__(self):
if self.__class__ in [Link, LinkTag]:
raise TypeError(
f"Abstract class {self.__class__.__name__} cannot be instantiated"
)


@dataclass(frozen=True)
class LinkTag(Link):
tag: str


@dataclass(frozen=True)
class OutgoingLinkTag(LinkTag):
def __init__(self, kind: str, tag: str) -> None:
super().__init__(kind=kind, tag=tag, direction="outgoing")


@dataclass(frozen=True)
class IncomingLinkTag(LinkTag):
def __init__(self, kind: str, tag: str) -> None:
super().__init__(kind=kind, tag=tag, direction="incoming")


@dataclass(frozen=True)
class BidirLinkTag(LinkTag):
def __init__(self, kind: str, tag: str) -> None:
super().__init__(kind=kind, tag=tag, direction="bidir")
20 changes: 15 additions & 5 deletions libs/langchain/ragstack_langchain/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Iterator,
List,
Optional,
Set,
)

from langchain_core.callbacks import (
Expand All @@ -20,7 +21,9 @@
from langchain_core.load import Serializable
from langchain_core.runnables import run_in_executor
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from pydantic import Field
from langchain_core.pydantic_v1 import Field

from ragstack_langchain.graph_store.links import METADATA_LINKS_KEY, Link


def _has_next(iterator: Iterator) -> None:
Expand All @@ -36,8 +39,9 @@ class Node(Serializable):
id: Optional[str]
"""Unique ID for the node. Will be generated by the GraphStore if not set."""
metadata: dict = Field(default_factory=dict)
"""Metadata for the node. May contain information used to link this node
with other nodes."""
"""Metadata for the node."""
links: Set[Link] = Field(default_factory=set)
"""Links associated with the node."""


class TextNode(Node):
Expand All @@ -54,17 +58,20 @@ def _texts_to_nodes(
ids_it = iter(ids) if ids else None
for text in texts:
try:
_metadata = next(metadatas_it) if metadatas_it else {}
_metadata = next(metadatas_it).copy() if metadatas_it else {}
except StopIteration:
raise ValueError("texts iterable longer than metadatas")
try:
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("texts iterable longer than ids")

links = _metadata.pop(METADATA_LINKS_KEY, set())
yield TextNode(
id=_id,
metadata=_metadata,
text=text,
links=links,
)
if ids and _has_next(ids_it):
raise ValueError("ids iterable longer than texts")
Expand All @@ -81,10 +88,13 @@ def _documents_to_nodes(
_id = next(ids_it) if ids_it else None
except StopIteration:
raise ValueError("documents iterable longer than ids")
metadata = doc.metadata.copy()
links = metadata.pop(METADATA_LINKS_KEY, set())
yield TextNode(
id=_id,
metadata=doc.metadata,
metadata=metadata,
text=doc.page_content,
links=links,
)
if ids and _has_next(ids_it):
raise ValueError("ids iterable longer than documents")
Expand Down
4 changes: 3 additions & 1 deletion libs/langchain/ragstack_langchain/graph_store/cassandra.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,9 @@ def add_nodes(
if not isinstance(node, TextNode):
raise ValueError("Only adding TextNode is supported at the moment")
_nodes.append(
graph_store.TextNode(id=node.id, text=node.text, metadata=node.metadata)
graph_store.TextNode(
id=node.id, text=node.text, metadata=node.metadata, links=node.links
)
)
return self.store.add_nodes(_nodes)

Expand Down
Original file line number Diff line number Diff line change
@@ -1,28 +1,25 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Generic, Iterable, Iterator, Set, TypeVar
from typing import Generic, Iterable, TypeVar

from langchain_core.documents import Document
from ragstack_knowledge_store._utils import strict_zip
from ragstack_knowledge_store.link_tag import LinkTag

InputT = TypeVar("InputT")


class EdgeExtractor(ABC, Generic[InputT]):
@abstractmethod
def extract_one(self, document: Document, input: InputT):
def extract_one(self, document: Document, input: InputT) -> None:
cbornet marked this conversation as resolved.
Show resolved Hide resolved
"""Add edges from each `input` to the corresponding documents.

Args:
document: Document to add the link tags to.
inputs: The input content to extract edges from.
input: The input content to extract edges from.
"""

def extract(
self, documents: Iterable[Document], inputs: Iterable[InputT]
) -> Iterator[Set[LinkTag]]:
def extract(self, documents: Iterable[Document], inputs: Iterable[InputT]) -> None:
"""Add edges from each `input` to the corresponding documents.

Args:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from langchain_core.documents import Document
from ragstack_knowledge_store.link_tag import (
get_link_tags,
from ragstack_langchain.graph_store.links import (
cbornet marked this conversation as resolved.
Show resolved Hide resolved
get_links,
IncomingLinkTag,
OutgoingLinkTag,
)
Expand Down Expand Up @@ -68,7 +68,7 @@ def __init__(
url_field: Name of the metadata field containing the URL
of the content. Defaults to "source".
kind: The kind of edge to extract. Defaults to "hyperlink".
drop_fragmets: Whether fragments in URLs and links shoud be
drop_fragments: Whether fragments in URLs and links shoud be
dropped. Defaults to `True`.
"""
try:
Expand Down Expand Up @@ -99,7 +99,7 @@ def extract_one(

hrefs = _parse_hrefs(input, url, self.drop_fragments)

link_tags = get_link_tags(document.metadata)
link_tags = get_links(document)
link_tags.add(IncomingLinkTag(kind=self._kind, tag=url))
for url in hrefs:
link_tags.add(OutgoingLinkTag(kind=self._kind, tag=url))
Loading