From cd10b8fd304b131b768b452617507c537ce98737 Mon Sep 17 00:00:00 2001 From: zeke <40004347+KAJdev@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:00:26 -0800 Subject: [PATCH 1/4] feat: support `image_input` for image embeddings --- src/embedding_service.py | 23 +++++++++++++++++++++++ src/handler.py | 6 ++++++ 2 files changed, 29 insertions(+) diff --git a/src/embedding_service.py b/src/embedding_service.py index a5afc44..1df260e 100644 --- a/src/embedding_service.py +++ b/src/embedding_service.py @@ -57,6 +57,7 @@ def list_models(self) -> list[str]: async def route_openai_get_embeddings( self, embedding_input: str | list[str], + image_input: str | list[str], model_name: str, return_as_list: bool = False, ): @@ -76,6 +77,28 @@ async def route_openai_get_embeddings( embeddings, model=model_name, usage=usage ) + async def route_get_image_embeddings( + self, + image_input: str | list[str], + model_name: str, + return_as_list: bool = False, + ): + """returns embeddings for the input image urls""" + if not self.is_running: + await self.start() + if not isinstance(image_input, list): + image_input = [image_input] + + embeddings, usage = await self.engine_array[model_name].embed_image(image_input) + if return_as_list: + return [ + list_embeddings_to_response(embeddings, model=model_name, usage=usage) + ] + else: + return list_embeddings_to_response( + embeddings, model=model_name, usage=usage + ) + async def infinity_rerank( self, query: str, docs: str, return_docs: str, model_name: str ): diff --git a/src/handler.py b/src/handler.py index a7bb04b..f67fab6 100644 --- a/src/handler.py +++ b/src/handler.py @@ -47,6 +47,12 @@ async def async_generator_handler(job: dict[str, Any]): "embedding_input": job_input.get("input"), "model_name": job_input.get("model"), } + # handle image urls (for image embeddings) + elif job_input.get("image_input"): + call_fn, kwargs = embedding_service.route_get_image_embeddings, { + "image_input": job_input.get("image_input"), + "model_name": job_input.get("model"), + } else: return create_error_response(f"Invalid input: {job}").model_dump() try: From 113c26f9df92da15953e2901a1cb4ede00e4b270 Mon Sep 17 00:00:00 2001 From: zeke <40004347+KAJdev@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:14:27 -0800 Subject: [PATCH 2/4] call correct function --- src/embedding_service.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/embedding_service.py b/src/embedding_service.py index 1df260e..2f0ea94 100644 --- a/src/embedding_service.py +++ b/src/embedding_service.py @@ -89,7 +89,7 @@ async def route_get_image_embeddings( if not isinstance(image_input, list): image_input = [image_input] - embeddings, usage = await self.engine_array[model_name].embed_image(image_input) + embeddings, usage = await self.engine_array[model_name].image_embed(image_input) if return_as_list: return [ list_embeddings_to_response(embeddings, model=model_name, usage=usage) From ff362c138b41ecb414f1b45655a56660a2138daf Mon Sep 17 00:00:00 2001 From: zeke <40004347+KAJdev@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:20:03 -0800 Subject: [PATCH 3/4] remove unused code --- src/embedding_service.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/embedding_service.py b/src/embedding_service.py index 2f0ea94..e64cd6a 100644 --- a/src/embedding_service.py +++ b/src/embedding_service.py @@ -57,7 +57,6 @@ def list_models(self) -> list[str]: async def route_openai_get_embeddings( self, embedding_input: str | list[str], - image_input: str | list[str], model_name: str, return_as_list: bool = False, ): From 806d56c0be6116950d93404b450c58d95abb67aa Mon Sep 17 00:00:00 2001 From: zeke <40004347+KAJdev@users.noreply.github.com> Date: Tue, 6 Aug 2024 20:22:30 -0800 Subject: [PATCH 4/4] Update README.md --- README.md | 1 + 1 file changed, 1 insertion(+) diff --git a/README.md b/README.md index c5c04b1..3e18d66 100644 --- a/README.md +++ b/README.md @@ -123,6 +123,7 @@ You may use `/run` (asynchronous, start job and return job ID) or `/runsync` (sy Inputs: * `model`: name of one of the deployed models. * `input`: single text string or list of texts to embed +* `image_input`: single url or list of urls to embed (for use with CLIP models) ### Reranking Inputs: