diff --git a/libs/community/langchain_community/vectorstores/supabase.py b/libs/community/langchain_community/vectorstores/supabase.py index b12f6c22e3c6b..1211c5c9f86b1 100644 --- a/libs/community/langchain_community/vectorstores/supabase.py +++ b/libs/community/langchain_community/vectorstores/supabase.py @@ -18,7 +18,7 @@ import numpy as np from langchain_core.documents import Document from langchain_core.embeddings import Embeddings -from langchain_core.vectorstores import VectorStore +from langchain_core.vectorstores import VectorStore, VST from langchain_community.vectorstores.utils import maximal_marginal_relevance @@ -210,8 +210,9 @@ def similarity_search_with_relevance_scores( vector, k=k, filter=filter, **kwargs ) + @staticmethod def match_args( - self, query: List[float], filter: Optional[Dict[str, Any]] + query: List[float], filter: Optional[Dict[str, Any]] ) -> Dict[str, Any]: ret: Dict[str, Any] = dict(query_embedding=query) if filter: @@ -226,7 +227,7 @@ def similarity_search_by_vector_with_relevance_scores( postgrest_filter: Optional[str] = None, score_threshold: Optional[float] = None, ) -> List[Tuple[Document, float]]: - match_documents_params = self.match_args(query, filter) + match_documents_params = match_args(query, filter) query_builder = self._client.rpc(self.query_name, match_documents_params) if postgrest_filter: @@ -356,6 +357,46 @@ def _add_vectors( return id_list + # TODO extract common code + @staticmethod + async def _aadd_vectors( + client: supabase.client.AsyncClient, + table_name: str, + vectors: List[List[float]], + documents: List[Document], + ids: List[str], + chunk_size: int, + **kwargs: Any, + ) -> List[str]: + """Add vectors to Supabase table.""" + + rows: List[Dict[str, Any]] = [ + { + "id": ids[idx], + "content": documents[idx].page_content, + "embedding": embedding, + "metadata": documents[idx].metadata, # type: ignore + **kwargs, + } + for idx, embedding in enumerate(vectors) + ] + id_list: List[str] = [] + for i in range(0, len(rows), chunk_size): + chunk = rows[i: i + chunk_size] + + upsert_call = client.from_(table_name).upsert(chunk) + result = await upsert_call.execute() # type: ignore + + if len(result.data) == 0: + raise Exception("Error inserting: No rows added") + + # VectorStore.add_vectors returns ids as strings + ids = [str(i.get("id")) for i in result.data if i.get("id")] + + id_list.extend(ids) + + return id_list + def max_marginal_relevance_search_by_vector( self, embedding: List[float], @@ -481,3 +522,130 @@ def delete(self, ids: Optional[List[str]] = None, **kwargs: Any) -> None: # TODO: Check if this can be done in bulk for row in rows: self._client.from_(self.table_name).delete().eq("id", row["id"]).execute() + + +class AsyncSupabaseVectorStore(VectorStore): + def __init__( + self, + client: supabase.client.AsyncClient, + embedding: Embeddings, + table_name: str, + chunk_size: int = 500, + query_name: Union[str, None] = None, + ) -> None: + """Initialize with supabase client.""" + try: + import supabase # noqa: F401 + except ImportError: + raise ImportError( + "Could not import supabase python package. " + "Please install it with `pip install supabase`." + ) + + self._client = client + self._embedding: Embeddings = embedding + self.table_name = table_name or "documents" + self.query_name = query_name or "match_documents" + self.chunk_size = chunk_size or 500 + + @classmethod + def from_texts(cls: type[VST], texts: list[str], embedding: Embeddings, metadatas: Optional[list[dict]] = None, *, + ids: Optional[list[str]] = None, **kwargs: Any) -> VST: + raise NotImplemented() + + + #TODO figure out, assert kwargs["ids"] precedence + async def aadd_documents(self, documents: list[Document], **kwargs: Any) -> list[str]: + if "ids" not in kwargs: + ids = [doc.id or str(uuid.uuid4()) for doc in documents] + else: + ids = kwargs["ids"] + vectors = await self._embedding.aembed_documents([doc.page_content for doc in documents]) + return await SupabaseVectorStore._aadd_vectors( + self._client, self.table_name, vectors, documents, ids, self.chunk_size) + + async def asimilarity_search( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + vector = await self._embedding.aembed_query(query) + return await self.similarity_search_by_vector(vector, k=k, filter=filter, **kwargs) + + async def asimilarity_search_by_vector( + self, + embedding: List[float], + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Document]: + result = await self.asimilarity_search_by_vector_with_relevance_scores( + embedding, k=k, filter=filter, **kwargs + ) + + documents = [doc for doc, _ in result] + + return documents + + async def asimilarity_search_with_relevance_scores( + self, + query: str, + k: int = 4, + filter: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> List[Tuple[Document, float]]: + vector = await self._embedding.embed_query(query) + return await self.asimilarity_search_by_vector_with_relevance_scores( + vector, k=k, filter=filter, **kwargs + ) + + + async def asimilarity_search_by_vector_with_relevance_scores( + self, + query: List[float], + k: int, + filter: Optional[Dict[str, Any]] = None, + postgrest_filter: Optional[str] = None, + score_threshold: Optional[float] = None, + ) -> List[Tuple[Document, float]]: + match_documents_params = SupabaseVectorStore.match_args(query, filter) + query_builder = await self._client.rpc(self.query_name, match_documents_params) + + if postgrest_filter: + query_builder.params = query_builder.params.set( + "and", f"({postgrest_filter})" + ) + + query_builder.params = query_builder.params.set("limit", k) + + res = query_builder.execute() + + match_result = [ + ( + Document( + metadata=search.get("metadata", {}), # type: ignore + page_content=search.get("content", ""), + ), + search.get("similarity", 0.0), + ) + for search in res.data + if search.get("content") + ] + + if score_threshold is not None: + match_result = [ + (doc, similarity) + for doc, similarity in match_result + if similarity >= score_threshold + ] + if len(match_result) == 0: + warnings.warn( + "No relevant docs were retrieved using the relevance score" + f" threshold {score_threshold}" + ) + + return match_result + + diff --git a/libs/community/tests/unit_tests/vectorstores/test_asyncsupabase.py b/libs/community/tests/unit_tests/vectorstores/test_asyncsupabase.py new file mode 100644 index 0000000000000..1195a416c5c58 --- /dev/null +++ b/libs/community/tests/unit_tests/vectorstores/test_asyncsupabase.py @@ -0,0 +1,65 @@ +from typing import Optional, Dict, Any, List +from unittest.mock import patch, Mock, AsyncMock + +import numpy +import pytest + +from langchain_community.embeddings import FakeEmbeddings +from langchain_community.vectorstores.supabase import AsyncSupabaseVectorStore + + +async def create_vector_store() -> AsyncSupabaseVectorStore: + from supabase.client import acreate_client + import os + return AsyncSupabaseVectorStore( + client=await acreate_client(os.environ["my_supabase_url"], os.environ["my_supabase_key"]), + embedding=FakeEmbeddings(size=3), + table_name="documents", + #query_name="match_documents", + ) + +@pytest.mark.requires("supabase") +async def test_ids_used_correctly() -> None: + """Check whether vector store uses the document ids when provided with them.""" + from langchain_core.documents import Document + + documents = [ + Document(id="id1", + page_content="page zero Lorem Ipsum", + metadata={"source": "document.pdf", "page": 0, "id": "ID-document-1"}, + ), + Document( + id="id2", + page_content="page one Lorem Ipsum Dolor sit ameit", + metadata={"source": "document.pdf", "page": 1, "id": "ID-document-2"}, + ), + ] + ids_provided = [i.id for i in documents] + table_mock = Mock(name="from_()") + mock_upsert = AsyncMock() + table_mock.upsert.return_value = mock_upsert + mock_result = Mock() + mock_upsert.execute.return_value = mock_result + mock_result.data=[{"id":"id1"}, {"id": "id2"}] + + import supabase + with patch.object( + supabase._async.client.AsyncClient, "from_",return_value=table_mock) as from_mock: + # ), patch.object(supabase._async.client.AsyncClient, "get_index", mock_default_index): + #: + vector_store = await create_vector_store() + ids_used_at_upload = await vector_store.aadd_documents(documents, ids=ids_provided) + assert len(ids_provided) == len(ids_used_at_upload) + assert ids_provided == ids_used_at_upload + from_mock.assert_called_once_with("documents") + table_mock.upsert.assert_called_once() + list_submitted = table_mock.upsert.call_args.args[0] + assert len(list_submitted) == 2 + assert [d["id"] for d in list_submitted] == [d.id for d in documents] + assert [d["content"] for d in list_submitted] == [d.page_content for d in documents] + for obj in list_submitted: + assert len(obj["embedding"])==3 + assert all(type(v)==type(numpy.float64(0.2)) for v in obj["embedding"]) + assert [d["metadata"] for d in list_submitted] == [d.metadata for d in documents] + +