Skip to content

Commit

Permalink
Merge pull request #13 from ovuruska/11-unit-tests-pipeline
Browse files Browse the repository at this point in the history
Tests are implemented
  • Loading branch information
ovuruska authored May 6, 2024
2 parents cc121f6 + ccee46b commit f8731c2
Show file tree
Hide file tree
Showing 15 changed files with 187 additions and 22 deletions.
6 changes: 5 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,4 +38,8 @@ jobs:
- name: Run lint check
run: |
black --check --verbose deepinfra
black --check --verbose deepinfra
- name: Run unit tests
run: |
pytest tests
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,23 @@ body = {
"audio": file_path
}
transcription = asr.generate(body)
print(transcription)
print(transcription["text"])
```

#### Transcribe an audio URL

```python
from deepinfra import AutomaticSpeechRecognition

model_name = "openai/whisper-base"
asr = AutomaticSpeechRecognition(model_name)

url = "https://path/to/audio/file"
body = {
"audio": url
}
transcription = asr.generate(body)
print(transcription["text"])
```


4 changes: 2 additions & 2 deletions deepinfra/constants/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

MAX_RETRIES = 5
INITIAL_BACKOFF = 5000
SUBSEQUENT_BACKOFF = 2000
INITIAL_BACKOFF = 5
SUBSEQUENT_BACKOFF = 2
USER_AGENT = "DeepInfra Python API Client"
ROOT_URL = "https://api.deepinfra.com/v1/inference/"
5 changes: 3 additions & 2 deletions deepinfra/models/base/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json

from deepinfra.models.base import BaseModel
from deepinfra.types.embeddings.request import EmbeddingsRequest
from deepinfra.types.embeddings.response import EmbeddingsResponse


Expand All @@ -14,5 +15,5 @@ def generate(self, body) -> EmbeddingsResponse:
:param body:
:return:
"""
response = self.client.post(body)
response = self.client.post(json.dumps(body))
return EmbeddingsResponse(**response.json())
3 changes: 2 additions & 1 deletion deepinfra/models/base/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
which is the base class for all text generation models.
"""

import json
from typing import Union

from deepinfra.models.base import BaseModel
Expand All @@ -22,5 +23,5 @@ def generate(self, body: dict) -> TextGenerationResponse:
:param body:
:return:
"""
response = self.client.post(body)
response = self.client.post(json.dumps(body))
return TextGenerationResponse(**response.json())
6 changes: 4 additions & 2 deletions deepinfra/models/base/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

from deepinfra.models.base import BaseModel
from deepinfra.types.text_to_image import TextToImageResponse

Expand All @@ -8,12 +10,12 @@ class TextToImage(BaseModel):
@docs Check the available models at https://deepinfra.com/models/text-to-image
"""

def generate(self, input):
def generate(self, input) -> TextToImageResponse:
"""
Generates an image.
:param input:
:return:
"""
body = {"input": input}
response = self.client.post(body)
response = self.client.post(json.dumps(input))
return TextToImageResponse(**response.json())
3 changes: 1 addition & 2 deletions deepinfra/types/text_generation/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List
from typing import List, Optional

from deepinfra.types.common.inference_status import InferenceStatus

Expand All @@ -11,7 +11,6 @@ class GeneratedText:

@dataclass
class TextGenerationResponse:
request_id: str
inference_status: InferenceStatus
results: List[GeneratedText]
num_tokens: int
Expand Down
13 changes: 3 additions & 10 deletions deepinfra/types/text_to_image/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,9 @@ class Metrics:
class TextToImageResponse:
request_id: str
inference_status: InferenceStatus
input: Dict
output: List[str]
id: str
started_at: str
completed_at: str
logs: str
status: str
metrics: Metrics
webhook_events_filter: List[str]
output_file_prefix: str
images: List[str]
nsfw_content_detected: bool
seed: str
version: Optional[str] = None
created_at: Optional[str] = None
error: Optional[str] = None
Expand Down
4 changes: 3 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
black==23.3.0
mypy
types-requests
types-requests
coverage
pytest
4 changes: 4 additions & 0 deletions run_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
import unittest

if __name__ == '__main__':
unittest.main()
Empty file added tests/__init__.py
Empty file.
35 changes: 35 additions & 0 deletions tests/test_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import unittest
from unittest.mock import patch

from deepinfra import AutomaticSpeechRecognition

model_name = "openai/whisper-base"
api_key = "API KEY"


class TestAutomaticSpeechRecognition(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {
"text": "Hello, World!",
"segments": [{"start": 0, "end": 1, "text": "Hello"}],
"language": "en",
"input_length_ms": 1000,
"request_id": "123",
"inference_status": None,
}

audio_data = b"audio data"
asr = AutomaticSpeechRecognition(model_name, api_key)
body = {"audio": audio_data}
response = asr.generate(body)

called_args, called_kwargs = mock_post.call_args
url = called_args[0]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")

called_headers = called_kwargs["headers"]
self.assertEqual(called_headers["Authorization"], f"Bearer {api_key}")

self.assertEqual(response.text, "Hello, World!")
33 changes: 33 additions & 0 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import json
import unittest
from unittest.mock import patch

from deepinfra import Embeddings

model_name = "BAAI/bge-large-en-v1.5"
api_key = "API KEY"


class TestEmbeddings(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {
"embeddings": [1, 2, 3],
"input_tokens": 123,
"inference_status": None,
}

embeddings = Embeddings(model_name, api_key)
body = {"text": "Hello, World!"}
response = embeddings.generate(body)

called_args, called_kwargs = mock_post.call_args
url = called_args[0]
header = called_kwargs["headers"]
data = called_kwargs["data"]

self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")
self.assertEqual(data, json.dumps(body))
self.assertEqual(response.embeddings, [1, 2, 3])
self.assertEqual(header["Authorization"], f"Bearer {api_key}")
35 changes: 35 additions & 0 deletions tests/test_text_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import json
import unittest
from unittest.mock import patch

from deepinfra.models.base.text_generation import TextGeneration

model_name = "mistralai/Mistral-7B-Instruct-v0.2"
api_key = "API KEY"


class TestTextGeneration(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):

mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {
"results": [],
"num_tokens": 0,
"num_input_tokens": 0,
"inference_status": None,
}

text_generation = TextGeneration(model_name, api_key)
body = {"text": "Hello, World!"}
response = text_generation.generate(body)

called_args, called_kwargs = mock_post.call_args
url = called_args[0]
data = called_kwargs["data"]
header = called_kwargs["headers"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")

self.assertEqual(response.results, [])
self.assertEqual(header["Authorization"], f"Bearer {api_key}")
self.assertEqual(data, json.dumps(body))
39 changes: 39 additions & 0 deletions tests/test_text_to_image.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import json
import unittest
from unittest.mock import patch

from deepinfra import TextToImage

model_name = "CompVis/stable-diffusion-v1-4"
api_key = "API KEY"


class TestTextToImage(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
images = ["image data"]

mock_post.return_value.json.return_value = {
"request_id": "123",
"inference_status": None,
"images": images,
"nsfw_content_detected": False,
"seed": "seed",
"version": "1.0",
"created_at": "2022-01-01",
}

text_to_image = TextToImage(model_name, api_key)
body = {"text": "Hello, World!"}
response = text_to_image.generate(body)

called_args, called_kwargs = mock_post.call_args
url = called_args[0]
header = called_kwargs["headers"]
data = called_kwargs["data"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")

self.assertEqual(response.images, images)
self.assertEqual(header["Authorization"], f"Bearer {api_key}")
self.assertEqual(data, json.dumps(body))

0 comments on commit f8731c2

Please sign in to comment.