Skip to content

Commit

Permalink
fix(huggingface): fixing aiohttp not working with VCR if passing file…
Browse files Browse the repository at this point in the history
…name (#68)
  • Loading branch information
tito authored Jan 20, 2025
1 parent 48404b6 commit 0060c03
Show file tree
Hide file tree
Showing 7 changed files with 66,443 additions and 761 deletions.
4 changes: 2 additions & 2 deletions scope3ai/response_interceptor/aiohttp_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@

AIOHTTP_RESPONSE_BASEKEY = "scope3ai__aiohttp_interceptor"
AIOHTTP_RESPONSE_ENABLED = contextvars.ContextVar(
f"{AIOHTTP_RESPONSE_BASEKEY}__enabled", default=None
f"{AIOHTTP_RESPONSE_BASEKEY}__enabled", default=False
)
AIOHTTP_RESPONSE_VALUE = contextvars.ContextVar(
AIOHTTP_RESPONSE_VALUE = contextvars.ContextVar[Any](
f"{AIOHTTP_RESPONSE_BASEKEY}__value", default=None
)

Expand Down
3 changes: 2 additions & 1 deletion scope3ai/tracers/huggingface/vision/object_detection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import io
import time
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Callable, List, Optional, Union

from aiohttp import ClientResponse
Expand Down Expand Up @@ -44,7 +45,7 @@ def _hugging_face_object_detection_get_impact_row(
compute_time = http_response.headers.get("x-compute-time") or compute_time
try:
image_param = args[0] if len(args) > 0 else kwargs["image"]
if isinstance(image_param, str):
if isinstance(image_param, (str, Path)):
input_image = Image.open(image_param)
else:
input_image = Image.open(io.BytesIO(image_param))
Expand Down
33,196 changes: 33,196 additions & 0 deletions tests/cassettes/test_huggingface_hub_object_detection_async[Path].yaml

Large diffs are not rendered by default.

Large diffs are not rendered by default.

33,196 changes: 33,196 additions & 0 deletions tests/cassettes/test_huggingface_hub_object_detection_async[str].yaml

Large diffs are not rendered by default.

23 changes: 22 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,13 +102,14 @@ def async_api_client(docker_api_info):


@pytest.fixture(autouse=True)
def fix_vcr_binary():
def fix_vcr_binary_utf8_decoding():
# this handle httpx UTF-8 decoding issue
# https://github.com/kevin1024/vcrpy/pull/882

import warnings

import vcr # type: ignore[import-untyped]
import vcr.stubs.httpx_stubs
from vcr.request import Request as VcrRequest # type: ignore[import-untyped]
from vcr.stubs.httpx_stubs import ( # type: ignore
_make_vcr_request, # noqa: F401 this is needed for some reason so python knows this method exists
Expand All @@ -133,3 +134,23 @@ def _fixed__make_vcr_request( # type: ignore
vcr.stubs.httpx_stubs._make_vcr_request = _fixed__make_vcr_request
yield
vcr.stubs.httpx_stubs._make_vcr_request = _make_vcr_request


@pytest.fixture(autouse=True)
def fix_vcr_body_read_missing_seek():
# if the body is a file or bufferedio with aiohttp used
# the body will be read by the vcr.Request, and the next body.read()
# done by aiohttp will be empty.
# IMO it is missing a seek(0) in the vcr.Request constructor after the read()
import vcr.request

request_init_orig = vcr.request.Request.__init__

def _fixed__request_init(self, method, uri, body, headers):
request_init_orig(self, method, uri, body, headers)
if self._was_file:
body.seek(0)

vcr.request.Request.__init__ = _fixed__request_init
yield
vcr.request.Request.__init__ = request_init_orig
27 changes: 21 additions & 6 deletions tests/test_huggingface_tracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,17 +247,32 @@ async def test_huggingface_hub_image_to_image_async(tracer_with_sync_init):


@pytest.mark.vcr
@pytest.mark.parametrize(
"image_type",
["str", "Path", "bytes"],
)
@pytest.mark.asyncio
async def test_huggingface_hub_object_detection_async(tracer_with_sync_init):
async def test_huggingface_hub_object_detection_async(
tracer_with_sync_init, image_type
):
client = AsyncInferenceClient()
datadir = Path(__file__).parent / "data"
street_scene_image = open((datadir / "street_scene.png").as_posix(), "rb")
street_scene_image_bytes = street_scene_image.read()
response = await client.object_detection(street_scene_image_bytes)
assert getattr(response, "scope3ai") is not None

path = datadir / "street_scene.png"
if image_type == "str":
street_scene_image = path.as_posix()
elif image_type == "Path":
street_scene_image = path
elif image_type == "bytes":
street_scene_image = path.read_bytes()
else:
assert 0
return

response = await client.object_detection(street_scene_image)
assert getattr(response, "scope3ai") is not None
assert response.scope3ai.request.input_images == [Image(root="1024x1024")]
assert response.scope3ai.request.request_duration_ms == 657
assert response.scope3ai.request.request_duration_ms > 0
assert response.scope3ai.impact is not None
assert response.scope3ai.impact.total_impact is not None
assert response.scope3ai.impact.total_impact.usage_energy_wh > 0
Expand Down

0 comments on commit 0060c03

Please sign in to comment.