Skip to content

Commit

Permalink
Added tests
Browse files Browse the repository at this point in the history
  • Loading branch information
ovuruska committed Apr 30, 2024
1 parent 06c2215 commit 9b83c9f
Show file tree
Hide file tree
Showing 14 changed files with 111 additions and 40 deletions.
27 changes: 7 additions & 20 deletions deepinfra/clients/deepinfra.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import time
import requests

from deepinfra.exceptions import MaxRetriesExceededError
from deepinfra.constants import (
MAX_RETRIES,
USER_AGENT,
Expand All @@ -19,10 +19,6 @@ def __init__(self, url, auth_token):
self.url = url
self.auth_token = auth_token

def backoff_delay(self, attempt):
delay = self.initial_backoff if attempt == 1 else self.subsequent_backoff
time.sleep(delay)

def post(self, data, config=None):
"""
Performs a POST request.
Expand All @@ -40,18 +36,9 @@ def post(self, data, config=None):
"User-Agent": USER_AGENT,
"Authorization": f"Bearer {self.auth_token}",
}
for attempt in range(self.max_retries + 1):
try:
response = requests.post(self.url, data=data, headers=headers)
response.raise_for_status()
return response
except requests.RequestException as error:
if attempt < self.max_retries:
print(
f"Request failed, retrying... Attempt {attempt + 1}/{self.max_retries}"
)
self.backoff_delay(attempt + 1)
else:
raise error

raise MaxRetriesExceededError()
try:
response = requests.post(self.url, data=data, headers=headers)
response.raise_for_status()
return response
except requests.RequestException as error:
raise error
1 change: 1 addition & 0 deletions deepinfra/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .client import MAX_RETRIES, USER_AGENT, INITIAL_BACKOFF, SUBSEQUENT_BACKOFF
from .models import SDXL
2 changes: 1 addition & 1 deletion deepinfra/constants/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This module contains the constants used by the DeepInfra API client.
"""

MAX_RETRIES = 5
MAX_RETRIES = 0
INITIAL_BACKOFF = 5000
SUBSEQUENT_BACKOFF = 2000
USER_AGENT = "DeepInfra Python API Client"
Expand Down
1 change: 1 addition & 0 deletions deepinfra/constants/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SDXL = "stability-ai/sdxl"
1 change: 0 additions & 1 deletion deepinfra/exceptions/__init__.py

This file was deleted.

12 changes: 0 additions & 12 deletions deepinfra/exceptions/max_retries_exceeded.py

This file was deleted.

27 changes: 27 additions & 0 deletions deepinfra/models/base/sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import json

from black import Optional

from deepinfra import BaseModel
from deepinfra.constants import SDXL
from deepinfra.types.sdxl.request import SdxlRequest
from deepinfra.types.sdxl.response import SdxlResponse


class Sdxl(BaseModel):
"""
Class for the SDXL model.
@docs Check the model at https://deepinfra.com/stability-ai/sdxl/api
"""

def __init__(self, api_token: Optional[str] = None):
super().__init__(SDXL, api_token)

def generate(self, body: dict) -> SdxlResponse:
"""
Generates an image.
:param input:
:return:
"""
response = self.client.post(json.dumps(body))
return response.json()
8 changes: 5 additions & 3 deletions deepinfra/models/base/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import json

from deepinfra.models.base import BaseModel
from deepinfra.types.text_to_image import TextToImageResponse

Expand All @@ -8,12 +10,12 @@ class TextToImage(BaseModel):
@docs Check the available models at https://deepinfra.com/models/text-to-image
"""

def generate(self, input):
def generate(self, body: dict):
"""
Generates an image.
:param input:
:return:
"""
body = {"input": input}
response = self.client.post(body)

response = self.client.post(json.dumps(body))
return response.json()
1 change: 0 additions & 1 deletion deepinfra/types/automatic_speech_recognition/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List

from dataclasses import dataclass

from deepinfra.types.common.inference_status import InferenceStatus
Expand Down
Empty file.
23 changes: 23 additions & 0 deletions deepinfra/types/sdxl/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from typing import Optional

from dataclasses import dataclass


@dataclass
class SdxlRequest:
prompt: str
negative_prompt: Optional[str] = None
image: Optional[str] = None
mask: Optional[str] = None
width: Optional[int] = None
height: Optional[int] = None
num_outputs: Optional[int] = None
scheduler: Optional[str] = None
num_inference_steps: Optional[int] = None
guidance_scale: Optional[float] = None
prompt_strength: Optional[float] = None
seed: Optional[int] = None
refine: Optional[str] = None
high_noise_frac: Optional[float] = None
refine_steps: Optional[int] = None
apply_watermark: Optional[bool] = True
16 changes: 16 additions & 0 deletions deepinfra/types/sdxl/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import List

from dataclasses import dataclass


@dataclass
class SdxlResponseItem:
format: str
type: str


@dataclass
class SdxlResponse:
items: List[SdxlResponseItem]
title: str
type: str
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
setuptools
setuptools-scm
requests
dataclasses
requests-toolbelt
requests-toolbelt
pydantic
28 changes: 28 additions & 0 deletions tests/test_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
from unittest.mock import patch

from deepinfra import TextToImage
from deepinfra.constants import SDXL
from deepinfra.models.base.sdxl import Sdxl

model_name = SDXL
api_key = "API KEY"


class TestSdxl(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"image": "image data"}

text_to_image = Sdxl(api_key)
body = {"input": {"prompt": "Hello, World!"}}
response = text_to_image.generate(body)

called_args, called_kwargs = mock_post.call_args
url = called_args[0]
header = called_kwargs["headers"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")

self.assertEqual(response["image"], "image data")
self.assertEqual(header["Authorization"], f"Bearer {api_key}")

0 comments on commit 9b83c9f

Please sign in to comment.