Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add simpler modality arg #401

Merged
merged 1 commit into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/assets/openapi.json

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions libs/infinity_emb/infinity_emb/fastapi_schemas/pymodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ class _OpenAIEmbeddingInput_Text(_OpenAIEmbeddingInput):
),
Annotated[str, INPUT_STRING],
]
infinity_extra_modality: Literal[Modality.text] = Modality.text # type: ignore
modality: Literal[Modality.text] = Modality.text # type: ignore


class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput):
Expand All @@ -115,21 +115,21 @@ class _OpenAIEmbeddingInput_URI(_OpenAIEmbeddingInput):


class OpenAIEmbeddingInput_Audio(_OpenAIEmbeddingInput_URI):
infinity_extra_modality: Literal[Modality.audio] = Modality.audio # type: ignore
modality: Literal[Modality.audio] = Modality.audio # type: ignore


class OpenAIEmbeddingInput_Image(_OpenAIEmbeddingInput_URI):
infinity_extra_modality: Literal[Modality.image] = Modality.image # type: ignore
modality: Literal[Modality.image] = Modality.image # type: ignore


def get_infinity_extra_modality(obj: dict) -> str:
def get_modality(obj: dict) -> str:
"""resolve the modality of the extra_body.
If not present, default to text

Function name is used to return error message, keep it explicit
"""
michaelfeil marked this conversation as resolved.
Show resolved Hide resolved
try:
return obj.get("infinity_extra_modality", Modality.text.value)
return obj.get("modality", Modality.text.value)
except AttributeError:
# in case a very weird request is sent, validate it against the default
return Modality.text.value
Expand All @@ -142,7 +142,7 @@ class MultiModalOpenAIEmbedding(RootModel):
Annotated[OpenAIEmbeddingInput_Audio, Tag(Modality.audio.value)],
Annotated[OpenAIEmbeddingInput_Image, Tag(Modality.image.value)],
],
Discriminator(get_infinity_extra_modality),
Discriminator(get_modality),
]


Expand Down
14 changes: 7 additions & 7 deletions libs/infinity_emb/infinity_emb/infinity_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ async def _embeddings(data: MultiModalOpenAIEmbedding):
# can also be base64 encoded
],
# set extra modality to image to process as image
"infinity_extra_modality": "image"
"modality": "image"
)
```

Expand All @@ -271,7 +271,7 @@ def url_to_base64(url, modality = "image"):
url, url_to_base64(url, "audio")
],
# set extra modality to audio to process as audio
"infinity_extra_modality": "audio"
"modality": "audio"
}
)
```
Expand All @@ -285,7 +285,7 @@ def url_to_base64(url, modality = "image"):
input=[url_to_base64(url, "audio")],
encoding_format= "base64",
extra_body={
"infinity_extra_modality": "audio"
"modality": "audio"
}
)

