Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

First draft of custom event handler support #42

Merged
merged 7 commits into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,29 @@ class ChatCompletionCompatibleAPI:

@staticmethod
def from_agent(
agent: Runnable, llm_model: str, system_fingerprint: Optional[str] = ""
agent: Runnable,
llm_model: str,
system_fingerprint: Optional[str] = "",
custom_event_handler: callable = lambda event: None,
):
return ChatCompletionCompatibleAPI(
LangchainStreamAdapter(llm_model, system_fingerprint),
LangchainInvokeAdapter(llm_model, system_fingerprint),
agent,
custom_event_handler,
)

def __init__(
self,
stream_adapter: LangchainStreamAdapter,
invoke_adapter: LangchainInvokeAdapter,
agent: Runnable,
custom_event_handler: callable = lambda event: None,
) -> None:
self.stream_adapter = stream_adapter
self.invoke_adapter = invoke_adapter
self.agent = agent
self.custom_event_handler = custom_event_handler

def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]:
input = self.__to_input(messages)
Expand All @@ -40,7 +46,7 @@ def astream(self, messages: List[OpenAIChatMessage]) -> AsyncIterator[dict]:
version="v2",
)
return ato_dict(
self.stream_adapter.ato_chat_completion_chunk_stream(astream_event)
self.stream_adapter.ato_chat_completion_chunk_stream(astream_event, custom_event_handler=self.custom_event_handler)
)

def invoke(self, messages: List[OpenAIChatMessage]) -> dict:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
OpenAIChatCompletionChunkObject,
)


class LangchainStreamAdapter:
def __init__(self, llm_model: str, system_fingerprint: str = ""):
self.llm_model = llm_model
Expand All @@ -22,20 +21,29 @@ async def ato_chat_completion_chunk_stream(
self,
astream_event: AsyncIterator[StreamEvent],
id: str = "",
custom_event_handler = lambda event: None,
etburke marked this conversation as resolved.
Show resolved Hide resolved
) -> AsyncIterator[OpenAIChatCompletionChunkObject]:
if id == "":
id = str(uuid.uuid4())
async for event in astream_event:
kind = event["event"]
match kind:
case "on_chat_model_stream":
chunk = to_openai_chat_completion_chunk_object(
event=event,
id=id,
model=self.llm_model,
system_fingerprint=self.system_fingerprint,
)
yield chunk
custom_event = custom_event_handler(event)
if custom_event is not None:
yield to_openai_chat_completion_chunk_object(
event=custom_event,
id=id,
model=self.llm_model,
system_fingerprint=self.system_fingerprint,
)
else:
kind = event["event"]
match kind:
case "on_chat_model_stream":
yield to_openai_chat_completion_chunk_object(
event=event,
id=id,
model=self.llm_model,
system_fingerprint=self.system_fingerprint,
)
etburke marked this conversation as resolved.
Show resolved Hide resolved

