Skip to content

Commit

Permalink
ref(lint)
Browse files Browse the repository at this point in the history
  • Loading branch information
ovuruska committed May 4, 2024
1 parent e473707 commit c1975e3
Show file tree
Hide file tree
Showing 11 changed files with 58 additions and 32 deletions.
4 changes: 2 additions & 2 deletions deepinfra/constants/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

MAX_RETRIES = 5
INITIAL_BACKOFF = 5000
SUBSEQUENT_BACKOFF = 2000
INITIAL_BACKOFF = 5
SUBSEQUENT_BACKOFF = 2
USER_AGENT = "DeepInfra Python API Client"
ROOT_URL = "https://api.deepinfra.com/v1/inference/"
2 changes: 1 addition & 1 deletion deepinfra/models/base/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def generate(self, body) -> AutomaticSpeechRecognitionResponse:
response = self.client.post(
form_data, {"headers": {"content-type": form_data.content_type}}
)
return response.json()
return AutomaticSpeechRecognitionResponse(**response.json())
7 changes: 4 additions & 3 deletions deepinfra/models/base/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json

from deepinfra.models.base import BaseModel
from deepinfra.types.embeddings.request import EmbeddingsRequest
from deepinfra.types.embeddings.response import EmbeddingsResponse


Expand All @@ -14,5 +15,5 @@ def generate(self, body) -> EmbeddingsResponse:
:param body:
:return:
"""
response = self.client.post(body)
return response.json()
response = self.client.post(json.dumps(body))
return EmbeddingsResponse(**response.json())
5 changes: 3 additions & 2 deletions deepinfra/models/base/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
which is the base class for all text generation models.
"""

import json
from typing import Union

from deepinfra.models.base import BaseModel
Expand All @@ -22,5 +23,5 @@ def generate(self, body: dict) -> TextGenerationResponse:
:param body:
:return:
"""
response = self.client.post(body)
return response.json()
response = self.client.post(json.dumps(body))
return TextGenerationResponse(**response.json())
8 changes: 5 additions & 3 deletions deepinfra/models/base/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

from deepinfra.models.base import BaseModel
from deepinfra.types.text_to_image import TextToImageResponse

Expand All @@ -8,12 +10,12 @@ class TextToImage(BaseModel):
@docs Check the available models at https://deepinfra.com/models/text-to-image
"""

def generate(self, input):
def generate(self, input) -> TextToImageResponse:
"""
Generates an image.
:param input:
:return:
"""
body = {"input": input}
response = self.client.post(body)
return response.json()
response = self.client.post(json.dumps(input))
return TextToImageResponse(**response.json())
3 changes: 1 addition & 2 deletions deepinfra/types/text_generation/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List
from typing import List, Optional

from deepinfra.types.common.inference_status import InferenceStatus

Expand All @@ -11,7 +11,6 @@ class GeneratedText:

@dataclass
class TextGenerationResponse:
request_id: str
inference_status: InferenceStatus
results: List[GeneratedText]
num_tokens: int
Expand Down
13 changes: 3 additions & 10 deletions deepinfra/types/text_to_image/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,9 @@ class Metrics:
class TextToImageResponse:
request_id: str
inference_status: InferenceStatus
input: Dict
output: List[str]
id: str
started_at: str
completed_at: str
logs: str
status: str
metrics: Metrics
webhook_events_filter: List[str]
output_file_prefix: str
images: List[str]
nsfw_content_detected: bool
seed: str
version: Optional[str] = None
created_at: Optional[str] = None
error: Optional[str] = None
Expand Down
12 changes: 10 additions & 2 deletions tests/test_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@ class TestAutomaticSpeechRecognition(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"text": "Hello, World!"}
mock_post.return_value.json.return_value = {
"text": "Hello, World!",
"segments": [{"start": 0, "end": 1, "text": "Hello"}],
"language": "en",
"input_length_ms": 1000,
"request_id": "123",
"inference_status": None,
}

audio_data = b"audio data"
asr = AutomaticSpeechRecognition(model_name, api_key)
body = {"audio": audio_data}
Expand All @@ -24,4 +32,4 @@ def test_generate(self, mock_post):
called_headers = called_kwargs["headers"]
self.assertEqual(called_headers["Authorization"], f"Bearer {api_key}")

self.assertEqual(response["text"], "Hello, World!")
self.assertEqual(response.text, "Hello, World!")
12 changes: 9 additions & 3 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@ class TestEmbeddings(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"embeddings": [1, 2, 3]}
mock_post.return_value.json.return_value = {
"embeddings": [1, 2, 3],
"input_tokens": 123,
"inference_status": None,
}

embeddings = Embeddings(model_name, api_key)
body = {"text": "Hello, World!"}
Expand All @@ -20,7 +24,9 @@ def test_generate(self, mock_post):
called_args, called_kwargs = mock_post.call_args
url = called_args[0]
header = called_kwargs["headers"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")
body = called_kwargs["data"]

self.assertEqual(response["embeddings"], [1, 2, 3])
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")
self.assertEqual(body, '{"text": "Hello, World!"}')
self.assertEqual(response.embeddings, [1, 2, 3])
self.assertEqual(header["Authorization"], f"Bearer {api_key}")
10 changes: 8 additions & 2 deletions tests/test_text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,14 @@
class TestTextGeneration(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):

mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"text": "Hello, World!"}
mock_post.return_value.json.return_value = {
"results": [],
"num_tokens": 0,
"num_input_tokens": 0,
"inference_status": None,
}

text_generation = TextGeneration(model_name, api_key)
body = {"text": "Hello, World!"}
Expand All @@ -22,5 +28,5 @@ def test_generate(self, mock_post):
header = called_kwargs["headers"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")

self.assertEqual(response["text"], "Hello, World!")
self.assertEqual(response.results, [])
self.assertEqual(header["Authorization"], f"Bearer {api_key}")
14 changes: 12 additions & 2 deletions tests/test_text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,17 @@ class TestTextToImage(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"image": "image data"}
images = ["image data"]

mock_post.return_value.json.return_value = {
"request_id": "123",
"inference_status": None,
"images": images,
"nsfw_content_detected": False,
"seed": "seed",
"version": "1.0",
"created_at": "2022-01-01",
}

text_to_image = TextToImage(model_name, api_key)
body = {"text": "Hello, World!"}
Expand All @@ -22,5 +32,5 @@ def test_generate(self, mock_post):
header = called_kwargs["headers"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")

self.assertEqual(response["image"], "image data")
self.assertEqual(response.images, images)
self.assertEqual(header["Authorization"], f"Bearer {api_key}")

0 comments on commit c1975e3

Please sign in to comment.