-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(litellm): litellm integration (#28)
* feat: litellm integration * fix: test coherence litellm * fix: test coherence litellm * fix(litellm): fixing some tests * fix(litellm): fixing some tests
- Loading branch information
Showing
9 changed files
with
765 additions
and
5 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,153 @@ | ||
import time | ||
from typing import Any, Callable, Optional, Union | ||
|
||
from litellm import AsyncCompletions, Completions | ||
from litellm.types.utils import ModelResponse | ||
from litellm.utils import CustomStreamWrapper | ||
|
||
from scope3ai import Scope3AI | ||
from scope3ai.api.types import Scope3AIContext, Model, ImpactRow | ||
|
||
|
||
PROVIDER = "litellm" | ||
|
||
|
||
class ChatCompletion(ModelResponse): | ||
scope3ai: Optional[Scope3AIContext] = None | ||
|
||
|
||
class ChatCompletionChunk(ModelResponse): | ||
scope3ai: Optional[Scope3AIContext] = None | ||
|
||
|
||
def litellm_chat_wrapper( | ||
wrapped: Callable, instance: Completions, args: Any, kwargs: Any | ||
) -> Union[ChatCompletion, CustomStreamWrapper]: | ||
if kwargs.get("stream", False): | ||
return litellm_chat_wrapper_stream(wrapped, instance, args, kwargs) | ||
else: | ||
return litellm_chat_wrapper_non_stream(wrapped, instance, args, kwargs) | ||
|
||
|
||
def litellm_chat_wrapper_stream( # type: ignore[misc] | ||
wrapped: Callable, | ||
instance: Completions, # noqa: ARG001 | ||
args: Any, | ||
kwargs: Any, | ||
) -> CustomStreamWrapper: | ||
timer_start = time.perf_counter() | ||
stream = wrapped(*args, **kwargs) | ||
token_count = 0 | ||
for i, chunk in enumerate(stream): | ||
if i > 0 and chunk.choices[0].finish_reason is None: | ||
token_count += 1 | ||
request_latency = time.perf_counter() - timer_start | ||
|
||
model = chunk.model | ||
if model is not None: | ||
scope3_row = ImpactRow( | ||
model=Model(id=model), | ||
output_tokens=token_count, | ||
request_duration_ms=float(request_latency) * 1000, | ||
managed_service_id=PROVIDER, | ||
) | ||
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
if scope3ai_ctx is not None: | ||
yield ChatCompletionChunk(**chunk.model_dump(), scope3ai=scope3ai_ctx) | ||
else: | ||
yield chunk | ||
else: | ||
yield chunk | ||
|
||
|
||
def litellm_chat_wrapper_non_stream( | ||
wrapped: Callable, | ||
instance: Completions, # noqa: ARG001 | ||
args: Any, | ||
kwargs: Any, | ||
) -> ChatCompletion: | ||
timer_start = time.perf_counter() | ||
response = wrapped(*args, **kwargs) | ||
request_latency = time.perf_counter() - timer_start | ||
model = response.model | ||
if model is None: | ||
return response | ||
scope3_row = ImpactRow( | ||
model=Model(id=model), | ||
input_tokens=response.usage.prompt_tokens, | ||
output_tokens=response.usage.total_tokens, | ||
request_duration_ms=float(request_latency) * 1000, | ||
managed_service_id=PROVIDER, | ||
) | ||
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
if scope3ai_ctx is not None: | ||
return ChatCompletion(**response.model_dump(), scope3ai=scope3ai_ctx) | ||
else: | ||
return response | ||
|
||
|
||
async def litellm_async_chat_wrapper( | ||
wrapped: Callable, instance: AsyncCompletions, args: Any, kwargs: Any | ||
) -> Union[ChatCompletion, CustomStreamWrapper]: | ||
if kwargs.get("stream", False): | ||
return litellm_async_chat_wrapper_stream(wrapped, instance, args, kwargs) | ||
else: | ||
return await litellm_async_chat_wrapper_base(wrapped, instance, args, kwargs) | ||
|
||
|
||
async def litellm_async_chat_wrapper_base( | ||
wrapped: Callable, | ||
instance: AsyncCompletions, # noqa: ARG001 | ||
args: Any, | ||
kwargs: Any, | ||
) -> ChatCompletion: | ||
timer_start = time.perf_counter() | ||
response = await wrapped(*args, **kwargs) | ||
request_latency = time.perf_counter() - timer_start | ||
model = response.model | ||
if model is None: | ||
return response | ||
scope3_row = ImpactRow( | ||
model=Model(id=model), | ||
input_tokens=response.usage.prompt_tokens, | ||
output_tokens=response.usage.total_tokens, | ||
request_duration_ms=float(request_latency) * 1000, | ||
managed_service_id=PROVIDER, | ||
) | ||
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
if scope3ai_ctx is not None: | ||
return ChatCompletion(**response.model_dump(), scope3ai=scope3ai_ctx) | ||
else: | ||
return response | ||
|
||
|
||
async def litellm_async_chat_wrapper_stream( # type: ignore[misc] | ||
wrapped: Callable, | ||
instance: AsyncCompletions, # noqa: ARG001 | ||
args: Any, | ||
kwargs: Any, | ||
) -> CustomStreamWrapper: | ||
timer_start = time.perf_counter() | ||
stream = await wrapped(*args, **kwargs) | ||
i = 0 | ||
token_count = 0 | ||
async for chunk in stream: | ||
if i > 0 and chunk.choices[0].finish_reason is None: | ||
token_count += 1 | ||
request_latency = time.perf_counter() - timer_start | ||
model = chunk.model | ||
if model is not None: | ||
scope3_row = ImpactRow( | ||
model=Model(id=model), | ||
output_tokens=token_count, | ||
request_duration_ms=float(request_latency) * 1000, | ||
managed_service_id=PROVIDER, | ||
) | ||
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
if scope3ai_ctx is not None: | ||
yield ChatCompletionChunk(**chunk.model_dump(), scope3ai=scope3ai_ctx) | ||
else: | ||
yield chunk | ||
else: | ||
yield chunk | ||
i += 1 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
import litellm | ||
from wrapt import wrap_function_wrapper # type: ignore[import-untyped] | ||
|
||
from scope3ai.tracers.litellm.chat import ( | ||
litellm_chat_wrapper, | ||
litellm_async_chat_wrapper, | ||
) | ||
|
||
|
||
class LiteLLMInstrumentor: | ||
def __init__(self) -> None: | ||
self.wrapped_methods = [ | ||
{ | ||
"module": litellm, | ||
"name": "completion", | ||
"wrapper": litellm_chat_wrapper, | ||
}, | ||
{ | ||
"module": litellm, | ||
"name": "acompletion", | ||
"wrapper": litellm_async_chat_wrapper, | ||
}, | ||
] | ||
|
||
def instrument(self) -> None: | ||
for wrapper in self.wrapped_methods: | ||
wrap_function_wrapper( | ||
wrapper["module"], wrapper["name"], wrapper["wrapper"] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
interactions: | ||
- request: | ||
body: '{"model": "command-r", "chat_history": [], "message": "Hello World!"}' | ||
headers: | ||
Accept-Encoding: | ||
- gzip, deflate | ||
Connection: | ||
- keep-alive | ||
Content-Length: | ||
- '69' | ||
Request-Source: | ||
- unspecified:litellm | ||
User-Agent: | ||
- python-requests/2.32.3 | ||
accept: | ||
- application/json | ||
authorization: | ||
- DUMMY | ||
content-type: | ||
- application/json | ||
method: POST | ||
uri: https://api.cohere.ai/v1/chat | ||
response: | ||
body: | ||
string: "{\"response_id\":\"e33dc171-063e-4d14-8fa9-1fc92d3f36bc\",\"text\":\"Hello! | ||
How's it going? I hope you're having a fantastic day today! \U0001F60A\",\"generation_id\":\"15afd773-5ac1-40dd-96df-56537933c248\",\"chat_history\":[{\"role\":\"USER\",\"message\":\"Hello | ||
World!\"},{\"role\":\"CHATBOT\",\"message\":\"Hello! How's it going? I hope | ||
you're having a fantastic day today! \U0001F60A\"}],\"finish_reason\":\"COMPLETE\",\"meta\":{\"api_version\":{\"version\":\"1\"},\"billed_units\":{\"input_tokens\":3,\"output_tokens\":19},\"tokens\":{\"input_tokens\":69,\"output_tokens\":19}}}" | ||
headers: | ||
Alt-Svc: | ||
- h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 | ||
Content-Length: | ||
- '518' | ||
Via: | ||
- 1.1 google | ||
access-control-expose-headers: | ||
- X-Debug-Trace-ID | ||
cache-control: | ||
- no-cache, no-store, no-transform, must-revalidate, private, max-age=0 | ||
content-type: | ||
- application/json | ||
date: | ||
- Fri, 27 Dec 2024 18:35:06 GMT | ||
expires: | ||
- Thu, 01 Jan 1970 00:00:00 UTC | ||
num_chars: | ||
- '430' | ||
num_tokens: | ||
- '22' | ||
pragma: | ||
- no-cache | ||
server: | ||
- envoy | ||
vary: | ||
- Origin | ||
x-accel-expires: | ||
- '0' | ||
x-debug-trace-id: | ||
- e2debaa1e7de7fe4e007e6e250e616b4 | ||
x-endpoint-monthly-call-limit: | ||
- '1000' | ||
x-envoy-upstream-service-time: | ||
- '202' | ||
x-trial-endpoint-call-limit: | ||
- '10' | ||
x-trial-endpoint-call-remaining: | ||
- '9' | ||
status: | ||
code: 200 | ||
message: OK | ||
version: 1 |
Oops, something went wrong.