Skip to content

Commit

Permalink
Further clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
NolanTrem committed Jan 1, 2025
1 parent 89244f5 commit f14a65d
Show file tree
Hide file tree
Showing 27 changed files with 145 additions and 117 deletions.
5 changes: 3 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,12 @@ repos:
- repo: local
hooks:
- id: check-typing-imports
name: Check for Dict/List usage
entry: bash -c 'echo "Checking for typing imports..." && find . -name "*.py" | xargs grep -n "from typing.*import.*[^d]Dict\\|from typing.*import.*List" || exit 0 && echo "⚠️ Please import dict/list instead of Dict/List from typing" && exit 1'
name: Check for Dict, List, or Union usage
entry: bash -c 'echo "Checking for typing imports..." && find . -name "*.py" | grep -v "/migrations/" | xargs grep -n "from typing.*import.*[^d]Dict\\|from typing.*import.*List\\|from typing.*import.*Union" || exit 0 && echo "⚠️ Please import dict instead of Dict, list instead of List, and the logical OR operator" && exit 1'
language: system
types: [python]
pass_filenames: false
exclude: ^py/migrations/

- repo: local
hooks:
Expand Down
10 changes: 2 additions & 8 deletions py/core/agent/rag.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Union

from core.agent import R2RAgent, R2RStreamingAgent
from core.base import (
format_search_results_for_llm,
Expand Down Expand Up @@ -126,9 +124,7 @@ class R2RRAGAgent(RAGAgentMixin, R2RAgent):
def __init__(
self,
database_provider: DatabaseProvider,
llm_provider: Union[
LiteLLMCompletionProvider, OpenAICompletionProvider
],
llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
search_pipeline: SearchPipeline,
config: AgentConfig,
):
Expand All @@ -144,9 +140,7 @@ class R2RStreamingRAGAgent(RAGAgentMixin, R2RStreamingAgent):
def __init__(
self,
database_provider: DatabaseProvider,
llm_provider: Union[
LiteLLMCompletionProvider, OpenAICompletionProvider
],
llm_provider: LiteLLMCompletionProvider | OpenAICompletionProvider,
search_pipeline: SearchPipeline,
config: AgentConfig,
):
Expand Down
3 changes: 0 additions & 3 deletions py/core/base/logger/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
from .base import RunInfoLog
from .run_manager import RunManager, manage_run

__all__ = [
# Basic types
"RunInfoLog",
# Run Manager
"RunManager",
"manage_run",
Expand Down
20 changes: 0 additions & 20 deletions py/core/base/logger/base.py

This file was deleted.

3 changes: 2 additions & 1 deletion py/core/base/providers/ingestion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from pydantic import Field

from core.base import ChunkingStrategy
from core.base.abstractions import ChunkEnrichmentSettings

from .base import AppConfig, Provider, ProviderConfig
Expand Down Expand Up @@ -44,7 +45,7 @@ class IngestionConfig(ProviderConfig):
excluded_parsers: list[str] = Field(
default_factory=lambda: IngestionConfig._defaults["excluded_parsers"]
)
chunking_strategy: str = Field(
chunking_strategy: str | ChunkingStrategy = Field(
default_factory=lambda: IngestionConfig._defaults["chunking_strategy"]
)
chunk_enrichment_settings: ChunkEnrichmentSettings = Field(
Expand Down
12 changes: 4 additions & 8 deletions py/core/database/graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,9 +54,7 @@ def _get_table_name(self, table: str) -> str:

def _get_entity_table_for_store(self, store_type: StoreType) -> str:
"""Get the appropriate table name for the store type."""
if isinstance(store_type, StoreType):
store_type = store_type.value
return f"{store_type}_entities"
return f"{store_type.value}_entities"

def _get_parent_constraint(self, store_type: StoreType) -> str:
"""Get the appropriate foreign key constraint for the store type."""
Expand Down Expand Up @@ -494,9 +492,7 @@ def _get_table_name(self, table: str) -> str:

def _get_relationship_table_for_store(self, store_type: StoreType) -> str:
"""Get the appropriate table name for the store type."""
if isinstance(store_type, StoreType):
store_type = store_type.value
return f"{store_type}_relationships"
return f"{store_type.value}_relationships"

def _get_parent_constraint(self, store_type: StoreType) -> str:
"""Get the appropriate foreign key constraint for the store type."""
Expand Down Expand Up @@ -2468,13 +2464,13 @@ async def _get_relationship_ids_cache(
relationship_ids_cache.setdefault(relationship.subject, [])
if relationship.id is not None:
relationship_ids_cache[relationship.subject].append(
relationship.id
int(relationship.id)
)
if relationship.object is not None:
relationship_ids_cache.setdefault(relationship.object, [])
if relationship.id is not None:
relationship_ids_cache[relationship.object].append(
relationship.id
int(relationship.id)
)

return relationship_ids_cache
Expand Down
6 changes: 3 additions & 3 deletions py/core/database/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .users import PostgresUserHandler

if TYPE_CHECKING:
from ..providers.crypto import NaClCryptoProvider
from ..providers.crypto import BCryptCryptoProvider, NaClCryptoProvider

logger = logging.getLogger()

Expand All @@ -47,7 +47,7 @@ class PostgresDatabaseProvider(DatabaseProvider):
dimension: int
conn: Optional[Any]

crypto_provider: "NaClCryptoProvider"
crypto_provider: "BCryptCryptoProvider" | "NaClCryptoProvider"
postgres_configuration_settings: PostgresConfigurationSettings
default_collection_name: str
default_collection_description: str
Expand All @@ -71,7 +71,7 @@ def __init__(
self,
config: DatabaseConfig,
dimension: int,
crypto_provider: "NaClCryptoProvider",
crypto_provider: "BCryptCryptoProvider | NaClCryptoProvider",
quantization_type: VectorQuantizationType = VectorQuantizationType.FP32,
*args,
**kwargs,
Expand Down
6 changes: 3 additions & 3 deletions py/core/database/vecs/adapter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@

from abc import ABC, abstractmethod
from enum import Enum
from typing import Any, Generator, Iterable, Optional, Tuple, Union
from typing import Any, Generator, Iterable, Optional, Tuple
from uuid import UUID

from vecs.exc import ArgError

MetadataValues = Union[str, int, float, bool, list[str]]
MetadataValues = str | int | float | bool | list[str]
Metadata = dict[str, MetadataValues]
Numeric = Union[int, float, complex]
Numeric = int | float | complex

Record = Tuple[
UUID,
Expand Down
2 changes: 1 addition & 1 deletion py/core/main/api/v3/users_router.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import textwrap
from typing import Optional, Union
from typing import Optional
from uuid import UUID

from fastapi import Body, Depends, Path, Query
Expand Down
44 changes: 31 additions & 13 deletions py/core/main/assembly/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,23 @@
OrchestrationConfig,
)
from core.pipelines import RAGPipeline, SearchPipeline
from core.pipes import GeneratorPipe, MultiSearchPipe, SearchPipe
from core.pipes import (
EmbeddingPipe,
GeneratorPipe,
GraphClusteringPipe,
GraphCommunitySummaryPipe,
GraphDescriptionPipe,
GraphExtractionPipe,
GraphSearchSearchPipe,
GraphStoragePipe,
MultiSearchPipe,
ParsingPipe,
RAGPipe,
SearchPipe,
StreamingRAGPipe,
VectorSearchPipe,
VectorStoragePipe,
)
from core.providers.email.sendgrid import SendGridEmailProvider

from ..abstractions import R2RAgents, R2RPipelines, R2RPipes, R2RProviders
Expand Down Expand Up @@ -366,18 +382,20 @@ def __init__(self, config: R2RConfig, providers: R2RProviders):

def create_pipes(
self,
parsing_pipe_override: Optional[AsyncPipe] = None,
embedding_pipe_override: Optional[AsyncPipe] = None,
graph_extraction_pipe_override: Optional[AsyncPipe] = None,
graph_storage_pipe_override: Optional[AsyncPipe] = None,
graph_search_pipe_override: Optional[AsyncPipe] = None,
vector_storage_pipe_override: Optional[AsyncPipe] = None,
vector_search_pipe_override: Optional[AsyncPipe] = None,
rag_pipe_override: Optional[AsyncPipe] = None,
streaming_rag_pipe_override: Optional[AsyncPipe] = None,
graph_description_pipe: Optional[AsyncPipe] = None,
graph_clustering_pipe: Optional[AsyncPipe] = None,
graph_community_summary_pipe: Optional[AsyncPipe] = None,
parsing_pipe_override: Optional[ParsingPipe] = None,
embedding_pipe_override: Optional[EmbeddingPipe] = None,
graph_extraction_pipe_override: Optional[GraphExtractionPipe] = None,
graph_storage_pipe_override: Optional[GraphStoragePipe] = None,
graph_search_pipe_override: Optional[GraphSearchSearchPipe] = None,
vector_storage_pipe_override: Optional[VectorStoragePipe] = None,
vector_search_pipe_override: Optional[VectorSearchPipe] = None,
rag_pipe_override: Optional[RAGPipe] = None,
streaming_rag_pipe_override: Optional[StreamingRAGPipe] = None,
graph_description_pipe: Optional[GraphDescriptionPipe] = None,
graph_clustering_pipe: Optional[GraphClusteringPipe] = None,
graph_community_summary_pipe: Optional[
GraphCommunitySummaryPipe
] = None,
*args,
**kwargs,
) -> R2RPipes:
Expand Down
4 changes: 2 additions & 2 deletions py/core/main/services/auth_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,7 @@ async def create_user_api_key(self, user_id: UUID) -> dict:
"""
return await self.providers.auth.create_user_api_key(user_id)

async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> bool:
"""
Delete the API key for the user.
Expand All @@ -292,7 +292,7 @@ async def delete_user_api_key(self, user_id: UUID, key_id: UUID) -> dict:
key_id (str): The ID of the API key
Returns:
dict: Contains the message
bool: True if the API key was deleted successfully
"""
return await self.providers.auth.delete_user_api_key(
user_id=user_id, key_id=key_id
Expand Down
10 changes: 6 additions & 4 deletions py/core/main/services/management_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,9 @@ async def get_cached_prompt(
return {
"message": (
await self.providers.database.prompts_handler.get_cached_prompt(
prompt_name, inputs, prompt_override
prompt_name=prompt_name,
inputs=inputs,
prompt_override=prompt_override,
)
)
}
Expand Down Expand Up @@ -838,11 +840,11 @@ async def delete_conversation(
filter_user_ids=user_ids,
)

async def get_user_max_documents(self, user_id: UUID) -> int:
async def get_user_max_documents(self, user_id: UUID) -> int | None:
return self.config.app.default_max_documents_per_user

async def get_user_max_chunks(self, user_id: UUID) -> int:
async def get_user_max_chunks(self, user_id: UUID) -> int | None:
return self.config.app.default_max_chunks_per_user

async def get_user_max_collections(self, user_id: UUID) -> int:
async def get_user_max_collections(self, user_id: UUID) -> int | None:
return self.config.app.default_max_collections_per_user
6 changes: 3 additions & 3 deletions py/core/parsers/structured/csv_parser.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# type: ignore
from typing import IO, AsyncGenerator, Optional, Union
from typing import IO, AsyncGenerator, Optional

from core.base.parsers.base_parser import AsyncParser
from core.base.providers import (
Expand Down Expand Up @@ -29,7 +29,7 @@ def __init__(
self.StringIO = StringIO

async def ingest(
self, data: Union[str, bytes], *args, **kwargs
self, data: str | bytes, *args, **kwargs
) -> AsyncGenerator[str, None]:
"""Ingest CSV data and yield text from each row."""
if isinstance(data, bytes):
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_delimiter(

async def ingest(
self,
data: Union[str, bytes],
data: str | bytes,
num_col_times_num_rows: int = 100,
*args,
**kwargs,
Expand Down
4 changes: 2 additions & 2 deletions py/core/pipes/abstractions/search_pipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import logging
from abc import abstractmethod
from typing import Any, AsyncGenerator, Optional, Union
from typing import Any, AsyncGenerator
from uuid import UUID

from core.base import AsyncPipe, AsyncState, ChunkSearchResult
Expand All @@ -15,7 +15,7 @@ class SearchConfig(AsyncPipe.PipeConfig):
limit: int = 10

class Input(AsyncPipe.Input):
message: Union[AsyncGenerator[str, None], str]
message: AsyncGenerator[str, None] | str

def __init__(
self,
Expand Down
4 changes: 2 additions & 2 deletions py/core/pipes/ingestion/embedding_pipe.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import asyncio
import logging
from typing import Any, AsyncGenerator, Optional, Union
from typing import Any, AsyncGenerator

from core.base import (
AsyncState,
Expand Down Expand Up @@ -113,7 +113,7 @@ async def process_batch(batch):

async def _process_extraction(
self, extraction: DocumentChunk
) -> Union[VectorEntry, R2RDocumentProcessingError]:
) -> VectorEntry | R2RDocumentProcessingError:
try:
if isinstance(extraction.data, bytes):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion py/core/pipes/kg/community_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ async def _run_logic( # type: ignore
)

# Organize clusters
clusters: dict[Any] = {}
clusters: dict[Any, Any] = {}
for item in community_clusters:
cluster_id = (
item["cluster"]
Expand Down
4 changes: 2 additions & 2 deletions py/core/pipes/kg/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging
import re
import time
from typing import Any, AsyncGenerator, Optional, Union
from typing import Any, AsyncGenerator, Optional

from core.base import (
AsyncState,
Expand Down Expand Up @@ -211,7 +211,7 @@ async def _run_logic( # type: ignore
run_id: Any,
*args: Any,
**kwargs: Any,
) -> AsyncGenerator[Union[KGExtraction, R2RDocumentProcessingError], None]:
) -> AsyncGenerator[KGExtraction | R2RDocumentProcessingError, None]:
start_time = time.time()

document_id = input.message["document_id"]
Expand Down
Loading

0 comments on commit f14a65d

Please sign in to comment.