Skip to content

Commit

Permalink
community[minor]: Add SQL storage implementation (#22207)
Browse files Browse the repository at this point in the history
Hello @eyurtsev

- package: langchain-comminity
- **Description**: Add SQL implementation for docstore. A new
implementation, in line with my other PR ([async
PGVector](langchain-ai/langchain-postgres#32),
[SQLChatMessageMemory](#22065))
- Twitter handler: pprados

---------

Signed-off-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
Co-authored-by: Piotr Mardziel <piotrm@gmail.com>
Co-authored-by: ChengZi <chen.zhang@zilliz.com>
Co-authored-by: Eugene Yurtsev <eyurtsev@gmail.com>
  • Loading branch information
5 people authored Jun 7, 2024
1 parent f2f0e0e commit 9aabb44
Show file tree
Hide file tree
Showing 6 changed files with 548 additions and 1 deletion.
2 changes: 1 addition & 1 deletion docs/docs/integrations/vectorstores/milvus.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -390,4 +390,4 @@
},
"nbformat": 4,
"nbformat_minor": 5
}
}
5 changes: 5 additions & 0 deletions libs/community/langchain_community/storage/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@
from langchain_community.storage.redis import (
RedisStore,
)
from langchain_community.storage.sql import (
SQLStore,
)
from langchain_community.storage.upstash_redis import (
UpstashRedisByteStore,
UpstashRedisStore,
Expand All @@ -42,6 +45,7 @@
"CassandraByteStore",
"MongoDBStore",
"RedisStore",
"SQLStore",
"UpstashRedisByteStore",
"UpstashRedisStore",
]
Expand All @@ -52,6 +56,7 @@
"CassandraByteStore": "langchain_community.storage.cassandra",
"MongoDBStore": "langchain_community.storage.mongodb",
"RedisStore": "langchain_community.storage.redis",
"SQLStore": "langchain_community.storage.sql",
"UpstashRedisByteStore": "langchain_community.storage.upstash_redis",
"UpstashRedisStore": "langchain_community.storage.upstash_redis",
}
Expand Down
266 changes: 266 additions & 0 deletions libs/community/langchain_community/storage/sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import contextlib
from pathlib import Path
from typing import (
Any,
AsyncGenerator,
AsyncIterator,
Dict,
Generator,
Iterator,
List,
Optional,
Sequence,
Tuple,
Union,
cast,
)

from langchain_core.stores import BaseStore
from sqlalchemy import (
Engine,
LargeBinary,
and_,
create_engine,
delete,
select,
)
from sqlalchemy.ext.asyncio import (
AsyncEngine,
AsyncSession,
async_sessionmaker,
create_async_engine,
)
from sqlalchemy.orm import (
Mapped,
Session,
declarative_base,
mapped_column,
sessionmaker,
)

Base = declarative_base()


def items_equal(x: Any, y: Any) -> bool:
return x == y


class LangchainKeyValueStores(Base): # type: ignore[valid-type,misc]
"""Table used to save values."""

# ATTENTION:
# Prior to modifying this table, please determine whether
# we should create migrations for this table to make sure
# users do not experience data loss.
__tablename__ = "langchain_key_value_stores"

namespace: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
key: Mapped[str] = mapped_column(primary_key=True, index=True, nullable=False)
value = mapped_column(LargeBinary, index=False, nullable=False)


# This is a fix of original SQLStore.
# This can will be removed when a PR will be merged.
class SQLStore(BaseStore[str, bytes]):
"""BaseStore interface that works on an SQL database.
Examples:
Create a SQLStore instance and perform operations on it:
.. code-block:: python
from langchain_rag.storage import SQLStore
# Instantiate the SQLStore with the root path
sql_store = SQLStore(namespace="test", db_url="sqllite://:memory:")
# Set values for keys
sql_store.mset([("key1", b"value1"), ("key2", b"value2")])
# Get values for keys
values = sql_store.mget(["key1", "key2"]) # Returns [b"value1", b"value2"]
# Delete keys
sql_store.mdelete(["key1"])
# Iterate over keys
for key in sql_store.yield_keys():
print(key)
"""

def __init__(
self,
*,
namespace: str,
db_url: Optional[Union[str, Path]] = None,
engine: Optional[Union[Engine, AsyncEngine]] = None,
engine_kwargs: Optional[Dict[str, Any]] = None,
async_mode: Optional[bool] = None,
):
if db_url is None and engine is None:
raise ValueError("Must specify either db_url or engine")

if db_url is not None and engine is not None:
raise ValueError("Must specify either db_url or engine, not both")

_engine: Union[Engine, AsyncEngine]
if db_url:
if async_mode is None:
async_mode = False
if async_mode:
_engine = create_async_engine(
url=str(db_url),
**(engine_kwargs or {}),
)
else:
_engine = create_engine(url=str(db_url), **(engine_kwargs or {}))
elif engine:
_engine = engine