Expand All @@ -294,7 +294,7 @@ def url_to_base64(url, modality = "image"):
input=["the sound of a beep", "the sound of a cat"],
encoding_format= "base64",
extra_body={
"infinity_extra_modality": "text"
"modality": "text"
}
)
```
Expand All @@ -305,7 +305,7 @@ def url_to_base64(url, modality = "image"):
```
"""

modality = data.root.infinity_extra_modality
modality = data.root.modality
data_root = data.root
engine = _resolve_engine(data_root.model)

Expand Down Expand Up @@ -471,7 +471,7 @@ async def _classify(data: ClassifyInput):
dependencies=route_dependencies,
operation_id="embeddings_image",
deprecated=True,
summary="Deprecated: Use `embeddings` with `infinity_extra_modality` set to `image`",
summary="Deprecated: Use `embeddings` with `modality` set to `image`",
)
async def _embeddings_image(data: ImageEmbeddingInput):
"""Encode Embeddings from Image files
Expand Down Expand Up @@ -530,7 +530,7 @@ async def _embeddings_image(data: ImageEmbeddingInput):
dependencies=route_dependencies,
operation_id="embeddings_audio",
deprecated=True,
summary="Deprecated: Use `embeddings` with `infinity_extra_modality` set to `audio`",
summary="Deprecated: Use `embeddings` with `modality` set to `audio`",
)
async def _embeddings_audio(data: AudioEmbeddingInput):
"""Encode Embeddings from Audio files
Expand Down
16 changes: 8 additions & 8 deletions libs/infinity_emb/tests/end_to_end/test_openapi_client_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,46 +79,46 @@ async def test_openai(client: AsyncClient):
"the sound of a bird",
],
encoding_format="float",
extra_body={"infinity_extra_modality": "text"},
extra_body={"modality": "text"},
)
emb1_audio = await client_oai.embeddings.create(
model=pytest.DEFAULT_AUDIO_MODEL,
input=[url_to_base64(pytest.AUDIO_SAMPLE_URL, "audio")],
encoding_format="float",
extra_body={"infinity_extra_modality": "audio"},
extra_body={"modality": "audio"},
)
emb1_1_audio = await client_oai.embeddings.create(
model=pytest.DEFAULT_AUDIO_MODEL,
input=[pytest.AUDIO_SAMPLE_URL],
encoding_format="float",
extra_body={"infinity_extra_modality": "audio"},
extra_body={"modality": "audio"},
)
# test: image
emb_1_image_from_text = await client_oai.embeddings.create(
model=pytest.DEFAULT_IMAGE_MODEL,
input=["a cat", "a dog", "a bird"],
encoding_format="float",
extra_body={"infinity_extra_modality": "text"},
extra_body={"modality": "text"},
)
emb_1_image = await client_oai.embeddings.create(
model=pytest.DEFAULT_IMAGE_MODEL,
input=[url_to_base64(pytest.IMAGE_SAMPLE_URL, "image")], # image is a cat
encoding_format="float",
extra_body={"infinity_extra_modality": "image"},
extra_body={"modality": "image"},
)
emb_1_1_image = await client_oai.embeddings.create(
model=pytest.DEFAULT_IMAGE_MODEL,
input=[pytest.IMAGE_SAMPLE_URL],
encoding_format="float",
extra_body={"infinity_extra_modality": "image"},
extra_body={"modality": "image"},
)

# test: text
emb_1_text = await client_oai.embeddings.create(
model=pytest.DEFAULT_BERT_MODEL,
input=["a cat", "a cat", "a bird"],
encoding_format="float",
extra_body={"infinity_extra_modality": "text"},
extra_body={"modality": "text"},
)

# test AUDIO: cosine distance of beep to cat and dog
Expand Down Expand Up @@ -156,5 +156,5 @@ async def test_openai(client: AsyncClient):
model=pytest.DEFAULT_AUDIO_MODEL,
input=[pytest.AUDIO_SAMPLE_URL],
encoding_format="float",
extra_body={"infinity_extra_modality": "audio"},
extra_body={"modality": "audio"},
)
6 changes: 3 additions & 3 deletions libs/infinity_emb/tests/end_to_end/test_torch_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ async def test_audio_multiple(client):
json={
"model": MODEL,
"input": audio_urls,
"infinity_extra_modality": "audio",
"modality": "audio",
},
)
assert response.status_code == 200
Expand All @@ -151,7 +151,7 @@ async def test_audio_fail(client):
json={
"model": MODEL,
"input": audio_url,
"infinity_extra_modality": "audio",
"modality": "audio",
},
)
assert response.status_code == status.HTTP_400_BAD_REQUEST
Expand All @@ -166,7 +166,7 @@ async def test_audio_empty(client):
json={
"model": MODEL,
"input": audio_url_empty,
"infinity_extra_modality": "audio",
"modality": "audio",
},
)
assert response_empty.status_code == status.HTTP_422_UNPROCESSABLE_ENTITY
Expand Down
4 changes: 2 additions & 2 deletions libs/infinity_emb/tests/end_to_end/test_torch_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ async def test_meta(client, helpers):
json={
"model": MODEL,
"input": image_input,
"infinity_extra_modality": "image",
"modality": "image",
},
)

Expand Down Expand Up @@ -166,7 +166,7 @@ async def test_vision_multiple(client):
json={
"model": MODEL,
"input": image_urls,
"infinity_extra_modality": "image",
"modality": "image",
},
)
assert response.status_code == 200
Expand Down
Loading