Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: allow any DI implementation to works with the FastAPI implementation #9

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 14 additions & 10 deletions langchain_openai_api_bridge/assistant/adapter/container.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,21 @@
from langchain_openai_api_bridge.assistant.adapter.thread_to_langchain_input_messages_service import (
ThreadToLangchainInputMessagesService,
)
from langchain_openai_api_bridge.core.utils.di_container import DIContainer
from langchain_openai_api_bridge.assistant.assistant_lib_injector import (
BaseAssistantLibInjector,
)


def register_assistant_adapter(container: DIContainer) -> DIContainer:
container.register(OnChatModelStreamHandler)
container.register(OnChatModelEndHandler)
container.register(ThreadRunEventHandler)
container.register(OnToolStartHandler)
container.register(OnToolEndHandler)
def register_assistant_adapter(
injector: BaseAssistantLibInjector,
) -> BaseAssistantLibInjector:
injector.register(OnChatModelStreamHandler)
injector.register(OnChatModelEndHandler)
injector.register(ThreadRunEventHandler)
injector.register(OnToolStartHandler)
injector.register(OnToolEndHandler)

container.register(ThreadToLangchainInputMessagesService)
container.register(LanggraphEventToOpenAIAssistantEventStream)
injector.register(ThreadToLangchainInputMessagesService)
injector.register(LanggraphEventToOpenAIAssistantEventStream)

return container
return injector
46 changes: 28 additions & 18 deletions langchain_openai_api_bridge/assistant/assistant_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
from langchain_openai_api_bridge.assistant.adapter.container import (
register_assistant_adapter,
)
from langchain_openai_api_bridge.assistant.assistant_lib_injector import (
BaseAssistantLibInjector,
)
from langchain_openai_api_bridge.assistant.assistant_message_service import (
AssistantMessageService,
)
Expand All @@ -21,32 +24,39 @@
ThreadRepository,
)
from langchain_openai_api_bridge.core.agent_factory import AgentFactory
from langchain_openai_api_bridge.core.utils.di_container import DIContainer


class AssistantApp:
def __init__(
self,
thread_repository_type: Type[ThreadRepository],
message_repository_type: Type[MessageRepository],
run_repository: Type[RunRepository],
agent_factory: Type[AgentFactory],
injector: BaseAssistantLibInjector,
thread_repository_type: Optional[Type[ThreadRepository]] = None,
message_repository_type: Optional[Type[MessageRepository]] = None,
run_repository: Optional[Type[RunRepository]] = None,
agent_factory: Optional[Type[AgentFactory]] = None,
system_fingerprint: Optional[str] = "",
):
self.container = DIContainer()
self.injector = injector
self.system_fingerprint = system_fingerprint

register_assistant_adapter(self.container)
register_assistant_adapter(self.injector)

self.injector.register(AssistantThreadService)
self.injector.register(AssistantMessageService)
self.injector.register(AssistantRunService)

if thread_repository_type is not None:
self.injector.register(
ThreadRepository, to=thread_repository_type, scope="singleton"
)

if message_repository_type is not None:
self.injector.register(
MessageRepository, to=message_repository_type, scope="singleton"
)

self.container.register(AssistantThreadService)
self.container.register(AssistantMessageService)
self.container.register(AssistantRunService)
if run_repository is not None:
self.injector.register(RunRepository, to=run_repository, scope="singleton")

self.container.register(
ThreadRepository, to=thread_repository_type, singleton=True
)
self.container.register(
MessageRepository, to=message_repository_type, singleton=True
)
self.container.register(RunRepository, to=run_repository, singleton=True)
self.container.register(AgentFactory, to=agent_factory)
if agent_factory is not None:
self.injector.register(AgentFactory, to=agent_factory)
19 changes: 19 additions & 0 deletions langchain_openai_api_bridge/assistant/assistant_lib_injector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import ABC, abstractmethod
from typing import Literal, Optional, Type, TypeVar

T = TypeVar("T")


class BaseAssistantLibInjector(ABC):
@abstractmethod
def get(self, cls: Type[T]) -> T:
pass

@abstractmethod
def register(
self,
cls: Type[T],
to: Optional[T] = None,
scope: Literal["singleton", None] = None,
) -> None:
pass
20 changes: 10 additions & 10 deletions langchain_openai_api_bridge/fastapi/add_assistant_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,22 @@ def create_open_ai_compatible_assistant_router(
assistant_app: AssistantApp,
):

container = assistant_app.container
container = assistant_app.injector
thread_router = APIRouter(prefix="/threads")

@thread_router.post("/")
def assistant_create_thread(create_request: CreateThreadDto):
service = container.resolve(AssistantThreadService)
service = container.get(AssistantThreadService)
return service.create(create_request)

@thread_router.get("/{thread_id}")
def assistant_retreive_thread(thread_id: str):
service = container.resolve(AssistantThreadService)
service = container.get(AssistantThreadService)
return service.retreive(thread_id=thread_id)

@thread_router.delete("/{thread_id}")
def assistant_delete_thread(thread_id: str):
service = container.resolve(AssistantThreadService)
service = container.get(AssistantThreadService)
return service.delete(thread_id=thread_id)

