From 3b96a5e481c0a735999c4cb0147f36ca61055359 Mon Sep 17 00:00:00 2001 From: Thomas Burke Date: Thu, 29 Aug 2024 12:51:12 -0700 Subject: [PATCH 1/7] first draft of custom event handler support --- .../chat_completion_compatible_api.py | 7 ++-- .../langchain_stream_adapter.py | 30 ++++++++++------- .../core/base_agent_factory.py | 6 +++- .../core/utils/pydantic_async_iterator.py | 2 +- .../fastapi/chat_completion_router.py | 6 ++-- .../server_openai.py | 28 +++++++++++----- .../test_langchain_stream_adapter.py | 32 +++++++++++++++++++ 7 files changed, 85 insertions(+), 26 deletions(-) diff --git a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py index ecaddcf..b20fac0 100644 --- a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py +++ b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py @@ -15,12 +15,13 @@ class ChatCompletionCompatibleAPI: @staticmethod def from_agent( - agent: Runnable, llm_model: str, system_fingerprint: Optional[str] = "" + agent: Runnable, llm_model: str, system_fingerprint: Optional[str] = "", custom_event_handler = lambda event: None ): return ChatCompletionCompatibleAPI( LangchainStreamAdapter(llm_model, system_fingerprint), LangchainInvokeAdapter(llm_model, system_fingerprint), agent, + custom_event_handler, ) def __init__( @@ -28,10 +29,12 @@ def __init__( stream_adapter: LangchainStreamAdapter, invoke_adapter: LangchainInvokeAdapter, agent: Runnable, + custom_event_handler=None, ) -> None: self.stream_adapter = stream_adapter self.invoke_adapter = invoke_adapter self.agent = agent + self.custom_event_handler = custom_event_handler def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]: input = self.__to_input(messages) @@ -40,7 +43,7 @@ def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]: version="v2", ) return ato_dict( - self.stream_adapter.ato_chat_completion_chunk_stream(astream_event) + self.stream_adapter.ato_chat_completion_chunk_stream(astream_event, custom_event_handler=self.custom_event_handler) ) def invoke(self, messages: List[OpenAIChatMessage]) -> dict: diff --git a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py index 9c778e3..b114bb7 100644 --- a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py +++ b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py @@ -12,7 +12,6 @@ OpenAIChatCompletionChunkObject, ) - class LangchainStreamAdapter: def __init__(self, llm_model: str, system_fingerprint: str = ""): self.llm_model = llm_model @@ -22,20 +21,29 @@ async def ato_chat_completion_chunk_stream( self, astream_event: AsyncIterator[StreamEvent], id: str = "", + custom_event_handler = lambda event: None, ) -> AsyncIterator[OpenAIChatCompletionChunkObject]: if id == "": id = str(uuid.uuid4()) async for event in astream_event: - kind = event["event"] - match kind: - case "on_chat_model_stream": - chunk = to_openai_chat_completion_chunk_object( - event=event, - id=id, - model=self.llm_model, - system_fingerprint=self.system_fingerprint, - ) - yield chunk + custom_event = custom_event_handler(event) + if custom_event is not None: + yield to_openai_chat_completion_chunk_object( + event=custom_event, + id=id, + model=self.llm_model, + system_fingerprint=self.system_fingerprint, + ) + else: + kind = event["event"] + match kind: + case "on_chat_model_stream": + yield to_openai_chat_completion_chunk_object( + event=event, + id=id, + model=self.llm_model, + system_fingerprint=self.system_fingerprint, + ) stop_chunk = create_final_chat_completion_chunk_object( id=id, model=self.llm_model diff --git a/langchain_openai_api_bridge/core/base_agent_factory.py b/langchain_openai_api_bridge/core/base_agent_factory.py index c16ad86..b8284f8 100644 --- a/langchain_openai_api_bridge/core/base_agent_factory.py +++ b/langchain_openai_api_bridge/core/base_agent_factory.py @@ -4,7 +4,11 @@ class BaseAgentFactory(ABC): - + + @classmethod + def custom_event_handler(self, event): + pass + @abstractmethod def create_agent(self, dto: CreateAgentDto) -> Runnable: pass diff --git a/langchain_openai_api_bridge/core/utils/pydantic_async_iterator.py b/langchain_openai_api_bridge/core/utils/pydantic_async_iterator.py index db6bc6a..52b10a7 100644 --- a/langchain_openai_api_bridge/core/utils/pydantic_async_iterator.py +++ b/langchain_openai_api_bridge/core/utils/pydantic_async_iterator.py @@ -5,4 +5,4 @@ async def ato_dict(async_iter: AsyncIterator[BaseModel]) -> AsyncIterator[dict]: async for obj in async_iter: - yield obj.dict() + yield obj.model_dump() diff --git a/langchain_openai_api_bridge/fastapi/chat_completion_router.py b/langchain_openai_api_bridge/fastapi/chat_completion_router.py index 6130cbe..c88c324 100644 --- a/langchain_openai_api_bridge/fastapi/chat_completion_router.py +++ b/langchain_openai_api_bridge/fastapi/chat_completion_router.py @@ -17,9 +17,9 @@ def create_chat_completion_router( tiny_di_container: TinyDIContainer, ): - chat_completion_router = APIRouter(prefix="/chat/completions") + chat_completion_router = APIRouter(prefix="/chat") - @chat_completion_router.post("/") + @chat_completion_router.post("/completions") async def assistant_retreive_thread_messages( request: OpenAIChatCompletionRequest, authorization: str = Header(None) ): @@ -33,7 +33,7 @@ async def assistant_retreive_thread_messages( agent = agent_factory.create_agent(dto=create_agent_dto) - adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model) + adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model, custom_event_handler=agent_factory.custom_event_handler) response_factory = HttpStreamResponseAdapter() if request.stream is True: diff --git a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py index a247175..7cd022d 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py +++ b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py @@ -4,6 +4,7 @@ import uvicorn from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto +from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import ( LangchainOpenaiApiBridgeFastAPI, ) @@ -27,17 +28,28 @@ expose_headers=["*"], ) +class AgentFactory(BaseAgentFactory): + def custom_event_handler(self, event): + if "chunk" in event["data"]: + return event + return None -def create_agent(dto: CreateAgentDto): - return ChatOpenAI( - temperature=dto.temperature or 0.7, - model=dto.model, - max_tokens=dto.max_tokens, - api_key=dto.api_key, - ) + def create_agent(self, dto: CreateAgentDto): + return ChatOpenAI( + temperature=dto.temperature or 0.7, + model=dto.model, + max_tokens=dto.max_tokens, + api_key=dto.api_key, + ) -bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent) +def custom_event_handler(event): + kind = event["event"] + match kind: + case "on_chat_model_stream": + return event + +bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=AgentFactory()) bridge.bind_openai_chat_completion(prefix="/my-custom-path") if __name__ == "__main__": diff --git a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py index 450edb3..36beeae 100644 --- a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py +++ b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py @@ -43,3 +43,35 @@ async def test_stream_contains_every_on_chat_model_stream( items = await assemble_stream(response_stream) assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict() assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict() + + + @pytest.mark.asyncio + @patch( + "langchain_openai_api_bridge.chat_completion.langchain_stream_adapter.to_openai_chat_completion_chunk_object", + side_effect=lambda event, id, model, system_fingerprint: ( + ChatCompletionChunkStub({"key": event["data"]["chunk"].content}) + ), + ) + async def test_stream_contains_every_custom_handled_stream( + self, to_openai_chat_completion_chunk_object + ): + on_chat_model_stream_event1 = create_on_chat_model_stream_event(content="hello") + on_chat_model_stream_event2 = create_on_chat_model_stream_event(content="moto") + input_stream = generate_stream( + [ + on_chat_model_stream_event1, + on_chat_model_stream_event2, + ] + ) + + def custom_event_handler(event): + kind = event["event"] + match kind: + case "on_chat_model_stream": + return event + + response_stream = self.instance.ato_chat_completion_chunk_stream(input_stream, custom_event_handler=custom_event_handler) + + items = await assemble_stream(response_stream) + assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict() + assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict() From 0d5cdaff46b5d0a9212441923e589dcc9114c71c Mon Sep 17 00:00:00 2001 From: Thomas Burke Date: Thu, 5 Sep 2024 11:38:52 -0700 Subject: [PATCH 2/7] address initial feedback --- .../chat_completion_compatible_api.py | 7 ++-- .../core/base_agent_factory.py | 4 --- .../core/utils/pydantic_async_iterator.py | 2 +- .../fastapi/chat_completion_router.py | 9 +++-- .../langchain_openai_api_bridge_fastapi.py | 4 +-- .../server_openai.py | 10 ++---- .../test_server_openai.py | 35 +++++++++++++++++++ 7 files changed, 52 insertions(+), 19 deletions(-) diff --git a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py index b20fac0..9e801fb 100644 --- a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py +++ b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py @@ -15,7 +15,10 @@ class ChatCompletionCompatibleAPI: @staticmethod def from_agent( - agent: Runnable, llm_model: str, system_fingerprint: Optional[str] = "", custom_event_handler = lambda event: None + agent: Runnable, + llm_model: str, + system_fingerprint: Optional[str] = "", + custom_event_handler: callable = lambda event: None, ): return ChatCompletionCompatibleAPI( LangchainStreamAdapter(llm_model, system_fingerprint), @@ -29,7 +32,7 @@ def __init__( stream_adapter: LangchainStreamAdapter, invoke_adapter: LangchainInvokeAdapter, agent: Runnable, - custom_event_handler=None, + custom_event_handler: callable = lambda event: None, ) -> None: self.stream_adapter = stream_adapter self.invoke_adapter = invoke_adapter diff --git a/langchain_openai_api_bridge/core/base_agent_factory.py b/langchain_openai_api_bridge/core/base_agent_factory.py index b8284f8..8044187 100644 --- a/langchain_openai_api_bridge/core/base_agent_factory.py +++ b/langchain_openai_api_bridge/core/base_agent_factory.py @@ -5,10 +5,6 @@ class BaseAgentFactory(ABC): - @classmethod - def custom_event_handler(self, event): - pass - @abstractmethod def create_agent(self, dto: CreateAgentDto) -> Runnable: pass diff --git a/langchain_openai_api_bridge/core/utils/pydantic_async_iterator.py b/langchain_openai_api_bridge/core/utils/pydantic_async_iterator.py index 52b10a7..db6bc6a 100644 --- a/langchain_openai_api_bridge/core/utils/pydantic_async_iterator.py +++ b/langchain_openai_api_bridge/core/utils/pydantic_async_iterator.py @@ -5,4 +5,4 @@ async def ato_dict(async_iter: AsyncIterator[BaseModel]) -> AsyncIterator[dict]: async for obj in async_iter: - yield obj.model_dump() + yield obj.dict() diff --git a/langchain_openai_api_bridge/fastapi/chat_completion_router.py b/langchain_openai_api_bridge/fastapi/chat_completion_router.py index c88c324..a22ce07 100644 --- a/langchain_openai_api_bridge/fastapi/chat_completion_router.py +++ b/langchain_openai_api_bridge/fastapi/chat_completion_router.py @@ -16,6 +16,7 @@ def create_chat_completion_router( tiny_di_container: TinyDIContainer, + custom_event_handler: callable = lambda event: None, ): chat_completion_router = APIRouter(prefix="/chat") @@ -33,7 +34,7 @@ async def assistant_retreive_thread_messages( agent = agent_factory.create_agent(dto=create_agent_dto) - adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model, custom_event_handler=agent_factory.custom_event_handler) + adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model, custom_event_handler=custom_event_handler) response_factory = HttpStreamResponseAdapter() if request.stream is True: @@ -46,9 +47,11 @@ async def assistant_retreive_thread_messages( def create_openai_chat_completion_router( - tiny_di_container: TinyDIContainer, prefix: str = "" + tiny_di_container: TinyDIContainer, + prefix: str = "", + custom_event_handler: callable = lambda event: None, ): - router = create_chat_completion_router(tiny_di_container=tiny_di_container) + router = create_chat_completion_router(tiny_di_container=tiny_di_container, custom_event_handler=custom_event_handler) open_ai_router = APIRouter(prefix=f"{prefix}/openai/v1") open_ai_router.include_router(router) diff --git a/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py b/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py index 2eec58e..1b59118 100644 --- a/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py +++ b/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py @@ -97,9 +97,9 @@ def bind_openai_assistant_api( self.app.include_router(assistant_router) - def bind_openai_chat_completion(self, prefix: str = "") -> None: + def bind_openai_chat_completion(self, prefix: str = "", custom_event_handler: callable = lambda event: None) -> None: chat_completion_router = create_openai_chat_completion_router( - self.tiny_di_container, prefix=prefix + self.tiny_di_container, prefix=prefix, custom_event_handler=custom_event_handler ) self.app.include_router(chat_completion_router) diff --git a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py index 7cd022d..247a528 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py +++ b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py @@ -29,11 +29,6 @@ ) class AgentFactory(BaseAgentFactory): - def custom_event_handler(self, event): - if "chunk" in event["data"]: - return event - return None - def create_agent(self, dto: CreateAgentDto): return ChatOpenAI( temperature=dto.temperature or 0.7, @@ -42,6 +37,8 @@ def create_agent(self, dto: CreateAgentDto): api_key=dto.api_key, ) +bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=AgentFactory()) +bridge.bind_openai_chat_completion(prefix="/my-custom-path") def custom_event_handler(event): kind = event["event"] @@ -49,8 +46,7 @@ def custom_event_handler(event): case "on_chat_model_stream": return event -bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=AgentFactory()) -bridge.bind_openai_chat_completion(prefix="/my-custom-path") +bridge.bind_openai_chat_completion(prefix="/my-custom-events-path", custom_event_handler=custom_event_handler) if __name__ == "__main__": uvicorn.run(app, host="localhost") diff --git a/tests/test_functional/fastapi_chat_completion_openai/test_server_openai.py b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai.py index 736f464..94e314d 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/test_server_openai.py +++ b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai.py @@ -42,3 +42,38 @@ def test_chat_completion_stream(openai_client): stream_output = "".join(every_content) assert "This is a test" in stream_output + +@pytest.fixture +def openai_client_custom_events(): + return OpenAI( + base_url="http://testserver/my-custom-events-path/openai/v1", + http_client=test_api, + ) + +def test_chat_completion_invoke_custom_events(openai_client_custom_events): + chat_completion = openai_client_custom_events.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": 'Say "This is a test"', + } + ], + ) + assert "This is a test" in chat_completion.choices[0].message.content + + +def test_chat_completion_stream_custom_events(openai_client_custom_events): + chunks = openai_client_custom_events.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": 'Say "This is a test"'}], + stream=True, + ) + every_content = [] + for chunk in chunks: + if chunk.choices and isinstance(chunk.choices[0].delta.content, str): + every_content.append(chunk.choices[0].delta.content) + + stream_output = "".join(every_content) + + assert "This is a test" in stream_output \ No newline at end of file From 00913a2870493a030dd45b66cb35e6037bd44ef7 Mon Sep 17 00:00:00 2001 From: Thomas Burke Date: Thu, 5 Sep 2024 14:17:26 -0700 Subject: [PATCH 3/7] clean up --- .../core/base_agent_factory.py | 2 +- .../server_openai.py | 22 +++++++++---------- 2 files changed, 11 insertions(+), 13 deletions(-) diff --git a/langchain_openai_api_bridge/core/base_agent_factory.py b/langchain_openai_api_bridge/core/base_agent_factory.py index 8044187..c16ad86 100644 --- a/langchain_openai_api_bridge/core/base_agent_factory.py +++ b/langchain_openai_api_bridge/core/base_agent_factory.py @@ -4,7 +4,7 @@ class BaseAgentFactory(ABC): - + @abstractmethod def create_agent(self, dto: CreateAgentDto) -> Runnable: pass diff --git a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py index 247a528..4a5bbec 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py +++ b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py @@ -4,7 +4,6 @@ import uvicorn from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto -from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import ( LangchainOpenaiApiBridgeFastAPI, ) @@ -27,17 +26,16 @@ allow_headers=["*"], expose_headers=["*"], ) - -class AgentFactory(BaseAgentFactory): - def create_agent(self, dto: CreateAgentDto): - return ChatOpenAI( - temperature=dto.temperature or 0.7, - model=dto.model, - max_tokens=dto.max_tokens, - api_key=dto.api_key, - ) - -bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=AgentFactory()) + +def create_agent(dto: CreateAgentDto): + return ChatOpenAI( + temperature=dto.temperature or 0.7, + model=dto.model, + max_tokens=dto.max_tokens, + api_key=dto.api_key, + ) + +bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent) bridge.bind_openai_chat_completion(prefix="/my-custom-path") def custom_event_handler(event): From 617f0fa58f655843052b0d687c8fc3e8a792653c Mon Sep 17 00:00:00 2001 From: Thomas Burke Date: Fri, 20 Sep 2024 10:56:30 -0700 Subject: [PATCH 4/7] change func name, DRY up adapter --- .../chat_completion_compatible_api.py | 10 ++-- .../langchain_stream_adapter.py | 20 +++----- .../fastapi/chat_completion_router.py | 8 +-- .../langchain_openai_api_bridge_fastapi.py | 4 +- .../server_openai.py | 4 +- .../server_openai_event_adapter.py | 49 +++++++++++++++++++ .../test_server_openai.py | 35 ------------- .../test_server_openai_event_adapter.py | 42 ++++++++++++++++ .../test_langchain_stream_adapter.py | 4 +- 9 files changed, 112 insertions(+), 64 deletions(-) create mode 100644 tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py create mode 100644 tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py diff --git a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py index 9e801fb..03b6512 100644 --- a/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py +++ b/langchain_openai_api_bridge/chat_completion/chat_completion_compatible_api.py @@ -18,13 +18,13 @@ def from_agent( agent: Runnable, llm_model: str, system_fingerprint: Optional[str] = "", - custom_event_handler: callable = lambda event: None, + event_adapter: callable = lambda event: None, ): return ChatCompletionCompatibleAPI( LangchainStreamAdapter(llm_model, system_fingerprint), LangchainInvokeAdapter(llm_model, system_fingerprint), agent, - custom_event_handler, + event_adapter, ) def __init__( @@ -32,12 +32,12 @@ def __init__( stream_adapter: LangchainStreamAdapter, invoke_adapter: LangchainInvokeAdapter, agent: Runnable, - custom_event_handler: callable = lambda event: None, + event_adapter: callable = lambda event: None, ) -> None: self.stream_adapter = stream_adapter self.invoke_adapter = invoke_adapter self.agent = agent - self.custom_event_handler = custom_event_handler + self.event_adapter = event_adapter def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]: input = self.__to_input(messages) @@ -46,7 +46,7 @@ def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]: version="v2", ) return ato_dict( - self.stream_adapter.ato_chat_completion_chunk_stream(astream_event, custom_event_handler=self.custom_event_handler) + self.stream_adapter.ato_chat_completion_chunk_stream(astream_event, event_adapter=self.event_adapter) ) def invoke(self, messages: List[OpenAIChatMessage]) -> dict: diff --git a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py index b114bb7..78dd66b 100644 --- a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py +++ b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py @@ -21,29 +21,21 @@ async def ato_chat_completion_chunk_stream( self, astream_event: AsyncIterator[StreamEvent], id: str = "", - custom_event_handler = lambda event: None, + event_adapter = lambda event: None, ) -> AsyncIterator[OpenAIChatCompletionChunkObject]: if id == "": id = str(uuid.uuid4()) async for event in astream_event: - custom_event = custom_event_handler(event) - if custom_event is not None: + custom_event = event_adapter(event) + event_to_process = custom_event if custom_event is not None else event + kind = event_to_process["event"] + if kind == "on_chat_model_stream" or custom_event is not None: yield to_openai_chat_completion_chunk_object( - event=custom_event, + event=event_to_process, id=id, model=self.llm_model, system_fingerprint=self.system_fingerprint, ) - else: - kind = event["event"] - match kind: - case "on_chat_model_stream": - yield to_openai_chat_completion_chunk_object( - event=event, - id=id, - model=self.llm_model, - system_fingerprint=self.system_fingerprint, - ) stop_chunk = create_final_chat_completion_chunk_object( id=id, model=self.llm_model diff --git a/langchain_openai_api_bridge/fastapi/chat_completion_router.py b/langchain_openai_api_bridge/fastapi/chat_completion_router.py index a22ce07..b09e756 100644 --- a/langchain_openai_api_bridge/fastapi/chat_completion_router.py +++ b/langchain_openai_api_bridge/fastapi/chat_completion_router.py @@ -16,7 +16,7 @@ def create_chat_completion_router( tiny_di_container: TinyDIContainer, - custom_event_handler: callable = lambda event: None, + event_adapter: callable = lambda event: None, ): chat_completion_router = APIRouter(prefix="/chat") @@ -34,7 +34,7 @@ async def assistant_retreive_thread_messages( agent = agent_factory.create_agent(dto=create_agent_dto) - adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model, custom_event_handler=custom_event_handler) + adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model, event_adapter=event_adapter) response_factory = HttpStreamResponseAdapter() if request.stream is True: @@ -49,9 +49,9 @@ async def assistant_retreive_thread_messages( def create_openai_chat_completion_router( tiny_di_container: TinyDIContainer, prefix: str = "", - custom_event_handler: callable = lambda event: None, + event_adapter: callable = lambda event: None, ): - router = create_chat_completion_router(tiny_di_container=tiny_di_container, custom_event_handler=custom_event_handler) + router = create_chat_completion_router(tiny_di_container=tiny_di_container, event_adapter=event_adapter) open_ai_router = APIRouter(prefix=f"{prefix}/openai/v1") open_ai_router.include_router(router) diff --git a/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py b/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py index 1b59118..0a2b66c 100644 --- a/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py +++ b/langchain_openai_api_bridge/fastapi/langchain_openai_api_bridge_fastapi.py @@ -97,9 +97,9 @@ def bind_openai_assistant_api( self.app.include_router(assistant_router) - def bind_openai_chat_completion(self, prefix: str = "", custom_event_handler: callable = lambda event: None) -> None: + def bind_openai_chat_completion(self, prefix: str = "", event_adapter: callable = lambda event: None) -> None: chat_completion_router = create_openai_chat_completion_router( - self.tiny_di_container, prefix=prefix, custom_event_handler=custom_event_handler + self.tiny_di_container, prefix=prefix, event_adapter=event_adapter ) self.app.include_router(chat_completion_router) diff --git a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py index 4a5bbec..f0eef8b 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py +++ b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py @@ -38,13 +38,13 @@ def create_agent(dto: CreateAgentDto): bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent) bridge.bind_openai_chat_completion(prefix="/my-custom-path") -def custom_event_handler(event): +def event_adapter(event): kind = event["event"] match kind: case "on_chat_model_stream": return event -bridge.bind_openai_chat_completion(prefix="/my-custom-events-path", custom_event_handler=custom_event_handler) +bridge.bind_openai_chat_completion(prefix="/my-custom-events-path", event_adapter=event_adapter) if __name__ == "__main__": uvicorn.run(app, host="localhost") diff --git a/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py b/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py new file mode 100644 index 0000000..75df4ef --- /dev/null +++ b/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py @@ -0,0 +1,49 @@ +from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +from dotenv import load_dotenv, find_dotenv +import uvicorn + +from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto +from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import ( + LangchainOpenaiApiBridgeFastAPI, +) +from langchain_openai import ChatOpenAI + +_ = load_dotenv(find_dotenv()) + + +app = FastAPI( + title="Langchain Agent OpenAI API Bridge", + version="1.0", + description="OpenAI API exposing langchain agent", +) + +app.add_middleware( + CORSMiddleware, + allow_origins=["*"], + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], + expose_headers=["*"], +) + +def create_agent(dto: CreateAgentDto): + return ChatOpenAI( + temperature=dto.temperature or 0.7, + model=dto.model, + max_tokens=dto.max_tokens, + api_key=dto.api_key, + ) + +bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent) + +def event_adapter(event): + kind = event["event"] + match kind: + case "on_chat_model_stream": + return event + +bridge.bind_openai_chat_completion(prefix="/my-custom-events-path", event_adapter=event_adapter) + +if __name__ == "__main__": + uvicorn.run(app, host="localhost") diff --git a/tests/test_functional/fastapi_chat_completion_openai/test_server_openai.py b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai.py index 94e314d..736f464 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/test_server_openai.py +++ b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai.py @@ -42,38 +42,3 @@ def test_chat_completion_stream(openai_client): stream_output = "".join(every_content) assert "This is a test" in stream_output - -@pytest.fixture -def openai_client_custom_events(): - return OpenAI( - base_url="http://testserver/my-custom-events-path/openai/v1", - http_client=test_api, - ) - -def test_chat_completion_invoke_custom_events(openai_client_custom_events): - chat_completion = openai_client_custom_events.chat.completions.create( - model="gpt-4o-mini", - messages=[ - { - "role": "user", - "content": 'Say "This is a test"', - } - ], - ) - assert "This is a test" in chat_completion.choices[0].message.content - - -def test_chat_completion_stream_custom_events(openai_client_custom_events): - chunks = openai_client_custom_events.chat.completions.create( - model="gpt-4o-mini", - messages=[{"role": "user", "content": 'Say "This is a test"'}], - stream=True, - ) - every_content = [] - for chunk in chunks: - if chunk.choices and isinstance(chunk.choices[0].delta.content, str): - every_content.append(chunk.choices[0].delta.content) - - stream_output = "".join(every_content) - - assert "This is a test" in stream_output \ No newline at end of file diff --git a/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py new file mode 100644 index 0000000..a97ca05 --- /dev/null +++ b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py @@ -0,0 +1,42 @@ +import pytest +from openai import OpenAI +from fastapi.testclient import TestClient +from server_openai_event_adapter import app + + +test_api = TestClient(app) + +@pytest.fixture +def openai_client_custom_events(): + return OpenAI( + base_url="http://testserver/my-custom-events-path/openai/v1", + http_client=test_api, + ) + +def test_chat_completion_invoke_custom_events(openai_client_custom_events): + chat_completion = openai_client_custom_events.chat.completions.create( + model="gpt-4o-mini", + messages=[ + { + "role": "user", + "content": 'Say "This is a test"', + } + ], + ) + assert "This is a test" in chat_completion.choices[0].message.content + + +def test_chat_completion_stream_custom_events(openai_client_custom_events): + chunks = openai_client_custom_events.chat.completions.create( + model="gpt-4o-mini", + messages=[{"role": "user", "content": 'Say "This is a test"'}], + stream=True, + ) + every_content = [] + for chunk in chunks: + if chunk.choices and isinstance(chunk.choices[0].delta.content, str): + every_content.append(chunk.choices[0].delta.content) + + stream_output = "".join(every_content) + + assert "This is a test" in stream_output \ No newline at end of file diff --git a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py index 36beeae..1e49c46 100644 --- a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py +++ b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py @@ -64,13 +64,13 @@ async def test_stream_contains_every_custom_handled_stream( ] ) - def custom_event_handler(event): + def event_adapter(event): kind = event["event"] match kind: case "on_chat_model_stream": return event - response_stream = self.instance.ato_chat_completion_chunk_stream(input_stream, custom_event_handler=custom_event_handler) + response_stream = self.instance.ato_chat_completion_chunk_stream(input_stream, event_adapter=event_adapter) items = await assemble_stream(response_stream) assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict() From 845e444c7e2d67717946acf31d0b774f09adea8f Mon Sep 17 00:00:00 2001 From: Thomas Burke Date: Mon, 23 Sep 2024 09:59:26 -0700 Subject: [PATCH 5/7] remove dead code --- .../fastapi_chat_completion_openai/server_openai.py | 11 +++-------- 1 file changed, 3 insertions(+), 8 deletions(-) diff --git a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py index f0eef8b..689d475 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py +++ b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py @@ -26,7 +26,8 @@ allow_headers=["*"], expose_headers=["*"], ) - + + def create_agent(dto: CreateAgentDto): return ChatOpenAI( temperature=dto.temperature or 0.7, @@ -35,16 +36,10 @@ def create_agent(dto: CreateAgentDto): api_key=dto.api_key, ) + bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent) bridge.bind_openai_chat_completion(prefix="/my-custom-path") -def event_adapter(event): - kind = event["event"] - match kind: - case "on_chat_model_stream": - return event - -bridge.bind_openai_chat_completion(prefix="/my-custom-events-path", event_adapter=event_adapter) if __name__ == "__main__": uvicorn.run(app, host="localhost") From d3293d158cbca88a1b39d83d99aa1157b1fd9263 Mon Sep 17 00:00:00 2001 From: Thomas Burke Date: Mon, 23 Sep 2024 10:07:29 -0700 Subject: [PATCH 6/7] remove dead code --- .../fastapi_chat_completion_openai/server_openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py index 689d475..a247175 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/server_openai.py +++ b/tests/test_functional/fastapi_chat_completion_openai/server_openai.py @@ -40,6 +40,5 @@ def create_agent(dto: CreateAgentDto): bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent) bridge.bind_openai_chat_completion(prefix="/my-custom-path") - if __name__ == "__main__": uvicorn.run(app, host="localhost") From ecff89355da793f76d25903786b0c482057f3ff8 Mon Sep 17 00:00:00 2001 From: Samuel Date: Mon, 23 Sep 2024 14:02:28 -0400 Subject: [PATCH 7/7] chore: fix lint --- .../chat_completion/langchain_stream_adapter.py | 3 ++- .../server_openai_event_adapter.py | 10 ++++++++-- .../test_server_openai_event_adapter.py | 4 +++- .../chat_completion/test_langchain_stream_adapter.py | 5 +++-- 4 files changed, 16 insertions(+), 6 deletions(-) diff --git a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py index 78dd66b..f56577b 100644 --- a/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py +++ b/langchain_openai_api_bridge/chat_completion/langchain_stream_adapter.py @@ -12,6 +12,7 @@ OpenAIChatCompletionChunkObject, ) + class LangchainStreamAdapter: def __init__(self, llm_model: str, system_fingerprint: str = ""): self.llm_model = llm_model @@ -21,7 +22,7 @@ async def ato_chat_completion_chunk_stream( self, astream_event: AsyncIterator[StreamEvent], id: str = "", - event_adapter = lambda event: None, + event_adapter=lambda event: None, ) -> AsyncIterator[OpenAIChatCompletionChunkObject]: if id == "": id = str(uuid.uuid4()) diff --git a/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py b/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py index 75df4ef..fd2e9c8 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py +++ b/tests/test_functional/fastapi_chat_completion_openai/server_openai_event_adapter.py @@ -26,7 +26,8 @@ allow_headers=["*"], expose_headers=["*"], ) - + + def create_agent(dto: CreateAgentDto): return ChatOpenAI( temperature=dto.temperature or 0.7, @@ -35,15 +36,20 @@ def create_agent(dto: CreateAgentDto): api_key=dto.api_key, ) + bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent) + def event_adapter(event): kind = event["event"] match kind: case "on_chat_model_stream": return event -bridge.bind_openai_chat_completion(prefix="/my-custom-events-path", event_adapter=event_adapter) + +bridge.bind_openai_chat_completion( + prefix="/my-custom-events-path", event_adapter=event_adapter +) if __name__ == "__main__": uvicorn.run(app, host="localhost") diff --git a/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py index a97ca05..3a9f60d 100644 --- a/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py +++ b/tests/test_functional/fastapi_chat_completion_openai/test_server_openai_event_adapter.py @@ -6,6 +6,7 @@ test_api = TestClient(app) + @pytest.fixture def openai_client_custom_events(): return OpenAI( @@ -13,6 +14,7 @@ def openai_client_custom_events(): http_client=test_api, ) + def test_chat_completion_invoke_custom_events(openai_client_custom_events): chat_completion = openai_client_custom_events.chat.completions.create( model="gpt-4o-mini", @@ -39,4 +41,4 @@ def test_chat_completion_stream_custom_events(openai_client_custom_events): stream_output = "".join(every_content) - assert "This is a test" in stream_output \ No newline at end of file + assert "This is a test" in stream_output diff --git a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py index 1e49c46..1638256 100644 --- a/tests/test_unit/chat_completion/test_langchain_stream_adapter.py +++ b/tests/test_unit/chat_completion/test_langchain_stream_adapter.py @@ -44,7 +44,6 @@ async def test_stream_contains_every_on_chat_model_stream( assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict() assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict() - @pytest.mark.asyncio @patch( "langchain_openai_api_bridge.chat_completion.langchain_stream_adapter.to_openai_chat_completion_chunk_object", @@ -70,7 +69,9 @@ def event_adapter(event): case "on_chat_model_stream": return event - response_stream = self.instance.ato_chat_completion_chunk_stream(input_stream, event_adapter=event_adapter) + response_stream = self.instance.ato_chat_completion_chunk_stream( + input_stream, event_adapter=event_adapter + ) items = await assemble_stream(response_stream) assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict()