Skip to content

Commit

Permalink
feat: third party injector lib example
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelint committed Jul 2, 2024
1 parent d248af4 commit 04c593d
Show file tree
Hide file tree
Showing 5 changed files with 175 additions and 0 deletions.
5 changes: 5 additions & 0 deletions langchain_openai_api_bridge/core/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
from .agent_factory import AgentFactory

__all__ = [
"AgentFactory",
]
22 changes: 22 additions & 0 deletions tests/test_functional/injector/app_module.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from injector import Binder, Module, singleton

from langchain_openai_api_bridge.assistant import (
ThreadRepository,
MessageRepository,
RunRepository,
InMemoryThreadRepository,
InMemoryMessageRepository,
InMemoryRunRepository,
)
from langchain_openai_api_bridge.core import AgentFactory
from tests.test_functional.injector.with_injector_my_agent_factory import (
WithInjectorMyAgentFactory,
)


class MyAppModule(Module):
def configure(self, binder: Binder):
binder.bind(ThreadRepository, to=InMemoryThreadRepository, scope=singleton)
binder.bind(MessageRepository, to=InMemoryMessageRepository, scope=singleton)
binder.bind(RunRepository, to=InMemoryRunRepository, scope=singleton)
binder.bind(AgentFactory, to=WithInjectorMyAgentFactory)
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
from typing import List
import pytest
from openai import OpenAI
from openai.types.beta import AssistantStreamEvent, Thread

from fastapi.testclient import TestClient
from with_injector_assistant_server_openai import app
from tests.test_functional.assistant_stream_utils import (
assistant_stream_events_to_str_response,
)


test_api = TestClient(app)


@pytest.fixture(scope="session")
def openai_client():
return OpenAI(
base_url="http://testserver/my-assistant/openai/v1",
http_client=test_api,
)


class TestFollowupMessage:

@pytest.fixture(scope="session")
def thread(self, openai_client: OpenAI) -> Thread:
return openai_client.beta.threads.create(
messages=[
{
"role": "user",
"content": "Remember that my favorite fruit is banana. I Like bananas.",
},
]
)

def test_run_stream_starts_with_thread_run_created(
self, openai_client: OpenAI, thread: Thread
):
openai_client.beta.threads.runs.create(
thread_id=thread.id,
model="gpt-3.5-turbo",
assistant_id="any",
temperature=0,
stream=True,
)

openai_client.beta.threads.messages.create(
thread_id=thread.id, role="user", content="What is my favority fruit?"
)

stream_2 = openai_client.beta.threads.runs.create(
thread_id=thread.id,
temperature=0,
model="gpt-3.5-turbo",
assistant_id="any",
stream=True,
)

events_2: List[AssistantStreamEvent] = []
for event in stream_2:
events_2.append(event)

followup_response = assistant_stream_events_to_str_response(events_2)

assert "banana" in followup_response
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI
from dotenv import load_dotenv, find_dotenv
import uvicorn
from injector import Injector

from langchain_openai_api_bridge.assistant import (
ThreadRepository,
MessageRepository,
RunRepository,
)
from langchain_openai_api_bridge.core.agent_factory import AgentFactory
from langchain_openai_api_bridge.fastapi import (
LangchainOpenaiApiBridgeFastAPI,
)
from tests.test_functional.injector.app_module import MyAppModule


_ = load_dotenv(find_dotenv())


app = FastAPI(
title="Langchain Agent OpenAI API Bridge",
version="1.0",
description="OpenAI API exposing langchain agent",
)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)

injector = Injector([MyAppModule()])

bridge = LangchainOpenaiApiBridgeFastAPI(
app=app, agent_factory_provider=lambda: injector.get(AgentFactory)
)
bridge.bind_openai_assistant_api(
thread_repository_provider=lambda: injector.get(ThreadRepository),
message_repository_provider=lambda: injector.get(MessageRepository),
run_repository_provider=lambda: injector.get(RunRepository),
prefix="/my-assistant",
)

if __name__ == "__main__":
uvicorn.run(app, host="localhost")
32 changes: 32 additions & 0 deletions tests/test_functional/injector/with_injector_my_agent_factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from langchain_openai_api_bridge.core.agent_factory import AgentFactory
from langgraph.graph.graph import CompiledGraph
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent
from langchain_openai import ChatOpenAI

from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto


@tool
def magic_number_tool(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2


class WithInjectorMyAgentFactory(AgentFactory):

def create_agent(self, llm: BaseChatModel, dto: CreateAgentDto) -> CompiledGraph:
return create_react_agent(
llm,
[magic_number_tool],
messages_modifier="""You are a helpful assistant.""",
)

def create_llm(self, dto: CreateAgentDto) -> CompiledGraph:
return ChatOpenAI(
model=dto.model,
api_key=dto.api_key,
streaming=True,
temperature=dto.temperature,
)

0 comments on commit 04c593d

Please sign in to comment.