From b03bb15f8df1d40f270f8cc871c9dc8581a0e18e Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Mon, 29 Apr 2024 14:52:25 +0300 Subject: [PATCH] feat(base models): now all models encapsulates response object. --- .github/workflows/ci.yml | 8 ++++---- deepinfra/models/base/automatic_speech_recognition.py | 2 +- deepinfra/models/base/embeddings.py | 4 ++-- deepinfra/models/base/text_generation.py | 4 ++-- deepinfra/models/base/text_to_image.py | 3 ++- deepinfra/types/image_generation/__init__.py | 2 -- deepinfra/types/text_to_image/__init__.py | 2 ++ .../types/{image_generation => text_to_image}/request.py | 2 +- .../types/{image_generation => text_to_image}/response.py | 2 +- deepinfra/utils/form_data.py | 2 +- 10 files changed, 16 insertions(+), 15 deletions(-) delete mode 100644 deepinfra/types/image_generation/__init__.py create mode 100644 deepinfra/types/text_to_image/__init__.py rename deepinfra/types/{image_generation => text_to_image}/request.py (95%) rename deepinfra/types/{image_generation => text_to_image}/response.py (95%) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index b5a695c..764fe3b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -32,10 +32,10 @@ jobs: pip install -r requirements.txt pip install -r requirements-dev.txt - - name: Run lint check + - name: Run type check run: | - black --check --verbose deepinfra + mypy --install-types deepinfra --non-interactive --verbose - - name: Run type check + - name: Run lint check run: | - mypy --install-types deepinfra --non-interactive --verbose \ No newline at end of file + black --check --verbose deepinfra \ No newline at end of file diff --git a/deepinfra/models/base/automatic_speech_recognition.py b/deepinfra/models/base/automatic_speech_recognition.py index 95d596b..98803b7 100644 --- a/deepinfra/models/base/automatic_speech_recognition.py +++ b/deepinfra/models/base/automatic_speech_recognition.py @@ -26,4 +26,4 @@ def generate(self, body) -> AutomaticSpeechRecognitionResponse: response = self.client.post( form_data, {"headers": {"content-type": form_data.content_type}} ) - return response.json() + return AutomaticSpeechRecognitionResponse(**response.json()) diff --git a/deepinfra/models/base/embeddings.py b/deepinfra/models/base/embeddings.py index 8c270c3..bc2d771 100644 --- a/deepinfra/models/base/embeddings.py +++ b/deepinfra/models/base/embeddings.py @@ -8,11 +8,11 @@ class Embeddings(BaseModel): @docs Check the available models at https://deepinfra.com/models/embeddings """ - def generate(self, body: EmbeddingsRequest) -> EmbeddingsResponse: + def generate(self, body) -> EmbeddingsResponse: """ Generates embeddings. :param body: :return: """ response = self.client.post(body) - return response + return EmbeddingsResponse(**response.json()) diff --git a/deepinfra/models/base/text_generation.py b/deepinfra/models/base/text_generation.py index 16319a2..b105a79 100644 --- a/deepinfra/models/base/text_generation.py +++ b/deepinfra/models/base/text_generation.py @@ -16,11 +16,11 @@ class TextGeneration(BaseModel): @docs Check the available models at https://deepinfra.com/models/text-generation """ - def generate(self, body: TextGenerationRequest) -> TextGenerationResponse: + def generate(self, body: dict) -> TextGenerationResponse: """ Generates text. :param body: :return: """ response = self.client.post(body) - return response + return TextGenerationResponse(**response.json()) diff --git a/deepinfra/models/base/text_to_image.py b/deepinfra/models/base/text_to_image.py index 6247f78..432462e 100644 --- a/deepinfra/models/base/text_to_image.py +++ b/deepinfra/models/base/text_to_image.py @@ -1,4 +1,5 @@ from deepinfra.models.base import BaseModel +from deepinfra.types.text_to_image import TextToImageResponse class TextToImage(BaseModel): @@ -15,4 +16,4 @@ def generate(self, input): """ body = {"input": input} response = self.client.post(body) - return response + return TextToImageResponse(**response.json()) diff --git a/deepinfra/types/image_generation/__init__.py b/deepinfra/types/image_generation/__init__.py deleted file mode 100644 index d3decc1..0000000 --- a/deepinfra/types/image_generation/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from .request import ImageGenerationRequest -from .response import ImageGenerationResponse diff --git a/deepinfra/types/text_to_image/__init__.py b/deepinfra/types/text_to_image/__init__.py new file mode 100644 index 0000000..1fc81a2 --- /dev/null +++ b/deepinfra/types/text_to_image/__init__.py @@ -0,0 +1,2 @@ +from .request import TextToImageRequest +from .response import TextToImageResponse diff --git a/deepinfra/types/image_generation/request.py b/deepinfra/types/text_to_image/request.py similarity index 95% rename from deepinfra/types/image_generation/request.py rename to deepinfra/types/text_to_image/request.py index 7bb423a..9eadad7 100644 --- a/deepinfra/types/image_generation/request.py +++ b/deepinfra/types/text_to_image/request.py @@ -3,7 +3,7 @@ @dataclass -class ImageGenerationRequest: +class TextToImageRequest: prompt: str negative_prompt: Optional[str] = None image: Optional[str] = None diff --git a/deepinfra/types/image_generation/response.py b/deepinfra/types/text_to_image/response.py similarity index 95% rename from deepinfra/types/image_generation/response.py rename to deepinfra/types/text_to_image/response.py index 0702b65..f1bf320 100644 --- a/deepinfra/types/image_generation/response.py +++ b/deepinfra/types/text_to_image/response.py @@ -10,7 +10,7 @@ class Metrics: @dataclass -class ImageGenerationResponse: +class TextToImageResponse: request_id: str inference_status: InferenceStatus input: Dict diff --git a/deepinfra/utils/form_data.py b/deepinfra/utils/form_data.py index 2021b95..abb4952 100644 --- a/deepinfra/utils/form_data.py +++ b/deepinfra/utils/form_data.py @@ -12,7 +12,7 @@ class FormDataUtils: """ @staticmethod - def get_form_data(data, blob_keys: Optional[List[str]] = None): + def get_form_data(data: dict, blob_keys: Optional[List[str]] = None): """ Creates a MultipartEncoder object from the data. :param data: