Skip to content

Commit

Permalink
Changes following review
Browse files Browse the repository at this point in the history
  • Loading branch information
cbornet committed Jun 20, 2024
1 parent 5775e7f commit 84ac4a4
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 47 deletions.
42 changes: 22 additions & 20 deletions libs/knowledge-store/ragstack_knowledge_store/graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from .concurrency import ConcurrentQueries
from .content import Kind
from .embedding_model import EmbeddingModel
from .link_tag import LinkTag
from .links import Link
from .math import cosine_similarity

CONTENT_ID = "content_id"
Expand All @@ -33,7 +33,7 @@ class Node:
"""Unique ID for the node. Will be generated by the GraphStore if not set."""
metadata: dict = field(default_factory=dict)
"""Metadata for the node."""
links: Set[LinkTag] = field(default_factory=set)
links: Set[Link] = field(default_factory=set)
"""Links for the node."""


Expand Down Expand Up @@ -330,13 +330,13 @@ def add_nodes(
) -> Iterable[str]:
texts = []
metadatas = []
links: List[Set[LinkTag]] = []
nodes_links: List[Set[Link]] = []
for node in nodes:
if not isinstance(node, TextNode):
raise ValueError("Only adding TextNode is supported at the moment")
texts.append(node.text)
metadatas.append(node.metadata)
links.append(node.links)
nodes_links.append(node.links)

text_embeddings = self._embedding.embed_texts(texts)

