Skip to content

Commit

Permalink
Feature/merge dev owen changes (#880)
Browse files Browse the repository at this point in the history
* add group ids to document abstraction, first steps

* extend group permissions

* up

* add tests for new group features

* up

* fixup auth

* onboard extensive regression tests

* adding regression tests

* finish tests

* rm selenium

* test observability

* uncomment tests

* checkin first set of group tests

* modify search, passing vector tests

* checkin work

* full delete logic

* update search to use new filters

* check in

* Clean up

* Check in

* add search

* tests/test_end_to_end.py::test_ingest_txt_document passing

* cleanup logging

* make schemas explicit

* move to run logger abstraction

* cleanup some test workflows

* revive tests

* tweak to pass tests

* tweak rrf

* finish hybrid search cleanup

* fixup on regr tests, regen payloads

* refresh payloads

* refactor api model

* Feature/refactor api model (#868)

* cleanup imports

* flake and cleanup

* coherent global import / export structure

* add ingestion response models

* add management response models

* cleanups

* checkin work on routes

* remove request models

* last fixes

* merge

* add user / group gating

* working test groups

* updating client

* rename service to restructure

* add get documents for group endpoint

* fix client bugs

* return delete format

* merge cleanups

* merge

* finalize

---------

Co-authored-by: NolanTrem <34580718+NolanTrem@users.noreply.github.com>
  • Loading branch information
emrgnt-cmplxty and NolanTrem authored Aug 17, 2024
1 parent de8e1aa commit 91f35c9
Show file tree
Hide file tree
Showing 35 changed files with 555 additions and 511 deletions.
7 changes: 7 additions & 0 deletions r2r/base/abstractions/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,13 @@ class TokenData(BaseModel):

class UserStats(BaseModel):
user_id: UUID
email: str
is_superuser: bool
is_active: bool
is_verified: bool
created_at: datetime
updated_at: datetime
group_ids: list[UUID]
num_files: int
total_size_in_bytes: int
document_ids: list[UUID]
118 changes: 0 additions & 118 deletions r2r/base/providers/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,130 +34,12 @@ class VectorDBProvider(Provider, ABC):
def _initialize_vector_db(self, dimension: int) -> None:
pass

@abstractmethod
def create_index(self, index_type, column_name, index_options):
pass

@abstractmethod
def upsert(self, entry: VectorEntry, commit: bool = True) -> None:
pass

@abstractmethod
def search(
self,
query_vector: list[float],
filters: dict[str, VectorDBFilterValue] = {},
limit: int = 10,
*args,
**kwargs,
) -> list[VectorSearchResult]:
pass

@abstractmethod
def hybrid_search(
self,
query_text: str,
query_vector: list[float],
limit: int = 10,
filters: Optional[dict[str, VectorDBFilterValue]] = None,
full_text_weight: float = 1.0,
semantic_weight: float = 1.0,
rrf_k: int = 20,
*args,
**kwargs,
) -> list[VectorSearchResult]:
pass

@abstractmethod
def delete(self, filters: dict[str, VectorDBFilterValue]) -> list[str]:
pass


class RelationalDBProvider(Provider, ABC):
@abstractmethod
def _initialize_relational_db(self) -> None:
pass

@abstractmethod
def upsert_documents_overview(
self, document_infs: list[DocumentInfo]
) -> None:
pass

@abstractmethod
def get_documents_overview(
self,
filter_user_ids: Optional[str] = None,
filter_group_ids: Optional[list[str]] = None,
filter_document_ids: Optional[list[str]] = None,
) -> list[DocumentInfo]:
pass

@abstractmethod
def delete_from_documents_overview(
self, document_id: str, version: Optional[str] = None
) -> dict:
pass

@abstractmethod
def get_users_overview(self, user_ids: Optional[list[str]] = None) -> dict:
pass

@abstractmethod
def create_user(self, email: str, password: str) -> UserResponse:
pass

@abstractmethod
def get_user_by_email(self, email: str) -> Optional[UserResponse]:
pass

@abstractmethod
def store_verification_code(
self, user_id: UUID, verification_code: str, expiry: datetime
):
pass

@abstractmethod
def get_user_id_by_verification_code(
self, verification_code: str
) -> Optional[UUID]:
pass

@abstractmethod
def mark_user_as_verified(self, user_id: UUID):
pass

@abstractmethod
def mark_user_as_superuser(self, user_id: UUID):
pass

@abstractmethod
def remove_verification_code(self, verification_code: str):
pass

@abstractmethod
def get_user_by_id(self, user_id: UUID) -> Optional[UserResponse]:
pass

@abstractmethod
def update_user(
self,
user_id: UUID,
email: Optional[str],
name: Optional[str],
bio: Optional[str],
profile_picture: Optional[str],
) -> UserResponse:
pass

@abstractmethod
def delete_user(self, user_id: UUID):
pass

@abstractmethod
def get_all_users(self) -> list[UserResponse]:
pass


class DatabaseProvider(Provider):

Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
2 changes: 2 additions & 0 deletions r2r/main/api/client/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
from .client import R2RClient

__all__ = ["R2RClient"]
8 changes: 2 additions & 6 deletions r2r/main/api/client/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ async def login(client, email: str, password: str) -> dict[str, Token]:
response = await client._make_request("POST", "login", data=data)
client.access_token = response["results"]["access_token"]["token"]
client._refresh_token = response["results"]["refresh_token"]["token"]
return response["results"]
return response

@staticmethod
async def user(client) -> UserResponse:
Expand All @@ -39,7 +39,7 @@ async def refresh_access_token(client) -> dict[str, Token]:
)
client.access_token = response["results"]["access_token"]["token"]
client._refresh_token = response["results"]["refresh_token"]["token"]
return response["results"]
return response

@staticmethod
async def change_password(
Expand Down Expand Up @@ -71,10 +71,6 @@ async def logout(client) -> dict:
client._refresh_token = None
return response

@staticmethod
async def get_user_profile(client, user_id: uuid.UUID) -> UserResponse:
return await client._make_request("GET", f"user/{user_id}")

@staticmethod
async def update_user(
client,
Expand Down
9 changes: 5 additions & 4 deletions r2r/main/api/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ async def handle_request_error_async(response):
if response.headers.get("content-type") == "application/json":
error_content = await response.json()
else:
error_content = await response.text()
error_content = await response.text

if isinstance(error_content, dict) and "detail" in error_content:
detail = error_content["detail"]
Expand Down Expand Up @@ -123,19 +123,20 @@ async def _make_request(self, method, endpoint, **kwargs):

if isinstance(self.client, TestClient):
response = getattr(self.client, method.lower())(
url, headers=headers, **kwargs
url, headers=headers, params=params, **kwargs
)
return response.json() if response.content else None
else:
try:
response = await self.client.request(
method, url, headers=headers, params=params, **kwargs
)
await handle_request_error_async(response)
return response.json()
return response.json() if response.content else None
except httpx.RequestError as e:
raise R2RException(
status_code=500, message=f"Request failed: {str(e)}"
) from e
)

def _get_auth_header(self) -> dict:
if not self.access_token:
Expand Down
68 changes: 1 addition & 67 deletions r2r/main/api/client/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,40 +4,10 @@
from contextlib import ExitStack
from typing import List, Optional, Union

from r2r.base import ChunkingConfig, Document, DocumentType
from r2r.base import ChunkingConfig


class IngestionMethods:
@staticmethod
async def ingest_documents(
client,
documents: List[Document],
versions: Optional[List[str]] = None,
chunking_config_override: Optional[Union[dict, ChunkingConfig]] = None,
) -> dict:
"""
Ingest a list of documents into the system.
Args:
documents (List[Document]): List of Document objects to ingest.
versions (Optional[List[str]]): List of version strings for each document.
chunking_config_override (Optional[Union[dict, ChunkingConfig]]): Custom chunking configuration.
Returns:
dict: Ingestion results containing processed, failed, and skipped documents.
"""
data = {
"documents": [doc.dict() for doc in documents],
"versions": versions,
"chunking_config_override": (
chunking_config_override.dict()
if isinstance(chunking_config_override, ChunkingConfig)
else chunking_config_override
),
}
return await client._make_request(
"POST", "ingest_documents", json=data
)

@staticmethod
async def ingest_files(
Expand Down Expand Up @@ -162,39 +132,3 @@ async def update_files(
return await client._make_request(
"POST", "update_files", data=data, files=files
)

@staticmethod
async def list_documents(
client,
user_id: Optional[str] = None,
group_ids: Optional[List[str]] = None,
document_type: Optional[DocumentType] = None,
status: Optional[str] = None,
page: int = 1,
page_size: int = 50,
) -> dict:
"""
List documents based on various filters.
Args:
user_id (Optional[str]): Filter by user ID.
group_ids (Optional[List[str]]): Filter by group IDs.
document_type (Optional[DocumentType]): Filter by document type.
status (Optional[str]): Filter by document status.
page (int): Page number for pagination.
page_size (int): Number of items per page.
Returns:
dict: List of documents matching the specified filters.
"""
params = {
"user_id": user_id,
"group_ids": json.dumps(group_ids) if group_ids else None,
"document_type": document_type.value if document_type else None,
"status": status,
"page": page,
"page_size": page_size,
}
return await client._make_request(
"GET", "list_documents", params=params
)
58 changes: 49 additions & 9 deletions r2r/main/api/client/management.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import uuid
from typing import Optional
from typing import Any, Optional, Union

from r2r.base import VectorDBFilterValue

Expand All @@ -14,11 +15,8 @@ async def update_prompt(
client,
name: str,
template: Optional[str] = None,
input_types: Optional[dict[str, str]] = None,
input_types: Optional[dict[str, str]] = {},
) -> dict:
if input_types is None:
input_types = {}

data = {
"name": name,
"template": template,
Expand Down Expand Up @@ -73,9 +71,11 @@ async def delete(
client,
filters: dict[str, VectorDBFilterValue],
) -> dict:
filters_json = json.dumps(filters)

return await client._make_request(
"DELETE", "delete", json={"filters": filters}
)
"DELETE", "delete", params={"filters": filters_json}
) or {"results": {}}

@staticmethod
async def documents_overview(
Expand Down Expand Up @@ -103,7 +103,7 @@ async def document_chunks(
document_id: uuid.UUID,
) -> dict:
return await client._make_request(
"GET", f"document_chunks/{document_id}"
"GET", "document_chunks", params={"document_id": document_id}
)

@staticmethod
Expand Down Expand Up @@ -261,10 +261,50 @@ async def get_groups_for_user(
async def groups_overview(
client,
group_ids: Optional[list[uuid.UUID]] = None,
limit: int = 100,
offset: int = 0,
) -> dict:
params = {
"group_ids": [str(gid) for gid in group_ids] if group_ids else None
"limit": limit,
"offset": offset,
}
if group_ids is not None:
params["group_ids"] = [str(gid) for gid in group_ids]
return await client._make_request(
"GET", "groups_overview", params=params
)

@staticmethod
async def get_documents_in_group(
client,
group_id: uuid.UUID,
offset: int = 0,
limit: int = 100,
) -> dict:
params = {
"offset": offset,
"limit": limit,
}
return await client._make_request(
"GET", f"group/{group_id}/documents", params=params
)

@staticmethod
async def analytics(
client,
filter_criteria: Optional[Union[dict, str]] = None,
analysis_types: Optional[Union[dict, str]] = None,
) -> dict:
params = {}
if filter_criteria:
if isinstance(filter_criteria, dict):
params["filter_criteria"] = json.dumps(filter_criteria)
else:
params["filter_criteria"] = filter_criteria
if analysis_types:
if isinstance(analysis_types, dict):
params["analysis_types"] = json.dumps(analysis_types)
else:
params["analysis_types"] = analysis_types

return await client._make_request("GET", "analytics", params=params)
Loading

0 comments on commit 91f35c9

Please sign in to comment.