-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: Hugging face chat and image generation first iteration (#7)
* feat: initial pytest implementation for openai (still use scope3 api) * ci: include pytest with python matrix * feat: autogen types from openapi scope3ai yaml * fix: adding more libraries * feat: huggingface_hub tracer * feat: hugging face image generation * fix: update directory * fix: update directory * feat: huggingface text-to-image tracer * fix: .gitignore updated * fix: Files removed from .gitignore * fix: add .idea to global gitignore * fix: fix output images for text_to_image task --------- Co-authored-by: Mathieu Virbel <mat@meltingrocks.com>
- Loading branch information
Showing
18 changed files
with
204,524 additions
and
71 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,140 @@ | ||
import time | ||
from collections.abc import AsyncIterable, Iterable | ||
from dataclasses import asdict, dataclass | ||
from typing import Any, Callable, Optional, Union | ||
|
||
from huggingface_hub import AsyncInferenceClient, InferenceClient # type: ignore[import-untyped] | ||
from huggingface_hub import ChatCompletionOutput as _ChatCompletionOutput | ||
from huggingface_hub import ChatCompletionStreamOutput as _ChatCompletionStreamOutput | ||
|
||
from scope3ai.api.types import Scope3AIContext, Model, ImpactRow | ||
from scope3ai.lib import Scope3AI | ||
|
||
PROVIDER = "huggingface_hub" | ||
|
||
|
||
@dataclass | ||
class ChatCompletionOutput(_ChatCompletionOutput): | ||
scope3ai: Optional[Scope3AIContext] = None | ||
|
||
|
||
@dataclass | ||
class ChatCompletionStreamOutput(_ChatCompletionStreamOutput): | ||
scope3ai: Optional[Scope3AIContext] = None | ||
|
||
|
||
def huggingface_chat_wrapper( | ||
wrapped: Callable, instance: InferenceClient, args: Any, kwargs: Any | ||
) -> Union[ChatCompletionOutput, Iterable[ChatCompletionStreamOutput]]: | ||
if kwargs.get("stream", False): | ||
return huggingface_chat_wrapper_stream(wrapped, instance, args, kwargs) | ||
else: | ||
return huggingface_chat_wrapper_non_stream(wrapped, instance, args, kwargs) | ||
|
||
|
||
def huggingface_chat_wrapper_non_stream( | ||
wrapped: Callable, instance: InferenceClient, args: Any, kwargs: Any | ||
) -> ChatCompletionOutput: | ||
timer_start = time.perf_counter() | ||
response = wrapped(*args, **kwargs) | ||
request_latency = time.perf_counter() - timer_start | ||
model_requested = instance.model | ||
model_used = response.model | ||
scope3_row = ImpactRow( | ||
model=Model(id=model_requested), | ||
model_used=Model(id=model_used), | ||
input_tokens=response.usage.prompt_tokens, | ||
output_tokens=response.usage.completion_tokens, | ||
request_duration_ms=request_latency * 1000, | ||
managed_service_id=PROVIDER, | ||
) | ||
scope3ai_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
return ChatCompletionOutput(**asdict(response), scope3ai=scope3ai_ctx) | ||
|
||
|
||
def huggingface_chat_wrapper_stream( | ||
wrapped: Callable, instance: InferenceClient, args: Any, kwargs: Any | ||
) -> Iterable[ChatCompletionStreamOutput]: | ||
timer_start = time.perf_counter() | ||
if "stream_options" not in kwargs: | ||
kwargs["stream_options"] = {} | ||
if "include_usage" not in kwargs["stream_options"]: | ||
kwargs["stream_options"]["include_usage"] = True | ||
elif not kwargs["stream_options"]["include_usage"]: | ||
raise ValueError("stream_options include_usage must be True") | ||
stream = wrapped(*args, **kwargs) | ||
token_count = 0 | ||
model_request = instance.model | ||
model_used = instance.model | ||
for chunk in stream: | ||
token_count += 1 | ||
request_latency = time.perf_counter() - timer_start | ||
scope3_row = ImpactRow( | ||
model=Model(id=model_request), | ||
model_used=Model(id=model_used), | ||
input_tokens=chunk.usage.prompt_tokens, | ||
output_tokens=chunk.usage.completion_tokens, | ||
request_duration_ms=request_latency | ||
* 1000, # TODO: can we get the header that has the processing time | ||
managed_service_id=PROVIDER, | ||
) | ||
scope3_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
yield ChatCompletionStreamOutput(**asdict(chunk), scope3ai=scope3_ctx) | ||
|
||
|
||
async def huggingface_async_chat_wrapper( | ||
wrapped: Callable, instance: AsyncInferenceClient, args: Any, kwargs: Any | ||
) -> Union[ChatCompletionOutput, AsyncIterable[ChatCompletionStreamOutput]]: | ||
if kwargs.get("stream", False): | ||
return huggingface_async_chat_wrapper_stream(wrapped, instance, args, kwargs) | ||
else: | ||
return await huggingface_async_chat_wrapper_non_stream( | ||
wrapped, instance, args, kwargs | ||
) | ||
|
||
|
||
async def huggingface_async_chat_wrapper_non_stream( | ||
wrapped: Callable, instance: AsyncInferenceClient, args: Any, kwargs: Any | ||
) -> ChatCompletionOutput: | ||
timer_start = time.perf_counter() | ||
response = await wrapped(*args, **kwargs) | ||
request_latency = time.perf_counter() - timer_start | ||
model_requested = kwargs["model"] | ||
model_used = response.model | ||
|
||
scope3_row = ImpactRow( | ||
model=Model(id=model_requested), | ||
model_used=Model(id=model_used), | ||
input_tokens=response.usage.prompt_tokens, | ||
output_tokens=response.usage.completion_tokens, | ||
request_duration_ms=request_latency | ||
* 1000, # TODO: can we get the header that has the processing time | ||
managed_service_id=PROVIDER, | ||
) | ||
|
||
scope3_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
return ChatCompletionOutput(**asdict(response), scope3ai=scope3_ctx) | ||
|
||
|
||
async def huggingface_async_chat_wrapper_stream( | ||
wrapped: Callable, instance: AsyncInferenceClient, args: Any, kwargs: Any | ||
) -> AsyncIterable[ChatCompletionStreamOutput]: | ||
timer_start = time.perf_counter() | ||
stream = await wrapped(*args, **kwargs) | ||
token_count = 0 | ||
model_request = kwargs["model"] | ||
model_used = instance.model | ||
async for chunk in stream: | ||
token_count += 1 | ||
request_latency = time.perf_counter() - timer_start | ||
scope3_row = ImpactRow( | ||
model=Model(id=model_request), | ||
model_used=Model(id=model_used), | ||
input_tokens=chunk.usage.prompt_tokens, | ||
output_tokens=chunk.usage.completion_tokens, | ||
request_duration_ms=request_latency | ||
* 1000, # TODO: can we get the header that has the processing time | ||
managed_service_id=PROVIDER, | ||
) | ||
scope3_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
yield ChatCompletionStreamOutput(**asdict(chunk), scope3ai=scope3_ctx) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from scope3ai.tracers.huggingface.chat import ( | ||
huggingface_chat_wrapper, | ||
huggingface_async_chat_wrapper, | ||
) | ||
from scope3ai.tracers.huggingface.text_to_image import huggingface_text_to_image_wrapper | ||
from wrapt import wrap_function_wrapper # type: ignore[import-untyped] | ||
|
||
from scope3ai.tracers.huggingface.translation import ( | ||
huggingface_translation_wrapper_non_stream, | ||
) | ||
|
||
|
||
class HuggingfaceInstrumentor: | ||
def __init__(self) -> None: | ||
self.wrapped_methods = [ | ||
{ | ||
"module": "huggingface_hub.inference._client", | ||
"name": "InferenceClient.chat_completion", | ||
"wrapper": huggingface_chat_wrapper, | ||
}, | ||
{ | ||
"module": "huggingface_hub.inference._client", | ||
"name": "InferenceClient.text_to_image", | ||
"wrapper": huggingface_text_to_image_wrapper, | ||
}, | ||
{ | ||
"module": "huggingface_hub.inference._client", | ||
"name": "InferenceClient.translation", | ||
"wrapper": huggingface_translation_wrapper_non_stream, | ||
}, | ||
{ | ||
"module": "huggingface_hub.inference._generated._async_client", | ||
"name": "AsyncInferenceClient.chat_completion", | ||
"wrapper": huggingface_async_chat_wrapper, | ||
}, | ||
] | ||
|
||
def instrument(self) -> None: | ||
for wrapper in self.wrapped_methods: | ||
wrap_function_wrapper( | ||
wrapper["module"], wrapper["name"], wrapper["wrapper"] | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import time | ||
import tiktoken | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Optional | ||
|
||
from huggingface_hub import InferenceClient # type: ignore[import-untyped] | ||
from huggingface_hub import TextToImageOutput as _TextToImageOutput | ||
|
||
from scope3ai.api.types import Scope3AIContext, Model, ImpactRow | ||
from scope3ai.api.typesgen import Task | ||
from scope3ai.lib import Scope3AI | ||
|
||
PROVIDER = "huggingface_hub" | ||
|
||
|
||
@dataclass | ||
class TextToImageOutput(_TextToImageOutput): | ||
scope3ai: Optional[Scope3AIContext] = None | ||
|
||
|
||
def huggingface_text_to_image_wrapper_non_stream( | ||
wrapped: Callable, instance: InferenceClient, args: Any, kwargs: Any | ||
) -> TextToImageOutput: | ||
timer_start = time.perf_counter() | ||
response = wrapped(*args, **kwargs) | ||
request_latency = time.perf_counter() - timer_start | ||
if kwargs.get("model"): | ||
model_requested = kwargs.get("model") | ||
model_used = kwargs.get("model") | ||
else: | ||
recommended_model = instance.get_recommended_model("text-to-image") | ||
model_requested = recommended_model | ||
model_used = recommended_model | ||
encoder = tiktoken.get_encoding("cl100k_base") | ||
if len(args) > 0: | ||
prompt = args[0] | ||
else: | ||
prompt = kwargs["prompt"] | ||
input_tokens = len(encoder.encode(prompt)) | ||
width, height = response.size | ||
scope3_row = ImpactRow( | ||
model=Model(id=model_requested), | ||
model_used=Model(id=model_used), | ||
input_tokens=input_tokens, | ||
task=Task.text_to_image, | ||
output_images=["{width}x{height}".format(width=width, height=height)], | ||
request_duration_ms=request_latency | ||
* 1000, # TODO: can we get the header that has the processing time | ||
managed_service_id=PROVIDER, | ||
) | ||
|
||
scope3_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
return TextToImageOutput(response, scope3ai=scope3_ctx) | ||
|
||
|
||
def huggingface_text_to_image_wrapper( | ||
wrapped: Callable, instance: InferenceClient, args: Any, kwargs: Any | ||
) -> TextToImageOutput: | ||
return huggingface_text_to_image_wrapper_non_stream(wrapped, instance, args, kwargs) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
import time | ||
import tiktoken | ||
from dataclasses import dataclass, asdict | ||
from typing import Any, Callable, Optional | ||
|
||
from huggingface_hub import InferenceClient # type: ignore[import-untyped] | ||
from huggingface_hub import TranslationOutput as _TranslationOutput | ||
|
||
from scope3ai.api.types import Scope3AIContext, Model, ImpactRow | ||
from scope3ai.api.typesgen import Task | ||
from scope3ai.lib import Scope3AI | ||
|
||
PROVIDER = "huggingface_hub" | ||
|
||
|
||
@dataclass | ||
class TranslationOutput(_TranslationOutput): | ||
scope3ai: Optional[Scope3AIContext] = None | ||
|
||
|
||
def huggingface_translation_wrapper_non_stream( | ||
wrapped: Callable, instance: InferenceClient, args: Any, kwargs: Any | ||
) -> TranslationOutput: | ||
timer_start = time.perf_counter() | ||
response = wrapped(*args, **kwargs) | ||
request_latency = time.perf_counter() - timer_start | ||
if kwargs.get("model"): | ||
model_requested = kwargs.get("model") | ||
model_used = kwargs.get("model") | ||
else: | ||
recommended_model = instance.get_recommended_model("translation") | ||
model_requested = recommended_model | ||
model_used = recommended_model | ||
encoder = tiktoken.get_encoding("cl100k_base") | ||
if len(args) > 0: | ||
prompt = args[0] | ||
else: | ||
prompt = kwargs["text"] | ||
input_tokens = len(encoder.encode(prompt)) | ||
output_tokens = len(encoder.encode(response.translation_text)) | ||
scope3_row = ImpactRow( | ||
model=Model(id=model_requested), | ||
model_used=Model(id=model_used), | ||
task=Task.translation, | ||
input_tokens=input_tokens, | ||
output_tokens=output_tokens, # TODO: How we can calculate the output tokens of a translation? | ||
request_duration_ms=request_latency | ||
* 1000, # TODO: can we get the header that has the processing time | ||
managed_service_id=PROVIDER, | ||
) | ||
|
||
scope3_ctx = Scope3AI.get_instance().submit_impact(scope3_row) | ||
return TranslationOutput(**asdict(response), scope3ai=scope3_ctx) | ||
|
||
|
||
def huggingface_text_to_image_wrapper( | ||
wrapped: Callable, instance: InferenceClient, args: Any, kwargs: Any | ||
) -> TranslationOutput: | ||
return huggingface_translation_wrapper_non_stream(wrapped, instance, args, kwargs) |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
from wrapt import wrap_function_wrapper | ||
|
||
from scope3ai.tracers.openai.chat import openai_chat_wrapper, openai_async_chat_wrapper | ||
|
||
|
||
class OpenAIInstrumentor: | ||
def __init__(self) -> None: | ||
self.wrapped_methods = [ | ||
{ | ||
"module": "openai.resources.chat.completions", | ||
"name": "Completions.create", | ||
"wrapper": openai_chat_wrapper, | ||
}, | ||
{ | ||
"module": "openai.resources.chat.completions", | ||
"name": "AsyncCompletions.create", | ||
"wrapper": openai_async_chat_wrapper, | ||
}, | ||
] | ||
|
||
def instrument(self) -> None: | ||
for wrapper in self.wrapped_methods: | ||
wrap_function_wrapper( | ||
wrapper["module"], wrapper["name"], wrapper["wrapper"] | ||
) |
Oops, something went wrong.