Skip to content

Commit

Permalink
feat: groq support (#27)
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelint authored Jul 11, 2024
1 parent 94840b8 commit 2d935fa
Show file tree
Hide file tree
Showing 11 changed files with 260 additions and 16 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# For functional testing
OPENAI_API_KEY=""
ANTHROPIC_API_KEY=""
ANTHROPIC_API_KEY=""
GROQ_API_KEY=""
10 changes: 9 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@ A `FastAPI` + `Langchain` / `langgraph` extension to expose agent result as an O

Use any OpenAI-compatible UI or UI framework with your custom `Langchain Agent`.

Support:
### Support:

#### OpenAI API features:

-[Chat Completions API](https://platform.openai.com/docs/api-reference/chat)
- ✅ Invoke
Expand All @@ -19,6 +21,12 @@ Support:
- ✅ Tools step stream
- 🚧 Human In The Loop

#### Vendors features:

- ✅ OpenAI
- ✅ Anthropic
- ✅ Groq (excluding multimodal)

If you find this project useful, please give it a star ⭐!

## Table of Content
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from abc import abstractmethod
from typing import List, Union
from langchain_core.messages import BaseMessage

from langchain_openai_api_bridge.chat_model_adapter.base_openai_compatible_chat_model_adapter import (
BaseOpenAICompatibleChatModelAdapter,
)


class DefaultOpenAICompatibleChatModelAdapter(BaseOpenAICompatibleChatModelAdapter):

@abstractmethod
def is_compatible(self, llm_type: str):
return True

def to_openai_format_messages(
self, messages: Union[List[BaseMessage], List[List[BaseMessage]]]
):
return messages
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,21 @@
from langchain_core.language_models import BaseChatModel
from langchain_core.pydantic_v1 import root_validator

from langchain_openai_api_bridge.chat_model_adapter.default_openai_compatible_chat_model_adapter import (
DefaultOpenAICompatibleChatModelAdapter,
)


from .anthropic_openai_compatible_chat_model_adapter import (
AnthropicOpenAICompatibleChatModelAdapter,
)
from langchain_openai_api_bridge.chat_model_adapter.base_openai_compatible_chat_model_adapter import (
BaseOpenAICompatibleChatModelAdapter,
)

default_adapters = [AnthropicOpenAICompatibleChatModelAdapter()]
default_adapters = [
AnthropicOpenAICompatibleChatModelAdapter(),
]


class OpenAICompatibleChatModel(BaseChatModel):
Expand All @@ -20,12 +27,16 @@ class OpenAICompatibleChatModel(BaseChatModel):

@root_validator()
def set_adapter(cls, values):
adapter = values.get(
"adapter",
OpenAICompatibleChatModel._find_adatper(
values.get("chat_model"), values.get("adapters", default_adapters)
),
)
adapter = values.get("adapter")

if adapter is None:
chat_model = values.get("chat_model")
adapters = values.get("adapters", default_adapters)
adapter = OpenAICompatibleChatModel._find_adatper(chat_model, adapters)

if adapter is None:
raise ValueError("Could not find an adapter for the given chat model")

values["adapter"] = adapter

return values
Expand All @@ -43,7 +54,7 @@ def _stream(self, messages, stop, run_manager, **kwargs):
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)

def _astream(self, messages, stop, run_manager, **kwargs):
async def _astream(self, messages, stop, run_manager, **kwargs):
return self.chat_model._astream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
)
Expand Down Expand Up @@ -74,4 +85,4 @@ def _find_adatper(
if adapter.is_compatible(llm_type):
return adapter

raise ValueError(f"Could not find an adapter for {llm_type}")
return DefaultOpenAICompatibleChatModelAdapter()
39 changes: 37 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ fastapi = { version = "^0.111.0", optional = true }
python-dotenv = { version = "^1.0.1", optional = true }
langgraph = { version = "^0.0.62", optional = true }
langchain-anthropic = { version = "^0.1.19", optional = true }
langchain-groq = { version = "^0.1.6", optional = true }

[tool.poetry.group.dev.dependencies]
flake8 = "^7.0.0"
Expand All @@ -33,7 +34,8 @@ fastapi-injector = { version = "^0.6.0", python = ">=3.9,<3.13" }

[tool.poetry.extras]
langchain = ["langchain", "langchain-openai", "langgraph"]
langchain-anthropic = ["langchain-anthropic"]
anthropic = ["langchain-anthropic"]
groq = ["langchain-groq"]
langchain_serve = ["fastapi", "python-dotenv"]


Expand Down
17 changes: 15 additions & 2 deletions tests/test_functional/fastapi_assistant_agent_anthropic/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,26 @@

The default `from langchain_anthropic import ChatAnthropic` is not compatible with multimodal prompts as the image format differs between OpenAI and Anthropic.

To use multimodal prompts, use the `OpenAICompatibleAnthropicChatModel` adapter (from `langchain_openai_api_bridge.chat_model_adapter`) which transforms OpenAI format to Anthropic format. This enables you to use one or the other seamlessly.
To use multimodal prompts, use the `OpenAICompatibleChatModel` which transforms OpenAI format to Anthropic format. This enables you to use one or the other seamlessly.
Look at `my_anthropic_agent_factory.py` for usage example.

#### Multimodal Formats
```python
chat_model = ChatAnthropic(
model="claude-3-5-sonnet-20240620",
max_tokens=1024,
streaming=True,
)

return OpenAICompatibleChatModel(chat_model=chat_model)

```

#### Multimodal Formats differences

##### Anthropic

https://docs.anthropic.com/en/docs/build-with-claude/vision#about-the-prompt-examples

```python
{
"role": "user",
Expand Down
21 changes: 21 additions & 0 deletions tests/test_functional/fastapi_assistant_agent_groq/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
# Expose as an OpenAI API Groq assistant (Assistant API)

:warning: Groq does not support streaming with tools. Make sure to set `streaming=False,`

```python
chat_model = ChatGroq(
model="llama3-8b-8192",
streaming=False, # <<--- Must be set to False when used with LangGraph / Tools
)
```

:warning: Note that Groq models do not currently support multi-modal capabilities. Do not use payload with image reference

```python
{
"type": "image_url",
"image_url": {
"url": "data:image/jpeg;base64,iVBORw0KGgo="
},
},
```
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from fastapi.middleware.cors import CORSMiddleware
from fastapi import FastAPI
from dotenv import load_dotenv, find_dotenv
import uvicorn

from langchain_openai_api_bridge.assistant import (
InMemoryMessageRepository,
InMemoryRunRepository,
InMemoryThreadRepository,
)
from langchain_openai_api_bridge.fastapi.langchain_openai_api_bridge_fastapi import (
LangchainOpenaiApiBridgeFastAPI,
)
from tests.test_functional.fastapi_assistant_agent_groq.my_groq_agent_factory import (
MyGroqAgentFactory,
)

_ = load_dotenv(find_dotenv())


app = FastAPI(
title="Langchain Agent OpenAI API Bridge",
version="1.0",
description="OpenAI API exposing langchain agent",
)

app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
expose_headers=["*"],
)

in_memory_thread_repository = InMemoryThreadRepository()
in_memory_message_repository = InMemoryMessageRepository()
in_memory_run_repository = InMemoryRunRepository()

bridge = LangchainOpenaiApiBridgeFastAPI(
app=app, agent_factory_provider=lambda: MyGroqAgentFactory()
)
bridge.bind_openai_assistant_api(
thread_repository_provider=in_memory_thread_repository,
message_repository_provider=in_memory_message_repository,
run_repository_provider=in_memory_run_repository,
prefix="/my-groq-assistant",
)


if __name__ == "__main__":
uvicorn.run(app, host="localhost")
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from langchain_groq import ChatGroq
from langchain_openai_api_bridge.chat_model_adapter import (
OpenAICompatibleChatModel,
)
from langchain_openai_api_bridge.core.agent_factory import AgentFactory
from langgraph.graph.graph import CompiledGraph
from langchain_core.language_models import BaseChatModel
from langchain_core.tools import tool
from langgraph.prebuilt import create_react_agent

from langchain_openai_api_bridge.core.create_agent_dto import CreateAgentDto


@tool
def magic_number_tool(input: int) -> int:
"""Applies a magic function to an input."""
return input + 2


class MyGroqAgentFactory(AgentFactory):

def create_agent(self, llm: BaseChatModel, dto: CreateAgentDto) -> CompiledGraph:
return create_react_agent(
llm,
[magic_number_tool],
messages_modifier="""You are a helpful assistant.""",
)

def create_llm(self, dto: CreateAgentDto) -> CompiledGraph:
chat_model = ChatGroq(
model=dto.model,
streaming=False,
)

return OpenAICompatibleChatModel(chat_model=chat_model)
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import pytest
from openai import OpenAI

from fastapi.testclient import TestClient
from assistant_server_groq import app
from tests.test_functional.assistant_stream_utils import (
assistant_stream_events_to_str_response,
)


test_api = TestClient(app)


@pytest.fixture(scope="session")
def openai_client():
return OpenAI(
base_url="http://testserver/my-groq-assistant/openai/v1",
http_client=test_api,
)


class TestGroqAssistant:

def test_run_stream_message_deltas(
self,
openai_client: OpenAI,
):
thread = openai_client.beta.threads.create(
messages=[
{
"role": "user",
"content": "Hello!",
},
]
)

stream = openai_client.beta.threads.runs.create(
thread_id=thread.id,
model="llama3-8b-8192",
assistant_id="any",
stream=True,
temperature=0,
)

str_response = assistant_stream_events_to_str_response(stream)

assert len(str_response) > 0

0 comments on commit 2d935fa

Please sign in to comment.