Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DRAFT just checking async supabase #21467 #28893

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 171 additions & 3 deletions libs/community/langchain_community/vectorstores/supabase.py
Original file line number Diff line number Diff line change
@@ -1,28 +1,28 @@
from __future__ import annotations

import uuid
import warnings
from itertools import repeat
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
List,
Optional,
Tuple,
Type,
Union,
)

import numpy as np
from langchain_core.documents import Document
from langchain_core.embeddings import Embeddings
from langchain_core.vectorstores import VectorStore
from langchain_core.vectorstores import VectorStore, VST

from langchain_community.vectorstores.utils import maximal_marginal_relevance

if TYPE_CHECKING:

Check failure on line 25 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (I001)

langchain_community/vectorstores/supabase.py:1:1: I001 Import block is un-sorted or un-formatted

Check failure on line 25 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (I001)

langchain_community/vectorstores/supabase.py:1:1: I001 Import block is un-sorted or un-formatted
import supabase


Expand Down Expand Up @@ -210,8 +210,9 @@
vector, k=k, filter=filter, **kwargs
)

@staticmethod
def match_args(
self, query: List[float], filter: Optional[Dict[str, Any]]
query: List[float], filter: Optional[Dict[str, Any]]
) -> Dict[str, Any]:
ret: Dict[str, Any] = dict(query_embedding=query)
if filter:
Expand All @@ -226,7 +227,7 @@
postgrest_filter: Optional[str] = None,
score_threshold: Optional[float] = None,
) -> List[Tuple[Document, float]]:
match_documents_params = self.match_args(query, filter)
match_documents_params = match_args(query, filter)

Check failure on line 230 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (F821)

langchain_community/vectorstores/supabase.py:230:34: F821 Undefined name `match_args`

Check failure on line 230 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (F821)

langchain_community/vectorstores/supabase.py:230:34: F821 Undefined name `match_args`
query_builder = self._client.rpc(self.query_name, match_documents_params)

if postgrest_filter:
Expand Down Expand Up @@ -356,6 +357,46 @@

return id_list

# TODO extract common code
@staticmethod
async def _aadd_vectors(
client: supabase.client.AsyncClient,
table_name: str,
vectors: List[List[float]],
documents: List[Document],
ids: List[str],
chunk_size: int,
**kwargs: Any,
) -> List[str]:
"""Add vectors to Supabase table."""

rows: List[Dict[str, Any]] = [
{
"id": ids[idx],
"content": documents[idx].page_content,
"embedding": embedding,
"metadata": documents[idx].metadata, # type: ignore
**kwargs,
}
for idx, embedding in enumerate(vectors)
]
id_list: List[str] = []
for i in range(0, len(rows), chunk_size):
chunk = rows[i: i + chunk_size]

upsert_call = client.from_(table_name).upsert(chunk)
result = await upsert_call.execute() # type: ignore

if len(result.data) == 0:
raise Exception("Error inserting: No rows added")

# VectorStore.add_vectors returns ids as strings
ids = [str(i.get("id")) for i in result.data if i.get("id")]

id_list.extend(ids)

return id_list

