Skip to content

Commit

Permalink
feat(litellm): litellm integration (#28)
Browse files Browse the repository at this point in the history
* feat: litellm integration

* fix: test coherence litellm

* fix: test coherence litellm

* fix(litellm): fixing some tests

* fix(litellm): fixing some tests
  • Loading branch information
kevdevg authored Dec 27, 2024
1 parent 60c6295 commit c1f56b9
Show file tree
Hide file tree
Showing 9 changed files with 765 additions and 5 deletions.
11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,12 @@ pip install scope3ai
## Library and SDK support Matrix

| Library/SDK | Text generation | TTS | STT | Image Generation | Translation |
|-------------|-----------------|-----|-----|------------------|-------------|
| Anthropic || | | | |
| Cohere (v1) || | | | |
| OpenAI || | | | |
| Huggingface ||||||
|-------------|-----------------|----|-----|------------------|-----------|
| Anthropic || | | | |
| Cohere (v1) || | | | |
| OpenAI || | | | |
| Huggingface ||||||
| LiteLLM || | | | |

Roadmap:
- Cohere (client v2)
Expand Down
10 changes: 10 additions & 0 deletions scope3ai/lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .api.tracer import Tracer
from .api.types import ImpactRow, ImpactResponse, Scope3AIContext
from .api.defaults import DEFAULT_API_URL

from .worker import BackgroundWorker

logger = logging.getLogger("scope3ai.lib")
Expand Down Expand Up @@ -50,11 +51,20 @@ def init_huggingface_hub_instrumentor() -> None:
instrumentor.instrument()


def init_litellm_instrumentor() -> None:
if importlib.util.find_spec("litellm") is not None:
from scope3ai.tracers.litellm.instrument import LiteLLMInstrumentor

instrumentor = LiteLLMInstrumentor()
instrumentor.instrument()


_INSTRUMENTS = {
"anthropic": init_anthropic_instrumentor,
"cohere": init_cohere_instrumentor,
"openai": init_openai_instrumentor,
"huggingface_hub": init_huggingface_hub_instrumentor,
"litellm": init_litellm_instrumentor,
}


Expand Down
153 changes: 153 additions & 0 deletions scope3ai/tracers/litellm/chat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import time
from typing import Any, Callable, Optional, Union

from litellm import AsyncCompletions, Completions
from litellm.types.utils import ModelResponse
from litellm.utils import CustomStreamWrapper

from scope3ai import Scope3AI
from scope3ai.api.types import Scope3AIContext, Model, ImpactRow


PROVIDER = "litellm"


class ChatCompletion(ModelResponse):
scope3ai: Optional[Scope3AIContext] = None


class ChatCompletionChunk(ModelResponse):
scope3ai: Optional[Scope3AIContext] = None


def litellm_chat_wrapper(
wrapped: Callable, instance: Completions, args: Any, kwargs: Any
) -> Union[ChatCompletion, CustomStreamWrapper]:
if kwargs.get("stream", False):
return litellm_chat_wrapper_stream(wrapped, instance, args, kwargs)
else:
return litellm_chat_wrapper_non_stream(wrapped, instance, args, kwargs)


def litellm_chat_wrapper_stream( # type: ignore[misc]
wrapped: Callable,
instance: Completions, # noqa: ARG001
args: Any,
kwargs: Any,
) -> CustomStreamWrapper:
timer_start = time.perf_counter()
stream = wrapped(*args, **kwargs)
token_count = 0
for i, chunk in enumerate(stream):
if i > 0 and chunk.choices[0].finish_reason is None:
token_count += 1
request_latency = time.perf_counter() - timer_start

model = chunk.model
if model is not None:
scope3_row = ImpactRow(
model=Model(id=model),
output_tokens=token_count,
request_duration_ms=float(request_latency) * 1000,
managed_service_id=PROVIDER,
)
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row)
if scope3ai_ctx is not None:
yield ChatCompletionChunk(**chunk.model_dump(), scope3ai=scope3ai_ctx)
else:
yield chunk
else:
yield chunk


def litellm_chat_wrapper_non_stream(
wrapped: Callable,
instance: Completions, # noqa: ARG001
args: Any,
kwargs: Any,
) -> ChatCompletion:
timer_start = time.perf_counter()
response = wrapped(*args, **kwargs)
request_latency = time.perf_counter() - timer_start
model = response.model
if model is None:
return response
scope3_row = ImpactRow(
model=Model(id=model),
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.total_tokens,
request_duration_ms=float(request_latency) * 1000,
managed_service_id=PROVIDER,
)
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row)
if scope3ai_ctx is not None:
return ChatCompletion(**response.model_dump(), scope3ai=scope3ai_ctx)
else:
return response


async def litellm_async_chat_wrapper(
wrapped: Callable, instance: AsyncCompletions, args: Any, kwargs: Any
) -> Union[ChatCompletion, CustomStreamWrapper]:
if kwargs.get("stream", False):
return litellm_async_chat_wrapper_stream(wrapped, instance, args, kwargs)
else:
return await litellm_async_chat_wrapper_base(wrapped, instance, args, kwargs)


