Skip to content

Commit

Permalink
Merge branch 'main' into 5-high-priosdxl
Browse files Browse the repository at this point in the history
  • Loading branch information
ovuruska authored May 7, 2024
2 parents 831bdfb + f8731c2 commit fe44774
Show file tree
Hide file tree
Showing 11 changed files with 63 additions and 31 deletions.
4 changes: 2 additions & 2 deletions deepinfra/constants/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""

MAX_RETRIES = 5
INITIAL_BACKOFF = 5000
SUBSEQUENT_BACKOFF = 2000
INITIAL_BACKOFF = 5
SUBSEQUENT_BACKOFF = 2
USER_AGENT = "DeepInfra Python API Client"
ROOT_URL = "https://api.deepinfra.com/v1/inference/"
2 changes: 1 addition & 1 deletion deepinfra/models/base/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,4 +26,4 @@ def generate(self, body):
response = self.client.post(
form_data, {"headers": {"content-type": form_data.content_type}}
)
return response.json()
return AutomaticSpeechRecognitionResponse(**response.json())
6 changes: 4 additions & 2 deletions deepinfra/models/base/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

from deepinfra.models.base import BaseModel
from deepinfra.types.embeddings.response import EmbeddingsResponse

Expand All @@ -13,5 +15,5 @@ def generate(self, body):
:param body:
:return:
"""
response = self.client.post(body)
return response.json()
response = self.client.post(json.dumps(body))
return EmbeddingsResponse(**response.json())
4 changes: 2 additions & 2 deletions deepinfra/models/base/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from deepinfra.types.text_generation.response import TextGenerationResponse


class TextGeneration(BaseModel):
class TextGeneration(BaseModel) -> TextGenerationResponse:
"""
Initializes one of the DeepInfra text generation models.
@docs Check the available models at https://deepinfra.com/models/text-generation
Expand All @@ -24,4 +24,4 @@ def generate(self, body: dict):
:return:
"""
response = self.client.post(json.dumps(body))
return response.json()
return TextGenerationResponse(**response.json())
7 changes: 4 additions & 3 deletions deepinfra/models/base/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ class TextToImage(BaseModel):
@docs Check the available models at https://deepinfra.com/models/text-to-image
"""

def generate(self, body: dict):
def generate(self, input) -> TextToImageResponse:
"""
Generates an image.
:param input:
:return:
"""

response = self.client.post(json.dumps(body))
return response.json()
body = {"input": input}
response = self.client.post(json.dumps(input))
return TextToImageResponse(**response.json())
3 changes: 1 addition & 2 deletions deepinfra/types/text_generation/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List
from typing import List, Optional

from deepinfra.types.common.inference_status import InferenceStatus

Expand All @@ -11,7 +11,6 @@ class GeneratedText:

@dataclass
class TextGenerationResponse:
request_id: str
inference_status: InferenceStatus
results: List[GeneratedText]
num_tokens: int
Expand Down
13 changes: 3 additions & 10 deletions deepinfra/types/text_to_image/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,9 @@ class Metrics:
class TextToImageResponse:
request_id: str
inference_status: InferenceStatus
input: Dict
output: List[str]
id: str
started_at: str
completed_at: str
logs: str
status: str
metrics: Metrics
webhook_events_filter: List[str]
output_file_prefix: str
images: List[str]
nsfw_content_detected: bool
seed: str
version: Optional[str] = None
created_at: Optional[str] = None
error: Optional[str] = None
Expand Down
12 changes: 10 additions & 2 deletions tests/test_automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,15 @@ class TestAutomaticSpeechRecognition(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"text": "Hello, World!"}
mock_post.return_value.json.return_value = {
"text": "Hello, World!",
"segments": [{"start": 0, "end": 1, "text": "Hello"}],
"language": "en",
"input_length_ms": 1000,
"request_id": "123",
"inference_status": None,
}

audio_data = b"audio data"
asr = AutomaticSpeechRecognition(model_name, api_key)
body = {"audio": audio_data}
Expand All @@ -24,4 +32,4 @@ def test_generate(self, mock_post):
called_headers = called_kwargs["headers"]
self.assertEqual(called_headers["Authorization"], f"Bearer {api_key}")

self.assertEqual(response["text"], "Hello, World!")
self.assertEqual(response.text, "Hello, World!")
13 changes: 10 additions & 3 deletions tests/test_embeddings.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import unittest
from unittest.mock import patch

Expand All @@ -11,7 +12,11 @@ class TestEmbeddings(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"embeddings": [1, 2, 3]}
mock_post.return_value.json.return_value = {
"embeddings": [1, 2, 3],
"input_tokens": 123,
"inference_status": None,
}

embeddings = Embeddings(model_name, api_key)
body = {"text": "Hello, World!"}
Expand All @@ -20,7 +25,9 @@ def test_generate(self, mock_post):
called_args, called_kwargs = mock_post.call_args
url = called_args[0]
header = called_kwargs["headers"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")
data = called_kwargs["data"]

self.assertEqual(response["embeddings"], [1, 2, 3])
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")
self.assertEqual(data, json.dumps(body))
self.assertEqual(response.embeddings, [1, 2, 3])
self.assertEqual(header["Authorization"], f"Bearer {api_key}")
13 changes: 11 additions & 2 deletions tests/test_text_generation.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import unittest
from unittest.mock import patch

Expand All @@ -10,17 +11,25 @@
class TestTextGeneration(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):

mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"text": "Hello, World!"}
mock_post.return_value.json.return_value = {
"results": [],
"num_tokens": 0,
"num_input_tokens": 0,
"inference_status": None,
}

text_generation = TextGeneration(model_name, api_key)
body = {"text": "Hello, World!"}
response = text_generation.generate(body)

called_args, called_kwargs = mock_post.call_args
url = called_args[0]
data = called_kwargs["data"]
header = called_kwargs["headers"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")

self.assertEqual(response["text"], "Hello, World!")
self.assertEqual(response.results, [])
self.assertEqual(header["Authorization"], f"Bearer {api_key}")
self.assertEqual(data, json.dumps(body))
17 changes: 15 additions & 2 deletions tests/test_text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import unittest
from unittest.mock import patch

Expand All @@ -11,7 +12,17 @@ class TestTextToImage(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"image": "image data"}
images = ["image data"]

mock_post.return_value.json.return_value = {
"request_id": "123",
"inference_status": None,
"images": images,
"nsfw_content_detected": False,
"seed": "seed",
"version": "1.0",
"created_at": "2022-01-01",
}

text_to_image = TextToImage(model_name, api_key)
body = {"text": "Hello, World!"}
Expand All @@ -20,7 +31,9 @@ def test_generate(self, mock_post):
called_args, called_kwargs = mock_post.call_args
url = called_args[0]
header = called_kwargs["headers"]
data = called_kwargs["data"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")

self.assertEqual(response["image"], "image data")
self.assertEqual(response.images, images)
self.assertEqual(header["Authorization"], f"Bearer {api_key}")
self.assertEqual(data, json.dumps(body))

0 comments on commit fe44774

Please sign in to comment.