Expand All @@ -347,8 +347,8 @@ def add_nodes(

# Step 1: Add the nodes, collecting the tags and new sources / targets.
with self._concurrent_queries() as cq:
tuples = zip(texts, text_embeddings, metadatas, links)
for text, text_embedding, metadata, _links in tuples:
tuples = zip(texts, text_embeddings, metadatas, nodes_links)
for text, text_embedding, metadata, links in tuples:
if CONTENT_ID not in metadata:
metadata[CONTENT_ID] = secrets.token_hex(8)
id = metadata[CONTENT_ID]
Expand All @@ -357,20 +357,22 @@ def add_nodes(
link_to_tags = set() # link to these tags
link_from_tags = set() # link from these tags

for tag in _links:
tag_str = f"{tag.kind}:{tag.tag}"
if tag.direction == "incoming" or tag.direction == "bidir":
# An incom`ing link should be linked *from* nodes with the given tag.
link_from_tags.add(tag_str)
tag_to_new_targets.setdefault(tag_str, dict())[id] = (
tag.kind,
text_embedding,
)
if tag.direction == "outgoing" or tag.direction == "bidir":
link_to_tags.add(tag_str)
tag_to_new_sources.setdefault(tag_str, list()).append(
(tag.kind, id)
)
for tag in links:
if hasattr(tag, "tag"):
tag_str = f"{tag.kind}:{tag.tag}"
if tag.direction == "incoming" or tag.direction == "bidir":
# An incoming link should be linked *from* nodes with the
# given tag.
link_from_tags.add(tag_str)
tag_to_new_targets.setdefault(tag_str, dict())[id] = (
tag.kind,
text_embedding,
)
if tag.direction == "outgoing" or tag.direction == "bidir":
link_to_tags.add(tag_str)
tag_to_new_sources.setdefault(tag_str, list()).append(
(tag.kind, id)
)

cq.execute(
self._insert_passage,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,20 +1,22 @@
from dataclasses import dataclass
from typing import Literal, Dict, Any, Set
from typing import Literal


@dataclass(frozen=True)
class _LinkTag:
class Link:
kind: str
tag: str
direction: Literal["incoming", "outgoing", "bidir"]

def __post_init__(self):
if self.__class__ in [Link, LinkTag]:
raise TypeError(
f"Abstract class {self.__class__.__name__} cannot be instantiated"
)


@dataclass(frozen=True)
class LinkTag(_LinkTag):
def __init__(self, kind: str, tag: str, direction: str) -> None:
if self.__class__ == LinkTag:
raise TypeError("Abstract class LinkTag cannot be instantiated")
super().__init__(kind, tag, direction)
class LinkTag(Link):
tag: str


@dataclass(frozen=True)
Expand Down
8 changes: 4 additions & 4 deletions libs/langchain/ragstack_langchain/graph_store/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from langchain_core.vectorstores import VectorStore, VectorStoreRetriever
from langchain_core.pydantic_v1 import Field

from ragstack_langchain.graph_store.links import LinkTag, LINKS
from ragstack_langchain.graph_store.links import METADATA_LINKS_KEY, Link


def _has_next(iterator: Iterator) -> None:
Expand All @@ -40,7 +40,7 @@ class Node(Serializable):
"""Unique ID for the node. Will be generated by the GraphStore if not set."""
metadata: dict = Field(default_factory=dict)
"""Metadata for the node."""
links: Set[LinkTag] = Field(default_factory=set)
links: Set[Link] = Field(default_factory=set)
"""Links associated with the node."""


Expand All @@ -66,7 +66,7 @@ def _texts_to_nodes(
except StopIteration:
raise ValueError("texts iterable longer than ids")

links = _metadata.pop(LINKS, set())
links = _metadata.pop(METADATA_LINKS_KEY, set())
yield TextNode(
id=_id,
metadata=_metadata,
Expand All @@ -89,7 +89,7 @@ def _documents_to_nodes(
except StopIteration:
raise ValueError("documents iterable longer than ids")
metadata = doc.metadata.copy()
links = metadata.pop(LINKS, set())
links = metadata.pop(METADATA_LINKS_KEY, set())
yield TextNode(
id=_id,
metadata=metadata,
Expand Down
6 changes: 3 additions & 3 deletions libs/langchain/ragstack_langchain/graph_store/links.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def __init__(self, kind: str, tag: str) -> None:
super().__init__(kind=kind, tag=tag, direction="bidir")


LINKS = "links"
METADATA_LINKS_KEY = "links"


def get_links(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[Link]:
Expand All @@ -53,10 +53,10 @@ def get_links(doc_or_md: Union[Document, Dict[str, Any]]) -> Set[Link]:
if isinstance(doc_or_md, Document):
doc_or_md = doc_or_md.metadata

links = doc_or_md.setdefault(LINKS, set())
links = doc_or_md.setdefault(METADATA_LINKS_KEY, set())
if not isinstance(links, Set):
links = set(links)
doc_or_md[LINKS] = links
doc_or_md[METADATA_LINKS_KEY] = links
return links


Expand Down
24 changes: 12 additions & 12 deletions libs/langchain/tests/integration_tests/test_graph_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
TextNode,
)
from ragstack_langchain.graph_store.links import (
LINKS,
METADATA_LINKS_KEY,
BidirLinkTag,
IncomingLinkTag,
OutgoingLinkTag,
Expand All @@ -37,7 +37,7 @@ def __init__(self, session: Session, keyspace: str, embedding: Embeddings) -> No

def store(
self,
initial_documents: Iterable[Document] = [],
initial_documents: Iterable[Document] = (),
ids: Optional[Iterable[str]] = None,
embedding: Optional[Embeddings] = None,
) -> CassandraGraphStore:
Expand Down Expand Up @@ -125,7 +125,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None:
page_content="A",
metadata={
"content_id": "a",
LINKS: {
METADATA_LINKS_KEY: {
IncomingLinkTag(kind="hyperlink", tag="http://a"),
},
},
Expand All @@ -134,7 +134,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None:
page_content="B",
metadata={
"content_id": "b",
LINKS: {
METADATA_LINKS_KEY: {
IncomingLinkTag(kind="hyperlink", tag="http://b"),
OutgoingLinkTag(kind="hyperlink", tag="http://a"),
},
Expand All @@ -144,7 +144,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None:
page_content="C",
metadata={
"content_id": "c",
LINKS: {
METADATA_LINKS_KEY: {
OutgoingLinkTag(kind="hyperlink", tag="http://a"),
},
},
Expand All @@ -153,7 +153,7 @@ def test_link_directed(cassandra: GraphStoreFactory) -> None:
page_content="D",
metadata={
"content_id": "d",
LINKS: {
METADATA_LINKS_KEY: {
OutgoingLinkTag(kind="hyperlink", tag="http://a"),
OutgoingLinkTag(kind="hyperlink", tag="http://b"),
},
Expand Down Expand Up @@ -198,7 +198,7 @@ def test_mmr_traversal(request, gs_factory: str):
page_content="-0.124",
metadata={
"content_id": "v0",
LINKS: {
METADATA_LINKS_KEY: {
OutgoingLinkTag(kind="explicit", tag="link"),
},
},
Expand All @@ -213,7 +213,7 @@ def test_mmr_traversal(request, gs_factory: str):
page_content="+0.25",
metadata={
"content_id": "v2",
LINKS: {
METADATA_LINKS_KEY: {
IncomingLinkTag(kind="explicit", tag="link"),
},
},
Expand All @@ -222,7 +222,7 @@ def test_mmr_traversal(request, gs_factory: str):
page_content="+1.0",
metadata={
"content_id": "v3",
LINKS: {
METADATA_LINKS_KEY: {
IncomingLinkTag(kind="explicit", tag="link"),
},
},
Expand Down Expand Up @@ -257,7 +257,7 @@ def test_write_retrieve_keywords(request, gs_factory: str):
page_content="Typical Greetings",
metadata={
"content_id": "greetings",
LINKS: {
METADATA_LINKS_KEY: {
IncomingLinkTag(kind="parent", tag="parent"),
},
},
Expand All @@ -266,7 +266,7 @@ def test_write_retrieve_keywords(request, gs_factory: str):
page_content="Hello World",
metadata={
"content_id": "doc1",
LINKS: {
METADATA_LINKS_KEY: {
OutgoingLinkTag(kind="parent", tag="parent"),
BidirLinkTag(kind="kw", tag="greeting"),
BidirLinkTag(kind="kw", tag="world"),
Expand All @@ -277,7 +277,7 @@ def test_write_retrieve_keywords(request, gs_factory: str):
page_content="Hello Earth",
metadata={
"content_id": "doc2",
LINKS: {
METADATA_LINKS_KEY: {
OutgoingLinkTag(kind="parent", tag="parent"),
BidirLinkTag(kind="kw", tag="greeting"),
BidirLinkTag(kind="kw", tag="earth"),
Expand Down

0 comments on commit 84ac4a4

Please sign in to comment.