else:
raise AssertionError("Something went wrong with configuration of engine.")

_session_maker: Union[sessionmaker[Session], async_sessionmaker[AsyncSession]]
if isinstance(_engine, AsyncEngine):
self.async_mode = True
_session_maker = async_sessionmaker(bind=_engine)
else:
self.async_mode = False
_session_maker = sessionmaker(bind=_engine)

self.engine = _engine
self.dialect = _engine.dialect.name
self.session_maker = _session_maker
self.namespace = namespace

def create_schema(self) -> None:
Base.metadata.create_all(self.engine)

async def acreate_schema(self) -> None:
assert isinstance(self.engine, AsyncEngine)
async with self.engine.begin() as session:
await session.run_sync(Base.metadata.create_all)

def drop(self) -> None:
Base.metadata.drop_all(bind=self.engine.connect())

async def amget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
assert isinstance(self.engine, AsyncEngine)
result: Dict[str, bytes] = {}
async with self._make_async_session() as session:
stmt = select(LangchainKeyValueStores).filter(
and_(
LangchainKeyValueStores.key.in_(keys),
LangchainKeyValueStores.namespace == self.namespace,
)
)
for v in await session.scalars(stmt):
result[v.key] = v.value
return [result.get(key) for key in keys]

def mget(self, keys: Sequence[str]) -> List[Optional[bytes]]:
result = {}

with self._make_sync_session() as session:
stmt = select(LangchainKeyValueStores).filter(
and_(
LangchainKeyValueStores.key.in_(keys),
LangchainKeyValueStores.namespace == self.namespace,
)
)
for v in session.scalars(stmt):
result[v.key] = v.value
return [result.get(key) for key in keys]

async def amset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
async with self._make_async_session() as session:
await self._amdelete([key for key, _ in key_value_pairs], session)
session.add_all(
[
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
for k, v in key_value_pairs
]
)
await session.commit()

def mset(self, key_value_pairs: Sequence[Tuple[str, bytes]]) -> None:
values: Dict[str, bytes] = dict(key_value_pairs)
with self._make_sync_session() as session:
self._mdelete(list(values.keys()), session)
session.add_all(
[
LangchainKeyValueStores(namespace=self.namespace, key=k, value=v)
for k, v in values.items()
]
)
session.commit()

def _mdelete(self, keys: Sequence[str], session: Session) -> None:
stmt = delete(LangchainKeyValueStores).filter(
and_(
LangchainKeyValueStores.key.in_(keys),
LangchainKeyValueStores.namespace == self.namespace,
)
)
session.execute(stmt)

async def _amdelete(self, keys: Sequence[str], session: AsyncSession) -> None:
stmt = delete(LangchainKeyValueStores).filter(
and_(
LangchainKeyValueStores.key.in_(keys),
LangchainKeyValueStores.namespace == self.namespace,
)
)
await session.execute(stmt)

def mdelete(self, keys: Sequence[str]) -> None:
with self._make_sync_session() as session:
self._mdelete(keys, session)
session.commit()

async def amdelete(self, keys: Sequence[str]) -> None:
async with self._make_async_session() as session:
await self._amdelete(keys, session)
await session.commit()

def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]:
with self._make_sync_session() as session:
for v in session.query(LangchainKeyValueStores).filter( # type: ignore
LangchainKeyValueStores.namespace == self.namespace
):
if str(v.key).startswith(prefix or ""):
yield str(v.key)
session.close()

async def ayield_keys(self, *, prefix: Optional[str] = None) -> AsyncIterator[str]:
async with self._make_async_session() as session:
stmt = select(LangchainKeyValueStores).filter(
LangchainKeyValueStores.namespace == self.namespace
)
for v in await session.scalars(stmt):
if str(v.key).startswith(prefix or ""):
yield str(v.key)
await session.close()

@contextlib.contextmanager
def _make_sync_session(self) -> Generator[Session, None, None]:
"""Make an async session."""
if self.async_mode:
raise ValueError(
"Attempting to use a sync method in when async mode is turned on. "
"Please use the corresponding async method instead."
)
with cast(Session, self.session_maker()) as session:
yield cast(Session, session)

@contextlib.asynccontextmanager
async def _make_async_session(self) -> AsyncGenerator[AsyncSession, None]:
"""Make an async session."""
if not self.async_mode:
raise ValueError(
"Attempting to use an async method in when sync mode is turned on. "
"Please use the corresponding async method instead."
)
async with cast(AsyncSession, self.session_maker()) as session:
yield cast(AsyncSession, session)
Loading

0 comments on commit 9aabb44

Please sign in to comment.