Skip to content

Commit

Permalink
First draft of custom event handler support (#42)
Browse files Browse the repository at this point in the history
  • Loading branch information
etburke authored Sep 23, 2024
1 parent d33a5d8 commit 2b20f09
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,29 @@ class ChatCompletionCompatibleAPI:

@staticmethod
def from_agent(
agent: Runnable, llm_model: str, system_fingerprint: Optional[str] = ""
agent: Runnable,
llm_model: str,
system_fingerprint: Optional[str] = "",
event_adapter: callable = lambda event: None,
):
return ChatCompletionCompatibleAPI(
LangchainStreamAdapter(llm_model, system_fingerprint),
LangchainInvokeAdapter(llm_model, system_fingerprint),
agent,
event_adapter,
)

def __init__(
self,
stream_adapter: LangchainStreamAdapter,
invoke_adapter: LangchainInvokeAdapter,
agent: Runnable,
event_adapter: callable = lambda event: None,
) -> None:
self.stream_adapter = stream_adapter
self.invoke_adapter = invoke_adapter
self.agent = agent
self.event_adapter = event_adapter

def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]:
input = self.__to_input(messages)
Expand All @@ -40,7 +46,7 @@ def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]:
version="v2",
)
return ato_dict(
self.stream_adapter.ato_chat_completion_chunk_stream(astream_event)
self.stream_adapter.ato_chat_completion_chunk_stream(astream_event, event_adapter=self.event_adapter)
)

def invoke(self, messages: List[OpenAIChatMessage]) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@ async def ato_chat_completion_chunk_stream(
self,
astream_event: AsyncIterator[StreamEvent],
id: str = "",
event_adapter=lambda event: None,
) -> AsyncIterator[OpenAIChatCompletionChunkObject]:
if id == "":
id = str(uuid.uuid4())
async for event in astream_event:
kind = event["event"]
match kind:
case "on_chat_model_stream":
chunk = to_openai_chat_completion_chunk_object(
event=event,
id=id,
model=self.llm_model,
system_fingerprint=self.system_fingerprint,
)
yield chunk
custom_event = event_adapter(event)
event_to_process = custom_event if custom_event is not None else event
kind = event_to_process["event"]
if kind == "on_chat_model_stream" or custom_event is not None:
yield to_openai_chat_completion_chunk_object(
event=event_to_process,
id=id,
model=self.llm_model,
system_fingerprint=self.system_fingerprint,
)

stop_chunk = create_final_chat_completion_chunk_object(
id=id, model=self.llm_model
Expand Down
13 changes: 8 additions & 5 deletions langchain_openai_api_bridge/fastapi/chat_completion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

def create_chat_completion_router(
tiny_di_container: TinyDIContainer,
event_adapter: callable = lambda event: None,
):
chat_completion_router = APIRouter(prefix="/chat/completions")
chat_completion_router = APIRouter(prefix="/chat")

@chat_completion_router.post("/")
@chat_completion_router.post("/completions")
async def assistant_retreive_thread_messages(
request: OpenAIChatCompletionRequest, authorization: str = Header(None)
):
Expand All @@ -33,7 +34,7 @@ async def assistant_retreive_thread_messages(

agent = agent_factory.create_agent(dto=create_agent_dto)

adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model)
adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model, event_adapter=event_adapter)

response_factory = HttpStreamResponseAdapter()
if request.stream is True:
Expand All @@ -46,9 +47,11 @@ async def assistant_retreive_thread_messages(


def create_openai_chat_completion_router(
tiny_di_container: TinyDIContainer, prefix: str = ""
tiny_di_container: TinyDIContainer,
prefix: str = "",
event_adapter: callable = lambda event: None,
):
router = create_chat_completion_router(tiny_di_container=tiny_di_container)
router = create_chat_completion_router(tiny_di_container=tiny_di_container, event_adapter=event_adapter)
open_ai_router = APIRouter(prefix=f"{prefix}/openai/v1")
open_ai_router.include_router(router)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def bind_openai_assistant_api(

self.app.include_router(assistant_router)

def bind_openai_chat_completion(self, prefix: str = "") -> None:
def bind_openai_chat_completion(self, prefix: str = "", event_adapter: callable = lambda event: None) -> None:
chat_completion_router = create_openai_chat_completion_router(
self.tiny_di_container, prefix=prefix
self.tiny_di_container, prefix=prefix, event_adapter=event_adapter
)

self.app.include_router(chat_completion_router)
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv, find_dotenv
import uvicorn

from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto
from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import (
LangchainOpenaiApiBridgeFastAPI,
)
from langchain_openai import ChatOpenAI

_ = load_dotenv(find_dotenv())


app = FastAPI(
title="Langchain Agent OpenAI API Bridge",
version="1.0",
description="OpenAI API exposing langchain agent",
)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)


def create_agent(dto: CreateAgentDto):
return ChatOpenAI(
temperature=dto.temperature or 0.7,
model=dto.model,
max_tokens=dto.max_tokens,
api_key=dto.api_key,
)


bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent)


def event_adapter(event):
kind = event["event"]
match kind:
case "on_chat_model_stream":
return event


bridge.bind_openai_chat_completion(
prefix="/my-custom-events-path", event_adapter=event_adapter
)

if __name__ == "__main__":
uvicorn.run(app, host="localhost")
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest
from openai import OpenAI
from fastapi.testclient import TestClient
from server_openai_event_adapter import app


test_api = TestClient(app)


@pytest.fixture
def openai_client_custom_events():
return OpenAI(
base_url="http://testserver/my-custom-events-path/openai/v1",
http_client=test_api,
)


def test_chat_completion_invoke_custom_events(openai_client_custom_events):
chat_completion = openai_client_custom_events.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": 'Say "This is a test"',
}
],
)
assert "This is a test" in chat_completion.choices[0].message.content


def test_chat_completion_stream_custom_events(openai_client_custom_events):
chunks = openai_client_custom_events.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": 'Say "This is a test"'}],
stream=True,
)
every_content = []
for chunk in chunks:
if chunk.choices and isinstance(chunk.choices[0].delta.content, str):
every_content.append(chunk.choices[0].delta.content)

stream_output = "".join(every_content)

assert "This is a test" in stream_output
33 changes: 33 additions & 0 deletions tests/test_unit/chat_completion/test_langchain_stream_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,36 @@ async def test_stream_contains_every_on_chat_model_stream(
items = await assemble_stream(response_stream)
assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict()
assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict()

@pytest.mark.asyncio
@patch(
"langchain_openai_api_bridge.chat_completion.langchain_stream_adapter.to_openai_chat_completion_chunk_object",
side_effect=lambda event, id, model, system_fingerprint: (
ChatCompletionChunkStub({"key": event["data"]["chunk"].content})
),
)
async def test_stream_contains_every_custom_handled_stream(
self, to_openai_chat_completion_chunk_object
):
on_chat_model_stream_event1 = create_on_chat_model_stream_event(content="hello")
on_chat_model_stream_event2 = create_on_chat_model_stream_event(content="moto")
input_stream = generate_stream(
[
on_chat_model_stream_event1,
on_chat_model_stream_event2,
]
)

def event_adapter(event):
kind = event["event"]
match kind:
case "on_chat_model_stream":
return event

response_stream = self.instance.ato_chat_completion_chunk_stream(
input_stream, event_adapter=event_adapter
)

items = await assemble_stream(response_stream)
assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict()
assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict()

0 comments on commit 2b20f09

Please sign in to comment.