async def litellm_async_chat_wrapper_base(
wrapped: Callable,
instance: AsyncCompletions, # noqa: ARG001
args: Any,
kwargs: Any,
) -> ChatCompletion:
timer_start = time.perf_counter()
response = await wrapped(*args, **kwargs)
request_latency = time.perf_counter() - timer_start
model = response.model
if model is None:
return response
scope3_row = ImpactRow(
model=Model(id=model),
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.total_tokens,
request_duration_ms=float(request_latency) * 1000,
managed_service_id=PROVIDER,
)
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row)
if scope3ai_ctx is not None:
return ChatCompletion(**response.model_dump(), scope3ai=scope3ai_ctx)
else:
return response


async def litellm_async_chat_wrapper_stream( # type: ignore[misc]
wrapped: Callable,
instance: AsyncCompletions, # noqa: ARG001
args: Any,
kwargs: Any,
) -> CustomStreamWrapper:
timer_start = time.perf_counter()
stream = await wrapped(*args, **kwargs)
i = 0
token_count = 0
async for chunk in stream:
if i > 0 and chunk.choices[0].finish_reason is None:
token_count += 1
request_latency = time.perf_counter() - timer_start
model = chunk.model
if model is not None:
scope3_row = ImpactRow(
model=Model(id=model),
output_tokens=token_count,
request_duration_ms=float(request_latency) * 1000,
managed_service_id=PROVIDER,
)
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row)
if scope3ai_ctx is not None:
yield ChatCompletionChunk(**chunk.model_dump(), scope3ai=scope3ai_ctx)
else:
yield chunk
else:
yield chunk
i += 1
29 changes: 29 additions & 0 deletions scope3ai/tracers/litellm/instrument.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
import litellm
from wrapt import wrap_function_wrapper # type: ignore[import-untyped]

from scope3ai.tracers.litellm.chat import (
litellm_chat_wrapper,
litellm_async_chat_wrapper,
)


class LiteLLMInstrumentor:
def __init__(self) -> None:
self.wrapped_methods = [
{
"module": litellm,
"name": "completion",
"wrapper": litellm_chat_wrapper,
},
{
"module": litellm,
"name": "acompletion",
"wrapper": litellm_async_chat_wrapper,
},
]

def instrument(self) -> None:
for wrapper in self.wrapped_methods:
wrap_function_wrapper(
wrapper["module"], wrapper["name"], wrapper["wrapper"]
)
71 changes: 71 additions & 0 deletions tests/cassettes/test_litellm_async_chat.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
interactions:
- request:
body: '{"model": "command-r", "chat_history": [], "message": "Hello World!"}'
headers:
Accept-Encoding:
- gzip, deflate
Connection:
- keep-alive
Content-Length:
- '69'
Request-Source:
- unspecified:litellm
User-Agent:
- python-requests/2.32.3
accept:
- application/json
authorization:
- DUMMY
content-type:
- application/json
method: POST
uri: https://api.cohere.ai/v1/chat
response:
body:
string: "{\"response_id\":\"e33dc171-063e-4d14-8fa9-1fc92d3f36bc\",\"text\":\"Hello!
How's it going? I hope you're having a fantastic day today! \U0001F60A\",\"generation_id\":\"15afd773-5ac1-40dd-96df-56537933c248\",\"chat_history\":[{\"role\":\"USER\",\"message\":\"Hello
World!\"},{\"role\":\"CHATBOT\",\"message\":\"Hello! How's it going? I hope
you're having a fantastic day today! \U0001F60A\"}],\"finish_reason\":\"COMPLETE\",\"meta\":{\"api_version\":{\"version\":\"1\"},\"billed_units\":{\"input_tokens\":3,\"output_tokens\":19},\"tokens\":{\"input_tokens\":69,\"output_tokens\":19}}}"
headers:
Alt-Svc:
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000
Content-Length:
- '518'
Via:
- 1.1 google
access-control-expose-headers:
- X-Debug-Trace-ID
cache-control:
- no-cache, no-store, no-transform, must-revalidate, private, max-age=0
content-type:
- application/json
date:
- Fri, 27 Dec 2024 18:35:06 GMT
expires:
- Thu, 01 Jan 1970 00:00:00 UTC
num_chars:
- '430'
num_tokens:
- '22'
pragma:
- no-cache
server:
- envoy
vary:
- Origin
x-accel-expires:
- '0'
x-debug-trace-id:
- e2debaa1e7de7fe4e007e6e250e616b4
x-endpoint-monthly-call-limit:
- '1000'
x-envoy-upstream-service-time:
- '202'
x-trial-endpoint-call-limit:
- '10'
x-trial-endpoint-call-remaining:
- '9'
status:
code: 200
message: OK
version: 1
Loading

0 comments on commit c1f56b9

Please sign in to comment.