stop_chunk = create_final_chat_completion_chunk_object(
id=id, model=self.llm_model
Expand Down
2 changes: 1 addition & 1 deletion langchain_openai_api_bridge/core/base_agent_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


class BaseAgentFactory(ABC):

@abstractmethod
def create_agent(self, dto: CreateAgentDto) -> Runnable:
pass
13 changes: 8 additions & 5 deletions langchain_openai_api_bridge/fastapi/chat_completion_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@

def create_chat_completion_router(
tiny_di_container: TinyDIContainer,
custom_event_handler: callable = lambda event: None,
):
chat_completion_router = APIRouter(prefix="/chat/completions")
chat_completion_router = APIRouter(prefix="/chat")

@chat_completion_router.post("/")
@chat_completion_router.post("/completions")
etburke marked this conversation as resolved.
Show resolved Hide resolved
async def assistant_retreive_thread_messages(
request: OpenAIChatCompletionRequest, authorization: str = Header(None)
):
Expand All @@ -33,7 +34,7 @@ async def assistant_retreive_thread_messages(

agent = agent_factory.create_agent(dto=create_agent_dto)

adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model)
adapter = ChatCompletionCompatibleAPI.from_agent(agent, create_agent_dto.model, custom_event_handler=custom_event_handler)

response_factory = HttpStreamResponseAdapter()
if request.stream is True:
Expand All @@ -46,9 +47,11 @@ async def assistant_retreive_thread_messages(


def create_openai_chat_completion_router(
tiny_di_container: TinyDIContainer, prefix: str = ""
tiny_di_container: TinyDIContainer,
prefix: str = "",
custom_event_handler: callable = lambda event: None,
):
router = create_chat_completion_router(tiny_di_container=tiny_di_container)
router = create_chat_completion_router(tiny_di_container=tiny_di_container, custom_event_handler=custom_event_handler)
open_ai_router = APIRouter(prefix=f"{prefix}/openai/v1")
open_ai_router.include_router(router)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,9 @@ def bind_openai_assistant_api(

self.app.include_router(assistant_router)

def bind_openai_chat_completion(self, prefix: str = "") -> None:
def bind_openai_chat_completion(self, prefix: str = "", custom_event_handler: callable = lambda event: None) -> None:
chat_completion_router = create_openai_chat_completion_router(
self.tiny_di_container, prefix=prefix
self.tiny_di_container, prefix=prefix, custom_event_handler=custom_event_handler
)

self.app.include_router(chat_completion_router)
etburke marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uvicorn

from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto
from langchain_openai_api_bridge.core.base_agent_factory import BaseAgentFactory
from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import (
LangchainOpenaiApiBridgeFastAPI,
)
Expand All @@ -27,18 +28,25 @@
expose_headers=["*"],
)

class AgentFactory(BaseAgentFactory):
def create_agent(self, dto: CreateAgentDto):
return ChatOpenAI(
temperature=dto.temperature or 0.7,
model=dto.model,
max_tokens=dto.max_tokens,
api_key=dto.api_key,
)

def create_agent(dto: CreateAgentDto):
return ChatOpenAI(
temperature=dto.temperature or 0.7,
model=dto.model,
max_tokens=dto.max_tokens,
api_key=dto.api_key,
)
bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=AgentFactory())
bridge.bind_openai_chat_completion(prefix="/my-custom-path")

def custom_event_handler(event):
etburke marked this conversation as resolved.
Show resolved Hide resolved
kind = event["event"]
match kind:
case "on_chat_model_stream":
return event

bridge = LangchainOpenaiApiBridgeFastAPI(app=app, agent_factory_provider=create_agent)
bridge.bind_openai_chat_completion(prefix="/my-custom-path")
bridge.bind_openai_chat_completion(prefix="/my-custom-events-path", custom_event_handler=custom_event_handler)

if __name__ == "__main__":
uvicorn.run(app, host="localhost")
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,38 @@ def test_chat_completion_stream(openai_client):
stream_output = "".join(every_content)

assert "This is a test" in stream_output

@pytest.fixture
def openai_client_custom_events():
return OpenAI(
base_url="http://testserver/my-custom-events-path/openai/v1",
http_client=test_api,
)

def test_chat_completion_invoke_custom_events(openai_client_custom_events):
chat_completion = openai_client_custom_events.chat.completions.create(
model="gpt-4o-mini",
messages=[
{
"role": "user",
"content": 'Say "This is a test"',
}
],
)
assert "This is a test" in chat_completion.choices[0].message.content


def test_chat_completion_stream_custom_events(openai_client_custom_events):
chunks = openai_client_custom_events.chat.completions.create(
model="gpt-4o-mini",
messages=[{"role": "user", "content": 'Say "This is a test"'}],
stream=True,
)
every_content = []
for chunk in chunks:
if chunk.choices and isinstance(chunk.choices[0].delta.content, str):
every_content.append(chunk.choices[0].delta.content)

stream_output = "".join(every_content)

assert "This is a test" in stream_output
32 changes: 32 additions & 0 deletions tests/test_unit/chat_completion/test_langchain_stream_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,35 @@ async def test_stream_contains_every_on_chat_model_stream(
items = await assemble_stream(response_stream)
assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict()
assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict()


@pytest.mark.asyncio
@patch(
"langchain_openai_api_bridge.chat_completion.langchain_stream_adapter.to_openai_chat_completion_chunk_object",
side_effect=lambda event, id, model, system_fingerprint: (
ChatCompletionChunkStub({"key": event["data"]["chunk"].content})
),
)
async def test_stream_contains_every_custom_handled_stream(
etburke marked this conversation as resolved.
Show resolved Hide resolved
self, to_openai_chat_completion_chunk_object
):
on_chat_model_stream_event1 = create_on_chat_model_stream_event(content="hello")
on_chat_model_stream_event2 = create_on_chat_model_stream_event(content="moto")
input_stream = generate_stream(
[
on_chat_model_stream_event1,
on_chat_model_stream_event2,
]
)

def custom_event_handler(event):
kind = event["event"]
match kind:
case "on_chat_model_stream":
return event

response_stream = self.instance.ato_chat_completion_chunk_stream(input_stream, custom_event_handler=custom_event_handler)

items = await assemble_stream(response_stream)
assert items[0].dict() == ChatCompletionChunkStub({"key": "hello"}).dict()
assert items[1].dict() == ChatCompletionChunkStub({"key": "moto"}).dict()