Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: End to End test for vision and audio #386

Merged
merged 12 commits into from
Sep 29, 2024
4 changes: 3 additions & 1 deletion libs/infinity_emb/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@
pytest.DEFAULT_BERT_MODEL = "michaelfeil/bge-small-en-v1.5"
pytest.DEFAULT_RERANKER_MODEL = "mixedbread-ai/mxbai-rerank-xsmall-v1"
pytest.DEFAULT_CLASSIFIER_MODEL = "SamLowe/roberta-base-go_emotions"
pytest.DEFAULT_AUDIO_MODEL = "laion/clap-htsat-unfused"
pytest.DEFAULT_VISION_MODEL = "wkcn/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M"

pytest.ENGINE_METHODS = ["embed", "image_embed", "classify", "rerank"]
pytest.ENGINE_METHODS = ["embed", "image_embed", "classify", "rerank", "audio_embed"]
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nice spot



@pytest.fixture
Expand Down
122 changes: 122 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import pytest
import torch
from asgi_lifespan import LifespanManager
from fastapi import status
from httpx import AsyncClient

from infinity_emb import create_server
from infinity_emb.args import EngineArgs
from infinity_emb.primitives import Device, InferenceEngine

PREFIX = "/v1_ct2"
MODEL: str = pytest.DEFAULT_AUDIO_MODEL # type: ignore[assignment]
batch_size = 32 if torch.cuda.is_available() else 8

app = create_server(
url_prefix=PREFIX,
engine_args_list=[
EngineArgs(
model_name_or_path=MODEL,
batch_size=batch_size,
engine=InferenceEngine.torch,
device=Device.auto if not torch.backends.mps.is_available() else Device.cpu,
)
],
)


@pytest.fixture()
async def client():
async with AsyncClient(
app=app, base_url="http://test", timeout=20
) as client, LifespanManager(app):
yield client


@pytest.mark.anyio
async def test_model_route(client):
response = await client.get(f"{PREFIX}/models")
assert response.status_code == 200
rdata = response.json()
assert "data" in rdata
assert rdata["data"][0].get("id", "") == MODEL
assert isinstance(rdata["data"][0].get("stats"), dict)
assert "audio_embed" in rdata["data"][0]["capabilities"]


@pytest.mark.anyio
async def test_audio_single(client):
audio_url = "https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_url},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
@pytest.mark.skip("text only")
async def test_audio_single_text_only(client):
text = "a sound of a at"

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": text},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
@pytest.mark.parametrize("no_of_audios", [1, 5, 10])
async def test_audio_multiple(client, no_of_audios):
audio_urls = [
"https://github.com/michaelfeil/infinity/raw/3b72eb7c14bae06e68ddd07c1f23fe0bf403f220/libs/infinity_emb/tests/data/audio/beep.wav"
] * no_of_audios

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_urls},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["data"]
assert len(rdata_results) == no_of_audios
assert "model" in rdata
assert "usage" in rdata
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
async def test_audio_fail(client):
audio_url = "https://www.google.com/404"
wirthual marked this conversation as resolved.
Show resolved Hide resolved

response = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_url},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.anyio
async def test_audio_empty(client):
audio_url_empty = []

response_empty = await client.post(
f"{PREFIX}/embeddings_audio",
json={"model": MODEL, "input": audio_url_empty},
)
assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
121 changes: 121 additions & 0 deletions libs/infinity_emb/tests/end_to_end/test_torch_vision.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
import pytest
import torch
from asgi_lifespan import LifespanManager
from fastapi import status
from httpx import AsyncClient

from infinity_emb import create_server
from infinity_emb.args import EngineArgs
from infinity_emb.primitives import Device, InferenceEngine

PREFIX = "/v1_ct2"
MODEL: str = pytest.DEFAULT_VISION_MODEL # type: ignore[assignment]
batch_size = 32 if torch.cuda.is_available() else 8

app = create_server(
url_prefix=PREFIX,
engine_args_list=[
EngineArgs(
model_name_or_path=MODEL,
batch_size=batch_size,
engine=InferenceEngine.torch,
device=Device.auto if not torch.backends.mps.is_available() else Device.cpu,
)
],
)


@pytest.fixture()
async def client():
async with AsyncClient(
app=app, base_url="http://test", timeout=20
) as client, LifespanManager(app):
yield client


@pytest.mark.anyio
async def test_model_route(client):
response = await client.get(f"{PREFIX}/models")
assert response.status_code == 200
rdata = response.json()
assert "data" in rdata
assert rdata["data"][0].get("id", "") == MODEL
assert isinstance(rdata["data"][0].get("stats"), dict)
assert "image_embed" in rdata["data"][0]["capabilities"]


@pytest.mark.anyio
async def test_vision_single(client):
image_url = "http://images.cocodataset.org/val2017/000000039769.jpg"

response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_url},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
@pytest.mark.skip("text only")
async def test_vision_single_text_only(client):
text = "a image of a cat"

response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": text},
)
assert response.status_code == 200
rdata = response.json()
assert "model" in rdata
assert "usage" in rdata
rdata_results = rdata["data"]
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
@pytest.mark.parametrize("no_of_images", [1, 5, 10])
async def test_vision_multiple(client, no_of_images):
image_urls = [
"http://images.cocodataset.org/val2017/000000039769.jpg"
] * no_of_images

response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_urls},
)
assert response.status_code == 200
rdata = response.json()
rdata_results = rdata["data"]
assert len(rdata_results) == no_of_images
assert "model" in rdata
assert "usage" in rdata
assert rdata_results[0]["object"] == "embedding"
assert len(rdata_results[0]["embedding"]) > 0


@pytest.mark.anyio
async def test_vision_fail(client):
image_url = "https://www.google.com/404"

response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_url},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST


@pytest.mark.anyio
async def test_vision_empty(client):
image_url_empty = []
response = await client.post(
f"{PREFIX}/embeddings_image",
json={"model": MODEL, "input": image_url_empty},
)
assert response.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
Loading