From 0a6e7da222aac992934b9d51b50b97d32788dc08 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Mon, 29 Apr 2024 21:06:24 +0300 Subject: [PATCH 1/7] Tests are implemented --- README.md | 19 +++++++++++- .../base/automatic_speech_recognition.py | 2 +- deepinfra/models/base/embeddings.py | 2 +- deepinfra/models/base/text_generation.py | 2 +- deepinfra/models/base/text_to_image.py | 2 +- requirements-dev.txt | 3 +- run_tests.py | 4 +++ tests/__init__.py | 0 tests/test_automatic_speech_recognition.py | 31 +++++++++++++++++++ 9 files changed, 59 insertions(+), 6 deletions(-) create mode 100644 run_tests.py create mode 100644 tests/__init__.py create mode 100644 tests/test_automatic_speech_recognition.py diff --git a/README.md b/README.md index dfbc109..6bf57c7 100644 --- a/README.md +++ b/README.md @@ -36,6 +36,23 @@ body = { "audio": file_path } transcription = asr.generate(body) -print(transcription) +print(transcription["text"]) ``` +#### Transcribe an audio URL + +```python +from deepinfra import AutomaticSpeechRecognition + +model_name = "openai/whisper-base" +asr = AutomaticSpeechRecognition(model_name) + +url = "https://path/to/audio/file" +body = { + "audio": url +} +transcription = asr.generate(body) +print(transcription["text"]) +``` + + diff --git a/deepinfra/models/base/automatic_speech_recognition.py b/deepinfra/models/base/automatic_speech_recognition.py index 98803b7..95d596b 100644 --- a/deepinfra/models/base/automatic_speech_recognition.py +++ b/deepinfra/models/base/automatic_speech_recognition.py @@ -26,4 +26,4 @@ def generate(self, body) -> AutomaticSpeechRecognitionResponse: response = self.client.post( form_data, {"headers": {"content-type": form_data.content_type}} ) - return AutomaticSpeechRecognitionResponse(**response.json()) + return response.json() diff --git a/deepinfra/models/base/embeddings.py b/deepinfra/models/base/embeddings.py index bc2d771..9343153 100644 --- a/deepinfra/models/base/embeddings.py +++ b/deepinfra/models/base/embeddings.py @@ -15,4 +15,4 @@ def generate(self, body) -> EmbeddingsResponse: :return: """ response = self.client.post(body) - return EmbeddingsResponse(**response.json()) + return response.json() diff --git a/deepinfra/models/base/text_generation.py b/deepinfra/models/base/text_generation.py index b105a79..55ab472 100644 --- a/deepinfra/models/base/text_generation.py +++ b/deepinfra/models/base/text_generation.py @@ -23,4 +23,4 @@ def generate(self, body: dict) -> TextGenerationResponse: :return: """ response = self.client.post(body) - return TextGenerationResponse(**response.json()) + return response.json() diff --git a/deepinfra/models/base/text_to_image.py b/deepinfra/models/base/text_to_image.py index 432462e..42733ba 100644 --- a/deepinfra/models/base/text_to_image.py +++ b/deepinfra/models/base/text_to_image.py @@ -16,4 +16,4 @@ def generate(self, input): """ body = {"input": input} response = self.client.post(body) - return TextToImageResponse(**response.json()) + return response.json() diff --git a/requirements-dev.txt b/requirements-dev.txt index 9d519af..92fc163 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,3 +1,4 @@ black==23.3.0 mypy -types-requests \ No newline at end of file +types-requests +coverage \ No newline at end of file diff --git a/run_tests.py b/run_tests.py new file mode 100644 index 0000000..1c35782 --- /dev/null +++ b/run_tests.py @@ -0,0 +1,4 @@ +import unittest + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/test_automatic_speech_recognition.py b/tests/test_automatic_speech_recognition.py new file mode 100644 index 0000000..457e2c5 --- /dev/null +++ b/tests/test_automatic_speech_recognition.py @@ -0,0 +1,31 @@ +import unittest +from unittest.mock import patch + +from deepinfra import AutomaticSpeechRecognition + +model_name = "openai/whisper-base" +api_key = "API KEY" + +class TestAutomaticSpeechRecognition(unittest.TestCase): + @patch("requests.post") + def test_generate(self, mock_post): + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = { + "text": "Hello, World!" + } + audio_data = b"audio data" + asr = AutomaticSpeechRecognition(model_name,api_key) + body = { + "audio": audio_data + } + response = asr.generate(body) + + called_args, called_kwargs = mock_post.call_args + url = called_args[0] + self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + + called_headers = called_kwargs["headers"] + self.assertEqual(called_headers["Authorization"], f"Bearer {api_key}") + + self.assertEqual(response["text"], "Hello, World!") + From f70a2ba00fac23ec72f06bec84b8f55d7c394a39 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Mon, 29 Apr 2024 21:08:05 +0300 Subject: [PATCH 2/7] ci(ci): Unit test execution is added. --- .github/workflows/ci.yml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 764fe3b..05dd4f1 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,4 +38,8 @@ jobs: - name: Run lint check run: | - black --check --verbose deepinfra \ No newline at end of file + black --check --verbose deepinfra + + - name: Run unit tests + run: | + pytest tests \ No newline at end of file From e82e908dcc08d8ea291119f371056ea0fe35e9d3 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 30 Apr 2024 12:02:49 +0300 Subject: [PATCH 3/7] Added missing requirement. --- requirements-dev.txt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index 92fc163..d2c80bf 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,4 +1,5 @@ black==23.3.0 mypy types-requests -coverage \ No newline at end of file +coverage +pytest \ No newline at end of file From 7c6b4825c409a901a05b8f2dd4452a733be16354 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 30 Apr 2024 12:12:03 +0300 Subject: [PATCH 4/7] Added tests --- tests/test_embeddings.py | 28 +++++++++++++++++ tests/test_text_generation.py | 58 +++++++++++++++++++++++++++++++++++ tests/test_text_to_image.py | 28 +++++++++++++++++ 3 files changed, 114 insertions(+) create mode 100644 tests/test_embeddings.py create mode 100644 tests/test_text_generation.py create mode 100644 tests/test_text_to_image.py diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py new file mode 100644 index 0000000..3c7f7cf --- /dev/null +++ b/tests/test_embeddings.py @@ -0,0 +1,28 @@ +import unittest +from unittest.mock import patch + +from deepinfra import Embeddings + + +class TestEmbeddings(unittest.TestCase): + @patch("requests.post") + def test_generate(self, mock_post): + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = { + "embeddings": [1, 2, 3] + } + model_name = "BAAI/bge-large-en-v1.5" + api_key = "API KEY" + embeddings = Embeddings(model_name, api_key) + body = { + "text": "Hello, World!" + } + response = embeddings.generate(body) + + called_args, called_kwargs = mock_post.call_args + url = called_args[0] + header = called_kwargs["headers"] + self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + + self.assertEqual(response["embeddings"], [1, 2, 3]) + self.assertEqual(header["Authorization"], f"Bearer {api_key}") \ No newline at end of file diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py new file mode 100644 index 0000000..2f1cb21 --- /dev/null +++ b/tests/test_text_generation.py @@ -0,0 +1,58 @@ +""" +import unittest +from unittest.mock import patch + +from deepinfra import Embeddings + + +class TestEmbeddings(unittest.TestCase): + @patch("requests.post") + def test_generate(self, mock_post): + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = { + "embeddings": [1, 2, 3] + } + model_name = "BAAI/bge-large-en-v1.5" + api_key = "API KEY" + embeddings = Embeddings(model_name, api_key) + body = { + "text": "Hello, World!" + } + response = embeddings.generate(body) + + called_args, called_kwargs = mock_post.call_args + url = called_args[0] + header = called_kwargs["headers"] + self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + + self.assertEqual(response["embeddings"], [1, 2, 3]) + self.assertEqual(header["Authorization"], f"Bearer {api_key}") +""" +import unittest +from unittest.mock import patch + +from deepinfra.models.base.text_generation import TextGeneration + + +class TestTextGeneration(unittest.TestCase): + @patch("requests.post") + def test_generate(self, mock_post): + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = { + "text": "Hello, World!" + } + model_name = "mistralai/Mistral-7B-Instruct-v0.2" + api_key = "API KEY" + text_generation = TextGeneration(model_name, api_key) + body = { + "text": "Hello, World!" + } + response = text_generation.generate(body) + + called_args, called_kwargs = mock_post.call_args + url = called_args[0] + header = called_kwargs["headers"] + self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + + self.assertEqual(response["text"], "Hello, World!") + self.assertEqual(header["Authorization"], f"Bearer {api_key}") \ No newline at end of file diff --git a/tests/test_text_to_image.py b/tests/test_text_to_image.py new file mode 100644 index 0000000..805aeae --- /dev/null +++ b/tests/test_text_to_image.py @@ -0,0 +1,28 @@ +import unittest +from unittest.mock import patch + +from deepinfra import TextToImage + + +class TestTextToImage(unittest.TestCase): + @patch("requests.post") + def test_generate(self, mock_post): + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = { + "image": "image data" + } + model_name = "CompVis/stable-diffusion-v1-4" + api_key = "API KEY" + text_to_image = TextToImage(model_name, api_key) + body = { + "text": "Hello, World!" + } + response = text_to_image.generate(body) + + called_args, called_kwargs = mock_post.call_args + url = called_args[0] + header = called_kwargs["headers"] + self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + + self.assertEqual(response["image"], "image data") + self.assertEqual(header["Authorization"], f"Bearer {api_key}") \ No newline at end of file From e473707110f32284049ebdae63b5726d91edb104 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 30 Apr 2024 12:17:26 +0300 Subject: [PATCH 5/7] ref(black): lint --- tests/test_automatic_speech_recognition.py | 12 ++---- tests/test_embeddings.py | 16 ++++---- tests/test_text_generation.py | 46 ++++------------------ tests/test_text_to_image.py | 16 ++++---- 4 files changed, 25 insertions(+), 65 deletions(-) diff --git a/tests/test_automatic_speech_recognition.py b/tests/test_automatic_speech_recognition.py index 457e2c5..9bd276c 100644 --- a/tests/test_automatic_speech_recognition.py +++ b/tests/test_automatic_speech_recognition.py @@ -6,18 +6,15 @@ model_name = "openai/whisper-base" api_key = "API KEY" + class TestAutomaticSpeechRecognition(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = { - "text": "Hello, World!" - } + mock_post.return_value.json.return_value = {"text": "Hello, World!"} audio_data = b"audio data" - asr = AutomaticSpeechRecognition(model_name,api_key) - body = { - "audio": audio_data - } + asr = AutomaticSpeechRecognition(model_name, api_key) + body = {"audio": audio_data} response = asr.generate(body) called_args, called_kwargs = mock_post.call_args @@ -28,4 +25,3 @@ def test_generate(self, mock_post): self.assertEqual(called_headers["Authorization"], f"Bearer {api_key}") self.assertEqual(response["text"], "Hello, World!") - diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 3c7f7cf..85b67bc 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -3,20 +3,18 @@ from deepinfra import Embeddings +model_name = "BAAI/bge-large-en-v1.5" +api_key = "API KEY" + class TestEmbeddings(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = { - "embeddings": [1, 2, 3] - } - model_name = "BAAI/bge-large-en-v1.5" - api_key = "API KEY" + mock_post.return_value.json.return_value = {"embeddings": [1, 2, 3]} + embeddings = Embeddings(model_name, api_key) - body = { - "text": "Hello, World!" - } + body = {"text": "Hello, World!"} response = embeddings.generate(body) called_args, called_kwargs = mock_post.call_args @@ -25,4 +23,4 @@ def test_generate(self, mock_post): self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") self.assertEqual(response["embeddings"], [1, 2, 3]) - self.assertEqual(header["Authorization"], f"Bearer {api_key}") \ No newline at end of file + self.assertEqual(header["Authorization"], f"Bearer {api_key}") diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index 2f1cb21..816c602 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -1,52 +1,20 @@ -""" -import unittest -from unittest.mock import patch - -from deepinfra import Embeddings - - -class TestEmbeddings(unittest.TestCase): - @patch("requests.post") - def test_generate(self, mock_post): - mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = { - "embeddings": [1, 2, 3] - } - model_name = "BAAI/bge-large-en-v1.5" - api_key = "API KEY" - embeddings = Embeddings(model_name, api_key) - body = { - "text": "Hello, World!" - } - response = embeddings.generate(body) - - called_args, called_kwargs = mock_post.call_args - url = called_args[0] - header = called_kwargs["headers"] - self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") - - self.assertEqual(response["embeddings"], [1, 2, 3]) - self.assertEqual(header["Authorization"], f"Bearer {api_key}") -""" import unittest from unittest.mock import patch from deepinfra.models.base.text_generation import TextGeneration +model_name = "mistralai/Mistral-7B-Instruct-v0.2" +api_key = "API KEY" + class TestTextGeneration(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = { - "text": "Hello, World!" - } - model_name = "mistralai/Mistral-7B-Instruct-v0.2" - api_key = "API KEY" + mock_post.return_value.json.return_value = {"text": "Hello, World!"} + text_generation = TextGeneration(model_name, api_key) - body = { - "text": "Hello, World!" - } + body = {"text": "Hello, World!"} response = text_generation.generate(body) called_args, called_kwargs = mock_post.call_args @@ -55,4 +23,4 @@ def test_generate(self, mock_post): self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") self.assertEqual(response["text"], "Hello, World!") - self.assertEqual(header["Authorization"], f"Bearer {api_key}") \ No newline at end of file + self.assertEqual(header["Authorization"], f"Bearer {api_key}") diff --git a/tests/test_text_to_image.py b/tests/test_text_to_image.py index 805aeae..39fd5de 100644 --- a/tests/test_text_to_image.py +++ b/tests/test_text_to_image.py @@ -3,20 +3,18 @@ from deepinfra import TextToImage +model_name = "CompVis/stable-diffusion-v1-4" +api_key = "API KEY" + class TestTextToImage(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = { - "image": "image data" - } - model_name = "CompVis/stable-diffusion-v1-4" - api_key = "API KEY" + mock_post.return_value.json.return_value = {"image": "image data"} + text_to_image = TextToImage(model_name, api_key) - body = { - "text": "Hello, World!" - } + body = {"text": "Hello, World!"} response = text_to_image.generate(body) called_args, called_kwargs = mock_post.call_args @@ -25,4 +23,4 @@ def test_generate(self, mock_post): self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") self.assertEqual(response["image"], "image data") - self.assertEqual(header["Authorization"], f"Bearer {api_key}") \ No newline at end of file + self.assertEqual(header["Authorization"], f"Bearer {api_key}") From c1975e32a25c4aa16c3bd3baecfc9651088a77f2 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Sat, 4 May 2024 18:03:29 +0300 Subject: [PATCH 6/7] ref(lint) --- deepinfra/constants/client.py | 4 ++-- .../models/base/automatic_speech_recognition.py | 2 +- deepinfra/models/base/embeddings.py | 7 ++++--- deepinfra/models/base/text_generation.py | 5 +++-- deepinfra/models/base/text_to_image.py | 8 +++++--- deepinfra/types/text_generation/response.py | 3 +-- deepinfra/types/text_to_image/response.py | 13 +++---------- tests/test_automatic_speech_recognition.py | 12 ++++++++++-- tests/test_embeddings.py | 12 +++++++++--- tests/test_text_generation.py | 10 ++++++++-- tests/test_text_to_image.py | 14 ++++++++++++-- 11 files changed, 58 insertions(+), 32 deletions(-) diff --git a/deepinfra/constants/client.py b/deepinfra/constants/client.py index 2d57565..7d07659 100644 --- a/deepinfra/constants/client.py +++ b/deepinfra/constants/client.py @@ -3,7 +3,7 @@ """ MAX_RETRIES = 5 -INITIAL_BACKOFF = 5000 -SUBSEQUENT_BACKOFF = 2000 +INITIAL_BACKOFF = 5 +SUBSEQUENT_BACKOFF = 2 USER_AGENT = "DeepInfra Python API Client" ROOT_URL = "https://api.deepinfra.com/v1/inference/" diff --git a/deepinfra/models/base/automatic_speech_recognition.py b/deepinfra/models/base/automatic_speech_recognition.py index 95d596b..98803b7 100644 --- a/deepinfra/models/base/automatic_speech_recognition.py +++ b/deepinfra/models/base/automatic_speech_recognition.py @@ -26,4 +26,4 @@ def generate(self, body) -> AutomaticSpeechRecognitionResponse: response = self.client.post( form_data, {"headers": {"content-type": form_data.content_type}} ) - return response.json() + return AutomaticSpeechRecognitionResponse(**response.json()) diff --git a/deepinfra/models/base/embeddings.py b/deepinfra/models/base/embeddings.py index 9343153..31319e9 100644 --- a/deepinfra/models/base/embeddings.py +++ b/deepinfra/models/base/embeddings.py @@ -1,5 +1,6 @@ +import json + from deepinfra.models.base import BaseModel -from deepinfra.types.embeddings.request import EmbeddingsRequest from deepinfra.types.embeddings.response import EmbeddingsResponse @@ -14,5 +15,5 @@ def generate(self, body) -> EmbeddingsResponse: :param body: :return: """ - response = self.client.post(body) - return response.json() + response = self.client.post(json.dumps(body)) + return EmbeddingsResponse(**response.json()) diff --git a/deepinfra/models/base/text_generation.py b/deepinfra/models/base/text_generation.py index 55ab472..e728efc 100644 --- a/deepinfra/models/base/text_generation.py +++ b/deepinfra/models/base/text_generation.py @@ -3,6 +3,7 @@ which is the base class for all text generation models. """ +import json from typing import Union from deepinfra.models.base import BaseModel @@ -22,5 +23,5 @@ def generate(self, body: dict) -> TextGenerationResponse: :param body: :return: """ - response = self.client.post(body) - return response.json() + response = self.client.post(json.dumps(body)) + return TextGenerationResponse(**response.json()) diff --git a/deepinfra/models/base/text_to_image.py b/deepinfra/models/base/text_to_image.py index 42733ba..5944564 100644 --- a/deepinfra/models/base/text_to_image.py +++ b/deepinfra/models/base/text_to_image.py @@ -1,3 +1,5 @@ +import json + from deepinfra.models.base import BaseModel from deepinfra.types.text_to_image import TextToImageResponse @@ -8,12 +10,12 @@ class TextToImage(BaseModel): @docs Check the available models at https://deepinfra.com/models/text-to-image """ - def generate(self, input): + def generate(self, input) -> TextToImageResponse: """ Generates an image. :param input: :return: """ body = {"input": input} - response = self.client.post(body) - return response.json() + response = self.client.post(json.dumps(input)) + return TextToImageResponse(**response.json()) diff --git a/deepinfra/types/text_generation/response.py b/deepinfra/types/text_generation/response.py index d8e0ae5..716a457 100644 --- a/deepinfra/types/text_generation/response.py +++ b/deepinfra/types/text_generation/response.py @@ -1,5 +1,5 @@ from dataclasses import dataclass -from typing import List +from typing import List, Optional from deepinfra.types.common.inference_status import InferenceStatus @@ -11,7 +11,6 @@ class GeneratedText: @dataclass class TextGenerationResponse: - request_id: str inference_status: InferenceStatus results: List[GeneratedText] num_tokens: int diff --git a/deepinfra/types/text_to_image/response.py b/deepinfra/types/text_to_image/response.py index f1bf320..6bb3c67 100644 --- a/deepinfra/types/text_to_image/response.py +++ b/deepinfra/types/text_to_image/response.py @@ -13,16 +13,9 @@ class Metrics: class TextToImageResponse: request_id: str inference_status: InferenceStatus - input: Dict - output: List[str] - id: str - started_at: str - completed_at: str - logs: str - status: str - metrics: Metrics - webhook_events_filter: List[str] - output_file_prefix: str + images: List[str] + nsfw_content_detected: bool + seed: str version: Optional[str] = None created_at: Optional[str] = None error: Optional[str] = None diff --git a/tests/test_automatic_speech_recognition.py b/tests/test_automatic_speech_recognition.py index 9bd276c..0d6a0d0 100644 --- a/tests/test_automatic_speech_recognition.py +++ b/tests/test_automatic_speech_recognition.py @@ -11,7 +11,15 @@ class TestAutomaticSpeechRecognition(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = {"text": "Hello, World!"} + mock_post.return_value.json.return_value = { + "text": "Hello, World!", + "segments": [{"start": 0, "end": 1, "text": "Hello"}], + "language": "en", + "input_length_ms": 1000, + "request_id": "123", + "inference_status": None, + } + audio_data = b"audio data" asr = AutomaticSpeechRecognition(model_name, api_key) body = {"audio": audio_data} @@ -24,4 +32,4 @@ def test_generate(self, mock_post): called_headers = called_kwargs["headers"] self.assertEqual(called_headers["Authorization"], f"Bearer {api_key}") - self.assertEqual(response["text"], "Hello, World!") + self.assertEqual(response.text, "Hello, World!") diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 85b67bc..56a543a 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -11,7 +11,11 @@ class TestEmbeddings(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = {"embeddings": [1, 2, 3]} + mock_post.return_value.json.return_value = { + "embeddings": [1, 2, 3], + "input_tokens": 123, + "inference_status": None, + } embeddings = Embeddings(model_name, api_key) body = {"text": "Hello, World!"} @@ -20,7 +24,9 @@ def test_generate(self, mock_post): called_args, called_kwargs = mock_post.call_args url = called_args[0] header = called_kwargs["headers"] - self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + body = called_kwargs["data"] - self.assertEqual(response["embeddings"], [1, 2, 3]) + self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + self.assertEqual(body, '{"text": "Hello, World!"}') + self.assertEqual(response.embeddings, [1, 2, 3]) self.assertEqual(header["Authorization"], f"Bearer {api_key}") diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index 816c602..56501b4 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -10,8 +10,14 @@ class TestTextGeneration(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): + mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = {"text": "Hello, World!"} + mock_post.return_value.json.return_value = { + "results": [], + "num_tokens": 0, + "num_input_tokens": 0, + "inference_status": None, + } text_generation = TextGeneration(model_name, api_key) body = {"text": "Hello, World!"} @@ -22,5 +28,5 @@ def test_generate(self, mock_post): header = called_kwargs["headers"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") - self.assertEqual(response["text"], "Hello, World!") + self.assertEqual(response.results, []) self.assertEqual(header["Authorization"], f"Bearer {api_key}") diff --git a/tests/test_text_to_image.py b/tests/test_text_to_image.py index 39fd5de..5c75fd8 100644 --- a/tests/test_text_to_image.py +++ b/tests/test_text_to_image.py @@ -11,7 +11,17 @@ class TestTextToImage(unittest.TestCase): @patch("requests.post") def test_generate(self, mock_post): mock_post.return_value.status_code = 200 - mock_post.return_value.json.return_value = {"image": "image data"} + images = ["image data"] + + mock_post.return_value.json.return_value = { + "request_id": "123", + "inference_status": None, + "images": images, + "nsfw_content_detected": False, + "seed": "seed", + "version": "1.0", + "created_at": "2022-01-01", + } text_to_image = TextToImage(model_name, api_key) body = {"text": "Hello, World!"} @@ -22,5 +32,5 @@ def test_generate(self, mock_post): header = called_kwargs["headers"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") - self.assertEqual(response["image"], "image data") + self.assertEqual(response.images, images) self.assertEqual(header["Authorization"], f"Bearer {api_key}") From ccee46b8c6815c5de9aca6e3a805a917e272e269 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Sat, 4 May 2024 18:05:58 +0300 Subject: [PATCH 7/7] test(body test): implement --- tests/test_embeddings.py | 5 +++-- tests/test_text_generation.py | 3 +++ tests/test_text_to_image.py | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/tests/test_embeddings.py b/tests/test_embeddings.py index 56a543a..67cdb14 100644 --- a/tests/test_embeddings.py +++ b/tests/test_embeddings.py @@ -1,3 +1,4 @@ +import json import unittest from unittest.mock import patch @@ -24,9 +25,9 @@ def test_generate(self, mock_post): called_args, called_kwargs = mock_post.call_args url = called_args[0] header = called_kwargs["headers"] - body = called_kwargs["data"] + data = called_kwargs["data"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") - self.assertEqual(body, '{"text": "Hello, World!"}') + self.assertEqual(data, json.dumps(body)) self.assertEqual(response.embeddings, [1, 2, 3]) self.assertEqual(header["Authorization"], f"Bearer {api_key}") diff --git a/tests/test_text_generation.py b/tests/test_text_generation.py index 56501b4..233c572 100644 --- a/tests/test_text_generation.py +++ b/tests/test_text_generation.py @@ -1,3 +1,4 @@ +import json import unittest from unittest.mock import patch @@ -25,8 +26,10 @@ def test_generate(self, mock_post): called_args, called_kwargs = mock_post.call_args url = called_args[0] + data = called_kwargs["data"] header = called_kwargs["headers"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") self.assertEqual(response.results, []) self.assertEqual(header["Authorization"], f"Bearer {api_key}") + self.assertEqual(data, json.dumps(body)) diff --git a/tests/test_text_to_image.py b/tests/test_text_to_image.py index 5c75fd8..6088534 100644 --- a/tests/test_text_to_image.py +++ b/tests/test_text_to_image.py @@ -1,3 +1,4 @@ +import json import unittest from unittest.mock import patch @@ -30,7 +31,9 @@ def test_generate(self, mock_post): called_args, called_kwargs = mock_post.call_args url = called_args[0] header = called_kwargs["headers"] + data = called_kwargs["data"] self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") self.assertEqual(response.images, images) self.assertEqual(header["Authorization"], f"Bearer {api_key}") + self.assertEqual(data, json.dumps(body))