From c1975e32a25c4aa16c3bd3baecfc9651088a77f2 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Sat, 4 May 2024 18:03:29 +0300 Subject: [PATCH 1/2] ref(lint) --- deepinfra/constants/client.py | 4 ++-- .../models/base/automatic_speech_recognition.py | 2 +- deepinfra/models/base/embeddings.py | 7 ++++--- deepinfra/models/base/text_generation.py | 5 +++-- deepinfra/models/base/text_to_image.py | 8 +++++--- deepinfra/types/text_generation/response.py | 3 +-- deepinfra/types/text_to_image/response.py | 13 +++---------- tests/test_automatic_speech_recognition.py | 12 ++++++++++-- tests/test_embeddings.py | 12 +++++++++--- tests/test_text_generation.py | 10 ++++++++-- tests/test_text_to_image.py | 14 ++++++++++++-- 11 files changed, 58 insertions(+), 32 deletions(-) diff --git a/deepinfra/constants/client.py b/deepinfra/constants/client.py index 2d57565..7d07659 100644 --- a/deepinfra/constants/client.py +++ b/deepinfra/constants/client.py @@ -3,7 +3,7 @@ """ MAX_RETRIES = 5 -INITIAL_BACKOFF = 5000 -SUBSEQUENT_BACKOFF = 2000 +INITIAL_BACKOFF = 5 +SUBSEQUENT_BACKOFF = 2 USER_AGENT = "DeepInfra Python API Client" ROOT_URL = "https://api.deepinfra.com/v1/inference/" diff --git a/deepinfra/models/base/automatic_speech_recognition.py b/deepinfra/models/base/automatic_speech_recognition.py index 95d596b..98803b7 100644 --- a/deepinfra/models/base/automatic_speech_recognition.py +++ b/deepinfra/models/base/automatic_speech_recognition.py @@ -26,4 +26,4 @@ def generate(self, body) -> AutomaticSpeechRecognitionResponse: response = self.client.post( form_data, {"headers": {"content-type": form_data.content_type}} ) - return response.json() + return AutomaticSpeechRecognitionResponse(**response.json()) diff --git a/deepinfra/models/base/embeddings.py b/deepinfra/models/base/embeddings.py index 9343153..31319e9 100644 --- a/deepinfra/models/base/embeddings.py +++ b/deepinfra/models/base/embeddings.py @@ -1,5 +1,6 @@ +import json + from deepinfra.models.base import BaseModel -from deepinfra.types.embeddings.request import EmbeddingsRequest from deepinfra.types.embeddings.response import EmbeddingsResponse @@ -14,5 +15,5 @@ def generate(self, body) -> EmbeddingsResponse: :param body: :return: """ - response = self.client.post(body) - return response.json() + response = self.client.post(json.dumps(body)) + return EmbeddingsResponse(**response.json()) diff --git a/deepinfra/models/base/text_generation.py b/deepinfra/models/base/text_generation.py index 55ab472..e728efc 100644 --- a/deepinfra/models/base/text_generation.py +++ b/deepinfra/models/base/text_generation.py @@ -3,6 +3,7 @@ which is the base class for all text generation models. """ +import json from typing import Union from deepinfra.models.base import BaseModel @@ -22,5 +23,5 @@ def generate(self, body: dict) -> TextGenerationResponse: :param body: :return: """ - response = self.client.post(body) - return response.json() + response = self.client.post(json.dumps(body)) + return TextGenerationResponse(**response.json()) diff --git a/deepinfra/models/base/text_to_image.py b/deepinfra/models/base/text_to_image.py index 42733ba..5944564 100644 --- a/deepinfra/models/base/text_to_image.py +++ b/deepinfra/models/base/text_to_image.py @@ -1,3 +1,5 @@ +import json + from deepinfra.models.base import BaseModel from deepinfra.types.text_to_image import TextToImageResponse @@ -8,12 +10,12 @@ class TextToImage(BaseModel): @docs Check the available models at https://deepinfra.com/models/text-to-image """ - def generate(self, input): + def generate(self, input) -> TextToImageResponse: """ Generates an image. :param input: :return: """ body = {"input": input} - response = self.client.post(body) - return response.json() + response = self.client.post(json.dumps(input)) + return TextToImageResponse(**response.json()) diff --git a/deepinfra/types/text_generation/response.py b/deepinfra/types/text_generation/response.py index d8e0ae5..716a457 100644 --- a/deepinfra/types/text_generation/response.py +++ b/deepinfra/types/text_generation/response.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import List, Optional from deepinfra.types.common.inference_status import InferenceStatus @@ -11,7 +11,6 @@ class GeneratedText: @dataclass class TextGenerationResponse: - request_id: str inference_status: InferenceStatus results: List[GeneratedText] num_tokens: int diff --git a/deepinfra/types/text_to_image/response.py b/deepinfra/types/text_to_image/response.py index f1bf320..6bb3c67 100644 --- a/deepinfra/types/text_to_image/response.py +++ b/deepinfra/types/text_to_image/response.py @@ -13,16 +13,9 @@ class Metrics: class TextToImageResponse: request_id: str inference_status: InferenceStatus - input: Dict - output: List[str] - id: str - started_at: str - completed_at: str - logs: str - status: str - metrics: Metrics - webhook_events_filter: List[str] - output_file_prefix: str + images: List[str] + nsfw_content_detected: bool + seed: str version: Optional[str] = None created_at: Optional[str] = None error: Optional[str] = None diff --git a/tests/test_automatic_speech_recognition.py b/tests/test_automatic_speech_recognition.py index 9bd276c..0d6a0d0 100644 --- a/tests/test_automatic_speech_recognition.py +++ b/tests/test_automatic_speech_recognition.py @@ -11,7 +11,15 @@ class TestAutomaticSpeechRecognition(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = {"text": "Hello, World!"} + mock_post.return_value.json.return_value = { + "text": "Hello, World!", + "segments": [{"start": 0, "end": 1, "text": "Hello"}], + "language": "en", + "input_length_ms": 1000, + "request_id": "123", + "inference_status": None, + } + audio_data = b"audio data" asr = AutomaticSpeechRecognition(model_name, api_key) body = {"audio": audio_data} @@ -24,4 +32,4 @@ def test_generate(self, mock_post): called_headers = called_kwargs["headers"] self.assertEqual(called_headers["Authorization"], f"Bearer {api_key}") - self.assertEqual(response["text"], "Hello, World!") + self.assertEqual(response.text, "Hello, World!") diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 85b67bc..56a543a 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -11,7 +11,11 @@ class TestEmbeddings(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = {"embeddings": [1, 2, 3]} + mock_post.return_value.json.return_value = { + "embeddings": [1, 2, 3], + "input_tokens": 123, + "inference_status": None, + } embeddings = Embeddings(model_name, api_key) body = {"text": "Hello, World!"} @@ -20,7 +24,9 @@ def test_generate(self, mock_post): called_args, called_kwargs = mock_post.call_args url = called_args[0] header = called_kwargs["headers"] - self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + body = called_kwargs["data"] - self.assertEqual(response["embeddings"], [1, 2, 3]) + self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + self.assertEqual(body, '{"text": "Hello, World!"}') + self.assertEqual(response.embeddings, [1, 2, 3]) self.assertEqual(header["Authorization"], f"Bearer {api_key}") diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index 816c602..56501b4 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -10,8 +10,14 @@ class TestTextGeneration(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): + mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = {"text": "Hello, World!"} + mock_post.return_value.json.return_value = { + "results": [], + "num_tokens": 0, + "num_input_tokens": 0, + "inference_status": None, + } text_generation = TextGeneration(model_name, api_key) body = {"text": "Hello, World!"} @@ -22,5 +28,5 @@ def test_generate(self, mock_post): header = called_kwargs["headers"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") - self.assertEqual(response["text"], "Hello, World!") + self.assertEqual(response.results, []) self.assertEqual(header["Authorization"], f"Bearer {api_key}") diff --git a/tests/test_text_to_image.py b/tests/test_text_to_image.py index 39fd5de..5c75fd8 100644 --- a/tests/test_text_to_image.py +++ b/tests/test_text_to_image.py @@ -11,7 +11,17 @@ class TestTextToImage(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = {"image": "image data"} + images = ["image data"] + + mock_post.return_value.json.return_value = { + "request_id": "123", + "inference_status": None, + "images": images, + "nsfw_content_detected": False, + "seed": "seed", + "version": "1.0", + "created_at": "2022-01-01", + } text_to_image = TextToImage(model_name, api_key) body = {"text": "Hello, World!"} @@ -22,5 +32,5 @@ def test_generate(self, mock_post): header = called_kwargs["headers"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") - self.assertEqual(response["image"], "image data") + self.assertEqual(response.images, images) self.assertEqual(header["Authorization"], f"Bearer {api_key}") From ccee46b8c6815c5de9aca6e3a805a917e272e269 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Sat, 4 May 2024 18:05:58 +0300 Subject: [PATCH 2/2] test(body test): implement --- tests/test_embeddings.py | 5 +++-- tests/test_text_generation.py | 3 +++ tests/test_text_to_image.py | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 56a543a..67cdb14 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -1,3 +1,4 @@ +import json import unittest from unittest.mock import patch @@ -24,9 +25,9 @@ def test_generate(self, mock_post): called_args, called_kwargs = mock_post.call_args url = called_args[0] header = called_kwargs["headers"] - body = called_kwargs["data"] + data = called_kwargs["data"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") - self.assertEqual(body, '{"text": "Hello, World!"}') + self.assertEqual(data, json.dumps(body)) self.assertEqual(response.embeddings, [1, 2, 3]) self.assertEqual(header["Authorization"], f"Bearer {api_key}") diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index 56501b4..233c572 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -1,3 +1,4 @@ +import json import unittest from unittest.mock import patch @@ -25,8 +26,10 @@ def test_generate(self, mock_post): called_args, called_kwargs = mock_post.call_args url = called_args[0] + data = called_kwargs["data"] header = called_kwargs["headers"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") self.assertEqual(response.results, []) self.assertEqual(header["Authorization"], f"Bearer {api_key}") + self.assertEqual(data, json.dumps(body)) diff --git a/tests/test_text_to_image.py b/tests/test_text_to_image.py index 5c75fd8..6088534 100644 --- a/tests/test_text_to_image.py +++ b/tests/test_text_to_image.py @@ -1,3 +1,4 @@ +import json import unittest from unittest.mock import patch @@ -30,7 +31,9 @@ def test_generate(self, mock_post): called_args, called_kwargs = mock_post.call_args url = called_args[0] header = called_kwargs["headers"] + data = called_kwargs["data"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") self.assertEqual(response.images, images) self.assertEqual(header["Authorization"], f"Bearer {api_key}") + self.assertEqual(data, json.dumps(body))