def max_marginal_relevance_search_by_vector(
self,
embedding: List[float],
Expand Down Expand Up @@ -481,3 +522,130 @@
# TODO: Check if this can be done in bulk
for row in rows:
self._client.from_(self.table_name).delete().eq("id", row["id"]).execute()


class AsyncSupabaseVectorStore(VectorStore):
def __init__(
self,
client: supabase.client.AsyncClient,
embedding: Embeddings,
table_name: str,
chunk_size: int = 500,
query_name: Union[str, None] = None,
) -> None:
"""Initialize with supabase client."""
try:
import supabase # noqa: F401
except ImportError:
raise ImportError(
"Could not import supabase python package. "
"Please install it with `pip install supabase`."
)

self._client = client
self._embedding: Embeddings = embedding
self.table_name = table_name or "documents"
self.query_name = query_name or "match_documents"
self.chunk_size = chunk_size or 500

@classmethod
def from_texts(cls: type[VST], texts: list[str], embedding: Embeddings, metadatas: Optional[list[dict]] = None, *,

Check failure on line 552 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/vectorstores/supabase.py:552:89: E501 Line too long (118 > 88)

Check failure on line 552 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/vectorstores/supabase.py:552:89: E501 Line too long (118 > 88)
ids: Optional[list[str]] = None, **kwargs: Any) -> VST:
raise NotImplemented()

Check failure on line 554 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (F901)

langchain_community/vectorstores/supabase.py:554:15: F901 `raise NotImplemented` should be `raise NotImplementedError`

Check failure on line 554 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (F901)

langchain_community/vectorstores/supabase.py:554:15: F901 `raise NotImplemented` should be `raise NotImplementedError`


#TODO figure out, assert kwargs["ids"] precedence
async def aadd_documents(self, documents: list[Document], **kwargs: Any) -> list[str]:

Check failure on line 558 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/vectorstores/supabase.py:558:89: E501 Line too long (90 > 88)

Check failure on line 558 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/vectorstores/supabase.py:558:89: E501 Line too long (90 > 88)
if "ids" not in kwargs:
ids = [doc.id or str(uuid.uuid4()) for doc in documents]
else:
ids = kwargs["ids"]
vectors = await self._embedding.aembed_documents([doc.page_content for doc in documents])

Check failure on line 563 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/vectorstores/supabase.py:563:89: E501 Line too long (97 > 88)

Check failure on line 563 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/vectorstores/supabase.py:563:89: E501 Line too long (97 > 88)
return await SupabaseVectorStore._aadd_vectors(
self._client, self.table_name, vectors, documents, ids, self.chunk_size)

async def asimilarity_search(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
vector = await self._embedding.aembed_query(query)
return await self.similarity_search_by_vector(vector, k=k, filter=filter, **kwargs)

Check failure on line 575 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.13

Ruff (E501)

langchain_community/vectorstores/supabase.py:575:89: E501 Line too long (91 > 88)

Check failure on line 575 in libs/community/langchain_community/vectorstores/supabase.py

View workflow job for this annotation

GitHub Actions / cd libs/community / make lint #3.9

Ruff (E501)

langchain_community/vectorstores/supabase.py:575:89: E501 Line too long (91 > 88)

async def asimilarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Document]:
result = await self.asimilarity_search_by_vector_with_relevance_scores(
embedding, k=k, filter=filter, **kwargs
)

documents = [doc for doc, _ in result]

return documents

async def asimilarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
filter: Optional[Dict[str, Any]] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
vector = await self._embedding.embed_query(query)
return await self.asimilarity_search_by_vector_with_relevance_scores(
vector, k=k, filter=filter, **kwargs
)


async def asimilarity_search_by_vector_with_relevance_scores(
self,
query: List[float],
k: int,
filter: Optional[Dict[str, Any]] = None,
postgrest_filter: Optional[str] = None,
score_threshold: Optional[float] = None,
) -> List[Tuple[Document, float]]:
match_documents_params = SupabaseVectorStore.match_args(query, filter)
query_builder = await self._client.rpc(self.query_name, match_documents_params)

if postgrest_filter:
query_builder.params = query_builder.params.set(
"and", f"({postgrest_filter})"
)

query_builder.params = query_builder.params.set("limit", k)

res = query_builder.execute()

match_result = [
(
Document(
metadata=search.get("metadata", {}), # type: ignore
page_content=search.get("content", ""),
),
search.get("similarity", 0.0),
)
for search in res.data
if search.get("content")
]

if score_threshold is not None:
match_result = [
(doc, similarity)
for doc, similarity in match_result
if similarity >= score_threshold
]
if len(match_result) == 0:
warnings.warn(
"No relevant docs were retrieved using the relevance score"
f" threshold {score_threshold}"
)

return match_result


65 changes: 65 additions & 0 deletions libs/community/tests/unit_tests/vectorstores/test_asyncsupabase.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from typing import Optional, Dict, Any, List
from unittest.mock import patch, Mock, AsyncMock

import numpy
import pytest

from langchain_community.embeddings import FakeEmbeddings
from langchain_community.vectorstores.supabase import AsyncSupabaseVectorStore


async def create_vector_store() -> AsyncSupabaseVectorStore:
from supabase.client import acreate_client
import os
return AsyncSupabaseVectorStore(
client=await acreate_client(os.environ["my_supabase_url"], os.environ["my_supabase_key"]),
embedding=FakeEmbeddings(size=3),
table_name="documents",
#query_name="match_documents",
)

@pytest.mark.requires("supabase")
async def test_ids_used_correctly() -> None:
"""Check whether vector store uses the document ids when provided with them."""
from langchain_core.documents import Document

documents = [
Document(id="id1",
page_content="page zero Lorem Ipsum",
metadata={"source": "document.pdf", "page": 0, "id": "ID-document-1"},
),
Document(
id="id2",
page_content="page one Lorem Ipsum Dolor sit ameit",
metadata={"source": "document.pdf", "page": 1, "id": "ID-document-2"},
),
]
ids_provided = [i.id for i in documents]
table_mock = Mock(name="from_()")
mock_upsert = AsyncMock()
table_mock.upsert.return_value = mock_upsert
mock_result = Mock()
mock_upsert.execute.return_value = mock_result
mock_result.data=[{"id":"id1"}, {"id": "id2"}]

import supabase
with patch.object(
supabase._async.client.AsyncClient, "from_",return_value=table_mock) as from_mock:
# ), patch.object(supabase._async.client.AsyncClient, "get_index", mock_default_index):
#:
vector_store = await create_vector_store()
ids_used_at_upload = await vector_store.aadd_documents(documents, ids=ids_provided)
assert len(ids_provided) == len(ids_used_at_upload)
assert ids_provided == ids_used_at_upload
from_mock.assert_called_once_with("documents")
table_mock.upsert.assert_called_once()
list_submitted = table_mock.upsert.call_args.args[0]
assert len(list_submitted) == 2
assert [d["id"] for d in list_submitted] == [d.id for d in documents]
assert [d["content"] for d in list_submitted] == [d.page_content for d in documents]
for obj in list_submitted:
assert len(obj["embedding"])==3
assert all(type(v)==type(numpy.float64(0.2)) for v in obj["embedding"])
assert [d["metadata"] for d in list_submitted] == [d.metadata for d in documents]


Loading