-
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: third party injector lib example
- Loading branch information
Showing
5 changed files
with
175 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
from .agent_factory import AgentFactory | ||
|
||
__all__ = [ | ||
"AgentFactory", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,22 @@ | ||
from injector import Binder, Module, singleton | ||
|
||
from langchain_openai_api_bridge.assistant import ( | ||
ThreadRepository, | ||
MessageRepository, | ||
RunRepository, | ||
InMemoryThreadRepository, | ||
InMemoryMessageRepository, | ||
InMemoryRunRepository, | ||
) | ||
from langchain_openai_api_bridge.core import AgentFactory | ||
from tests.test_functional.injector.with_injector_my_agent_factory import ( | ||
WithInjectorMyAgentFactory, | ||
) | ||
|
||
|
||
class MyAppModule(Module): | ||
def configure(self, binder: Binder): | ||
binder.bind(ThreadRepository, to=InMemoryThreadRepository, scope=singleton) | ||
binder.bind(MessageRepository, to=InMemoryMessageRepository, scope=singleton) | ||
binder.bind(RunRepository, to=InMemoryRunRepository, scope=singleton) | ||
binder.bind(AgentFactory, to=WithInjectorMyAgentFactory) |
66 changes: 66 additions & 0 deletions
66
tests/test_functional/injector/test_with_injector_assistant_server_openai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
from typing import List | ||
import pytest | ||
from openai import OpenAI | ||
from openai.types.beta import AssistantStreamEvent, Thread | ||
|
||
from fastapi.testclient import TestClient | ||
from with_injector_assistant_server_openai import app | ||
from tests.test_functional.assistant_stream_utils import ( | ||
assistant_stream_events_to_str_response, | ||
) | ||
|
||
|
||
test_api = TestClient(app) | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def openai_client(): | ||
return OpenAI( | ||
base_url="http://testserver/my-assistant/openai/v1", | ||
http_client=test_api, | ||
) | ||
|
||
|
||
class TestFollowupMessage: | ||
|
||
@pytest.fixture(scope="session") | ||
def thread(self, openai_client: OpenAI) -> Thread: | ||
return openai_client.beta.threads.create( | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": "Remember that my favorite fruit is banana. I Like bananas.", | ||
}, | ||
] | ||
) | ||
|
||
def test_run_stream_starts_with_thread_run_created( | ||
self, openai_client: OpenAI, thread: Thread | ||
): | ||
openai_client.beta.threads.runs.create( | ||
thread_id=thread.id, | ||
model="gpt-3.5-turbo", | ||
assistant_id="any", | ||
temperature=0, | ||
stream=True, | ||
) | ||
|
||
openai_client.beta.threads.messages.create( | ||
thread_id=thread.id, role="user", content="What is my favority fruit?" | ||
) | ||
|
||
stream_2 = openai_client.beta.threads.runs.create( | ||
thread_id=thread.id, | ||
temperature=0, | ||
model="gpt-3.5-turbo", | ||
assistant_id="any", | ||
stream=True, | ||
) | ||
|
||
events_2: List[AssistantStreamEvent] = [] | ||
for event in stream_2: | ||
events_2.append(event) | ||
|
||
followup_response = assistant_stream_events_to_str_response(events_2) | ||
|
||
assert "banana" in followup_response |
50 changes: 50 additions & 0 deletions
50
tests/test_functional/injector/with_injector_assistant_server_openai.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
from fastapi.middleware.cors import CORSMiddleware | ||
from fastapi import FastAPI | ||
from dotenv import load_dotenv, find_dotenv | ||
import uvicorn | ||
from injector import Injector | ||
|
||
from langchain_openai_api_bridge.assistant import ( | ||
ThreadRepository, | ||
MessageRepository, | ||
RunRepository, | ||
) | ||
from langchain_openai_api_bridge.core.agent_factory import AgentFactory | ||
from langchain_openai_api_bridge.fastapi import ( | ||
LangchainOpenaiApiBridgeFastAPI, | ||
) | ||
from tests.test_functional.injector.app_module import MyAppModule | ||
|
||
|
||
_ = load_dotenv(find_dotenv()) | ||
|
||
|
||
app = FastAPI( | ||
title="Langchain Agent OpenAI API Bridge", | ||
version="1.0", | ||
description="OpenAI API exposing langchain agent", | ||
) | ||
|
||
app.add_middleware( | ||
CORSMiddleware, | ||
allow_origins=["*"], | ||
allow_credentials=True, | ||
allow_methods=["*"], | ||
allow_headers=["*"], | ||
expose_headers=["*"], | ||
) | ||
|
||
injector = Injector([MyAppModule()]) | ||
|
||
bridge = LangchainOpenaiApiBridgeFastAPI( | ||
app=app, agent_factory_provider=lambda: injector.get(AgentFactory) | ||
) | ||
bridge.bind_openai_assistant_api( | ||
thread_repository_provider=lambda: injector.get(ThreadRepository), | ||
message_repository_provider=lambda: injector.get(MessageRepository), | ||
run_repository_provider=lambda: injector.get(RunRepository), | ||
prefix="/my-assistant", | ||
) | ||
|
||
if __name__ == "__main__": | ||
uvicorn.run(app, host="localhost") |
32 changes: 32 additions & 0 deletions
32
tests/test_functional/injector/with_injector_my_agent_factory.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from langchain_openai_api_bridge.core.agent_factory import AgentFactory | ||
from langgraph.graph.graph import CompiledGraph | ||
from langchain_core.language_models import BaseChatModel | ||
from langchain_core.tools import tool | ||
from langgraph.prebuilt import create_react_agent | ||
from langchain_openai import ChatOpenAI | ||
|
||
from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto | ||
|
||
|
||
@tool | ||
def magic_number_tool(input: int) -> int: | ||
"""Applies a magic function to an input.""" | ||
return input + 2 | ||
|
||
|
||
class WithInjectorMyAgentFactory(AgentFactory): | ||
|
||
def create_agent(self, llm: BaseChatModel, dto: CreateAgentDto) -> CompiledGraph: | ||
return create_react_agent( | ||
llm, | ||
[magic_number_tool], | ||
messages_modifier="""You are a helpful assistant.""", | ||
) | ||
|
||
def create_llm(self, dto: CreateAgentDto) -> CompiledGraph: | ||
return ChatOpenAI( | ||
model=dto.model, | ||
api_key=dto.api_key, | ||
streaming=True, | ||
temperature=dto.temperature, | ||
) |