Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/fix dev tests #966

Merged
merged 6 commits into from
Aug 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions py/core/base/abstractions/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,21 @@ def __init__(
self.status_code = status_code
super().__init__(self.message)

def to_dict(self):
return {
"message": self.message,
"status_code": self.status_code,
"detail": self.detail,
"error_type": self.__class__.__name__,
}


class R2RDocumentProcessingError(R2RException):
def __init__(self, error_message, document_id):
self.document_id = document_id
super().__init__(error_message, 400, {"document_id": document_id})

def to_dict(self):
result = super().to_dict()
result["document_id"] = self.document_id
return result
4 changes: 3 additions & 1 deletion py/core/main/api/routes/base_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from fastapi import APIRouter, HTTPException
from fastapi.responses import StreamingResponse

from core.base import R2RException, manage_run
from core.base import R2RDocumentProcessingError, R2RException, manage_run
from core.base.logging.base import RunType

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -48,6 +48,8 @@ async def wrapper(*args, **kwargs):
},
)
except Exception as e:
print("cc")

await self.engine.logging_connection.log(
run_id=run_id,
key="error",
Expand Down
5 changes: 4 additions & 1 deletion py/core/main/services/ingestion_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,10 @@ async def _process_ingestion_results(
if document.id in successful_ids
],
"failed_documents": [
{"document_id": document_id, "result": results[document_id]}
{
"document_id": document_id,
"result": str(results[document_id]),
}
for document_id in failed_ids
],
"skipped_documents": skipped_ids,
Expand Down
29 changes: 21 additions & 8 deletions py/core/parsers/media/img_parser.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
import base64
import os
from io import BytesIO
from typing import AsyncGenerator

from PIL import Image

from core.base.abstractions.document import DataType
from core.base.parsers.base_parser import AsyncParser
from core.parsers.media.openai_helpers import process_frame_with_openai
Expand All @@ -13,7 +17,8 @@ def __init__(
self,
model: str = "gpt-4o",
max_tokens: int = 2_048,
api_base: str = "https://api.openai.com/v2/chat/completions",
api_base: str = "https://api.openai.com/v1/chat/completions",
max_image_size: int = 1 * 1024 * 1024, # 4MB limit
):
self.model = model
self.max_tokens = max_tokens
Expand All @@ -23,12 +28,26 @@ def __init__(
"Error, environment variable `OPENAI_API_KEY` is required to run `ImageParser`."
)
self.api_base = api_base
self.max_image_size = max_image_size

def _resize_image(self, image_data: bytes, compression_ratio) -> bytes:
img = Image.open(BytesIO(image_data))
img_byte_arr = BytesIO()
img.save(
img_byte_arr, format="JPEG", quality=int(100 * compression_ratio)
)
return img_byte_arr.getvalue()

async def ingest(self, data: DataType) -> AsyncGenerator[str, None]:
"""Ingest image data and yield a description."""
if isinstance(data, bytes):
import base64
# Resize the image if it's too large
if len(data) > self.max_image_size:
data = self._resize_image(
data, float(self.max_image_size) / len(data)
)

# Encode to base64
data = base64.b64encode(data).decode("utf-8")

yield process_frame_with_openai(
Expand All @@ -38,9 +57,3 @@ async def ingest(self, data: DataType) -> AsyncGenerator[str, None]:
self.max_tokens,
self.api_base,
)


class ImageParserLocal(AsyncParser[DataType]):

def __init__(self):
pass
4 changes: 3 additions & 1 deletion py/core/providers/auth/r2r_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,9 @@ def refresh_access_token(self, refresh_token: str) -> Dict[str, Token]:
new_access_token = self.create_access_token(
data={"sub": token_data.email}
)
new_refresh_token = self.create_refresh_token()
new_refresh_token = self.create_refresh_token(
data={"sub": token_data.email}
)
return {
"access_token": Token(token=new_access_token, token_type="access"),
"refresh_token": Token(
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class GroupMixin(DatabaseMixin):
def create_table(self) -> None:
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name('groups')} (
group_id UUID PRIMARY KEY DEFAULT public.uuid_generate_v4(),
group_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
name TEXT NOT NULL,
description TEXT,
created_at TIMESTAMPTZ DEFAULT NOW(),
Expand Down
2 changes: 1 addition & 1 deletion py/core/providers/database/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class BlacklistedTokensMixin(DatabaseMixin):
def create_table(self):
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name('blacklisted_tokens')} (
id UUID PRIMARY KEY DEFAULT public.uuid_generate_v4(),
id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
token TEXT NOT NULL,
blacklisted_at TIMESTAMPTZ DEFAULT NOW()
);
Expand Down
14 changes: 1 addition & 13 deletions py/core/providers/database/user.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ class UserMixin(DatabaseMixin):
def create_table(self):
query = f"""
CREATE TABLE IF NOT EXISTS {self._get_table_name('users')} (
user_id UUID PRIMARY KEY DEFAULT public.uuid_generate_v4(),
user_id UUID PRIMARY KEY DEFAULT uuid_generate_v4(),
email TEXT UNIQUE NOT NULL,
hashed_password TEXT NOT NULL,
is_superuser BOOLEAN DEFAULT FALSE,
Expand Down Expand Up @@ -209,18 +209,6 @@ def delete_user(self, user_id: UUID) -> None:

user_groups = group_result[0]

# Remove user from all groups they belong to
if user_groups:
group_update_query = f"""
UPDATE {self._get_table_name('groups')}
SET user_ids = array_remove(user_ids, :user_id)
WHERE group_id = ANY(:group_ids)
"""
self.execute_query(
group_update_query,
{"user_id": user_id, "group_ids": user_groups},
)

# Remove user from documents
doc_update_query = f"""
UPDATE {self._get_table_name('document_info')}
Expand Down
68 changes: 36 additions & 32 deletions py/sdk/management.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ async def create_group(
@staticmethod
async def get_group(
client,
group_id: str,
group_id: Union[str, UUID],
) -> dict:
"""
Get a group by its ID.
Expand All @@ -287,12 +287,12 @@ async def get_group(
Returns:
dict: The group data.
"""
return await client._make_request("GET", f"get_group/{group_id}")
return await client._make_request("GET", f"get_group/{str(group_id)}")

@staticmethod
async def update_group(
client,
group_id: str,
group_id: Union[str, UUID],
name: Optional[str] = None,
description: Optional[str] = None,
) -> dict:
Expand All @@ -307,7 +307,7 @@ async def update_group(
Returns:
dict: The response from the server.
"""
data = {"group_id": group_id}
data = {"group_id": str(group_id)}
if name is not None:
data["name"] = name
if description is not None:
Expand All @@ -318,7 +318,7 @@ async def update_group(
@staticmethod
async def delete_group(
client,
group_id: str,
group_id: Union[str, UUID],
) -> dict:
"""
Delete a group by its ID.
Expand All @@ -329,7 +329,9 @@ async def delete_group(
Returns:
dict: The response from the server.
"""
return await client._make_request("DELETE", f"delete_group/{group_id}")
return await client._make_request(
"DELETE", f"delete_group/{str(group_id)}"
)

@staticmethod
async def delete_user(
Expand Down Expand Up @@ -385,8 +387,8 @@ async def list_groups(
@staticmethod
async def add_user_to_group(
client,
user_id: str,
group_id: str,
user_id: Union[str, UUID],
group_id: Union[str, UUID],
) -> dict:
"""
Add a user to a group.
Expand All @@ -399,8 +401,8 @@ async def add_user_to_group(
dict: The response from the server.
"""
data = {
"user_id": user_id,
"group_id": group_id,
"user_id": str(user_id),
"group_id": str(group_id),
}
return await client._make_request(
"POST", "add_user_to_group", json=data
Expand All @@ -409,8 +411,8 @@ async def add_user_to_group(
@staticmethod
async def remove_user_from_group(
client,
user_id: str,
group_id: str,
user_id: Union[str, UUID],
group_id: Union[str, UUID],
) -> dict:
"""
Remove a user from a group.
Expand All @@ -423,8 +425,8 @@ async def remove_user_from_group(
dict: The response from the server.
"""
data = {
"user_id": user_id,
"group_id": group_id,
"user_id": str(user_id),
"group_id": str(group_id),
}
return await client._make_request(
"POST", "remove_user_from_group", json=data
Expand All @@ -433,7 +435,7 @@ async def remove_user_from_group(
@staticmethod
async def get_users_in_group(
client,
group_id: str,
group_id: Union[str, UUID],
offset: Optional[int] = None,
limit: Optional[int] = None,
) -> dict:
Expand All @@ -454,13 +456,13 @@ async def get_users_in_group(
if limit is not None:
params["limit"] = limit
return await client._make_request(
"GET", f"get_users_in_group/{group_id}", params=params
"GET", f"get_users_in_group/{str(group_id)}", params=params
)

@staticmethod
async def user_groups(
client,
user_id: str,
user_id: Union[str, UUID],
offset: Optional[int] = None,
limit: Optional[int] = None,
) -> dict:
Expand All @@ -479,17 +481,19 @@ async def user_groups(
if limit is not None:
params["limit"] = limit
if params:
return await client._make_request("GET", f"user_groups/{user_id}")
return await client._make_request(
"GET", f"user_groups/{str(user_id)}"
)
else:
return await client._make_request(
"GET", f"user_groups/{user_id}", params=params
"GET", f"user_groups/{str(user_id)}", params=params
)

@staticmethod
async def assign_document_to_group(
client,
document_id: str,
group_id: str,
document_id: Union[str, UUID],
group_id: Union[str, UUID],
) -> dict:
"""
Assign a document to a group.
Expand All @@ -502,8 +506,8 @@ async def assign_document_to_group(
dict: The response from the server.
"""
data = {
"document_id": document_id,
"group_id": group_id,
"document_id": str(document_id),
"group_id": str(group_id),
}
return await client._make_request(
"POST", "assign_document_to_group", json=data
Expand All @@ -513,8 +517,8 @@ async def assign_document_to_group(
@staticmethod
async def remove_document_from_group(
client,
document_id: str,
group_id: str,
document_id: Union[str, UUID],
group_id: Union[str, UUID],
) -> dict:
"""
Remove a document from a group.
Expand All @@ -527,8 +531,8 @@ async def remove_document_from_group(
dict: The response from the server.
"""
data = {
"document_id": document_id,
"group_id": group_id,
"document_id": str(document_id),
"group_id": str(group_id),
}
return await client._make_request(
"POST", "remove_document_from_group", json=data
Expand All @@ -537,7 +541,7 @@ async def remove_document_from_group(
@staticmethod
async def document_groups(
client,
document_id: str,
document_id: Union[str, UUID],
offset: Optional[int] = None,
limit: Optional[int] = None,
) -> dict:
Expand All @@ -557,17 +561,17 @@ async def document_groups(
params["limit"] = limit
if params:
return await client._make_request(
"GET", f"document_groups/{document_id}", params=params
"GET", f"document_groups/{str(document_id)}", params=params
)
else:
return await client._make_request(
"GET", f"document_groups/{document_id}"
"GET", f"document_groups/{str(document_id)}"
)

@staticmethod
async def documents_in_group(
client,
group_id: str,
group_id: Union[str, UUID],
offset: Optional[int] = None,
limit: Optional[int] = None,
) -> dict:
Expand All @@ -588,5 +592,5 @@ async def documents_in_group(
if limit is not None:
params["limit"] = limit
return await client._make_request(
"GET", f"group/{group_id}/documents", params=params
"GET", f"group/{str(group_id)}/documents", params=params
)
Loading
Loading