@thread_router.get("/{thread_id}/messages")
Expand All @@ -56,7 +56,7 @@ async def assistant_list_thread_messages(
limit: int = 100,
order: Literal["asc", "desc"] = None,
):
service = container.resolve(AssistantMessageService)
service = container.get(AssistantMessageService)
messages = service.list(
thread_id=thread_id, after=after, before=before, limit=limit, order=order
)
Expand All @@ -68,7 +68,7 @@ async def assistant_retreive_thread_messages(
thread_id: str,
message_id: str,
):
service = container.resolve(AssistantMessageService)
service = container.get(AssistantMessageService)
message = service.retreive(thread_id=thread_id, message_id=message_id)

return message
Expand All @@ -78,15 +78,15 @@ def assistant_delete_thread_messages(
thread_id: str,
message_id: str,
):
service = container.resolve(AssistantMessageService)
service = container.get(AssistantMessageService)
return service.delete(thread_id=thread_id, message_id=message_id)

@thread_router.post("/{thread_id}/messages")
def assistant_create_thread_messages(
thread_id: str,
request: CreateThreadMessageDto,
):
service = container.resolve(AssistantMessageService)
service = container.get(AssistantMessageService)
message = service.create(thread_id=thread_id, dto=request)

return message
Expand All @@ -101,7 +101,7 @@ async def assistant_create_thread_runs(

api_key = get_bearer_token(authorization)

agent_factory = container.resolve(AgentFactory)
agent_factory = container.get(AgentFactory)
create_agent_dto = CreateAgentDto(
model=thread_run_dto.model,
api_key=api_key,
Expand All @@ -111,7 +111,7 @@ async def assistant_create_thread_runs(
llm = agent_factory.create_llm(dto=create_agent_dto)
agent = agent_factory.create_agent(llm=llm, dto=create_agent_dto)

service = container.resolve(AssistantRunService)
service = container.get(AssistantRunService)
stream = service.astream(agent=agent, dto=thread_run_dto)

response_factory = AssistantStreamEventAdapter()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,15 @@
def create_open_ai_compatible_chat_completion_router(
assistant_app: AssistantApp,
):
container = assistant_app.container
container = assistant_app.injector
chat_completion_router = APIRouter(prefix="/chat/completions")

@chat_completion_router.post("/")
async def assistant_retreive_thread_messages(
request: OpenAIChatCompletionRequest, authorization: str = Header(None)
):
api_key = get_bearer_token(authorization)
agent_factory = container.resolve(AgentFactory)
agent_factory = container.get(AgentFactory)
create_agent_dto = CreateAgentDto(
model=request.model,
api_key=api_key,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
from langchain_openai_api_bridge.fastapi import (
include_assistant,
)
from tests.test_functional.tiny_di_container import (
AssistantLibInjector,
)
from tests.test_functional.fastapi_assistant_agent_openai_advanced.my_agent_factory import (
MyAgentFactory,
)
Expand All @@ -20,6 +23,7 @@


assistant_app = AssistantApp(
injector=AssistantLibInjector(),
thread_repository_type=InMemoryThreadRepository,
message_repository_type=InMemoryMessageRepository,
run_repository=InMemoryRunRepository,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from tests.test_functional.fastapi_chat_completion_anthropic.my_anthropic_agent_factory import (
MyAnthropicAgentFactory,
)
from tests.test_functional.tiny_di_container import AssistantLibInjector

_ = load_dotenv(find_dotenv())

Expand All @@ -32,6 +33,7 @@
)

assistant_app = AssistantApp(
injector=AssistantLibInjector(),
thread_repository_type=InMemoryThreadRepository,
message_repository_type=InMemoryMessageRepository,
run_repository=InMemoryRunRepository,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from tests.test_functional.fastapi_chat_completion_openai.my_openai_agent_factory import (
MyOpenAIAgentFactory,
)
from tests.test_functional.tiny_di_container import AssistantLibInjector


_ = load_dotenv(find_dotenv())
Expand All @@ -34,6 +35,7 @@
)

assistant_app = AssistantApp(
injector=AssistantLibInjector(),
thread_repository_type=InMemoryThreadRepository,
message_repository_type=InMemoryMessageRepository,
run_repository=InMemoryRunRepository,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,10 +1,19 @@
from typing import Type, TypeVar, Dict, Any, Union
from typing import Literal, Optional, Type, TypeVar, Dict, Any, Union
import inspect

from langchain_openai_api_bridge.assistant.assistant_lib_injector import (
BaseAssistantLibInjector,
)

T = TypeVar("T")


class DIContainer:
# !! This DI Implementation is an example and should not be used in production !!
# Prefer using a well implemented DI Container, like
# https://github.com/python-injector/injector
# or
# https://python-dependency-injector.ets-labs.org
class TinyDIContainer:
def __init__(self):
self.services: Dict[Type[Any], Any] = {}
self.singletons: Dict[Type[Any], Any] = {}
Expand Down Expand Up @@ -51,3 +60,19 @@ def _create_instance(self, cls: Type[T]) -> T:
if name != "self" and param.annotation != param.empty
}
return cls(**dependencies)


class AssistantLibInjector(BaseAssistantLibInjector):
def __init__(self):
self.di_container = TinyDIContainer()

def get(self, cls: Type[T]) -> T:
return self.di_container.resolve(cls=cls)

def register(
self,
cls: Type[T],
to: Optional[T] = None,
scope: Literal["singleton", None] = None,
) -> None:
self.di_container.register(cls=cls, service=to, singleton=scope == "singleton")