From 9b83c9f433d51bde4c248ca45da4ce2775fd6e3d Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 30 Apr 2024 14:27:17 +0300 Subject: [PATCH 1/8] Added tests --- deepinfra/clients/deepinfra.py | 27 +++++------------- deepinfra/constants/__init__.py | 1 + deepinfra/constants/client.py | 2 +- deepinfra/constants/models.py | 1 + deepinfra/exceptions/__init__.py | 1 - deepinfra/exceptions/max_retries_exceeded.py | 12 -------- deepinfra/models/base/sdxl.py | 27 ++++++++++++++++++ deepinfra/models/base/text_to_image.py | 8 ++++-- .../automatic_speech_recognition/response.py | 1 - deepinfra/types/sdxl/__init__.py | 0 deepinfra/types/sdxl/request.py | 23 +++++++++++++++ deepinfra/types/sdxl/response.py | 16 +++++++++++ requirements.txt | 4 +-- tests/test_sdxl.py | 28 +++++++++++++++++++ 14 files changed, 111 insertions(+), 40 deletions(-) create mode 100644 deepinfra/constants/models.py delete mode 100644 deepinfra/exceptions/__init__.py delete mode 100644 deepinfra/exceptions/max_retries_exceeded.py create mode 100644 deepinfra/models/base/sdxl.py create mode 100644 deepinfra/types/sdxl/__init__.py create mode 100644 deepinfra/types/sdxl/request.py create mode 100644 deepinfra/types/sdxl/response.py create mode 100644 tests/test_sdxl.py diff --git a/deepinfra/clients/deepinfra.py b/deepinfra/clients/deepinfra.py index d1f6ca1..ce5a7ee 100644 --- a/deepinfra/clients/deepinfra.py +++ b/deepinfra/clients/deepinfra.py @@ -1,7 +1,7 @@ +import json import time import requests -from deepinfra.exceptions import MaxRetriesExceededError from deepinfra.constants import ( MAX_RETRIES, USER_AGENT, @@ -19,10 +19,6 @@ def __init__(self, url, auth_token): self.url = url self.auth_token = auth_token - def backoff_delay(self, attempt): - delay = self.initial_backoff if attempt == 1 else self.subsequent_backoff - time.sleep(delay) - def post(self, data, config=None): """ Performs a POST request. @@ -40,18 +36,9 @@ def post(self, data, config=None): "User-Agent": USER_AGENT, "Authorization": f"Bearer {self.auth_token}", } - for attempt in range(self.max_retries + 1): - try: - response = requests.post(self.url, data=data, headers=headers) - response.raise_for_status() - return response - except requests.RequestException as error: - if attempt < self.max_retries: - print( - f"Request failed, retrying... Attempt {attempt + 1}/{self.max_retries}" - ) - self.backoff_delay(attempt + 1) - else: - raise error - - raise MaxRetriesExceededError() + try: + response = requests.post(self.url, data=data, headers=headers) + response.raise_for_status() + return response + except requests.RequestException as error: + raise error diff --git a/deepinfra/constants/__init__.py b/deepinfra/constants/__init__.py index a27a84b..e657dab 100644 --- a/deepinfra/constants/__init__.py +++ b/deepinfra/constants/__init__.py @@ -1 +1,2 @@ from .client import MAX_RETRIES, USER_AGENT, INITIAL_BACKOFF, SUBSEQUENT_BACKOFF +from .models import SDXL \ No newline at end of file diff --git a/deepinfra/constants/client.py b/deepinfra/constants/client.py index 2d57565..6d9cc07 100644 --- a/deepinfra/constants/client.py +++ b/deepinfra/constants/client.py @@ -2,7 +2,7 @@ This module contains the constants used by the DeepInfra API client. """ -MAX_RETRIES = 5 +MAX_RETRIES = 0 INITIAL_BACKOFF = 5000 SUBSEQUENT_BACKOFF = 2000 USER_AGENT = "DeepInfra Python API Client" diff --git a/deepinfra/constants/models.py b/deepinfra/constants/models.py new file mode 100644 index 0000000..f53d9d3 --- /dev/null +++ b/deepinfra/constants/models.py @@ -0,0 +1 @@ +SDXL = "stability-ai/sdxl" diff --git a/deepinfra/exceptions/__init__.py b/deepinfra/exceptions/__init__.py deleted file mode 100644 index 03df37f..0000000 --- a/deepinfra/exceptions/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .max_retries_exceeded import MaxRetriesExceededError diff --git a/deepinfra/exceptions/max_retries_exceeded.py b/deepinfra/exceptions/max_retries_exceeded.py deleted file mode 100644 index 113e399..0000000 --- a/deepinfra/exceptions/max_retries_exceeded.py +++ /dev/null @@ -1,12 +0,0 @@ -""" -This error is raised when the maximum number of retries is exceeded by DeepInfraClient. -""" - - -class MaxRetriesExceededError(Exception): - """ - This error is raised when the maximum number of retries is exceeded by DeepInfraClient. - """ - - def __init__(self, message="Maximum retries exceeded"): - super().__init__(message) diff --git a/deepinfra/models/base/sdxl.py b/deepinfra/models/base/sdxl.py new file mode 100644 index 0000000..f3ee9c2 --- /dev/null +++ b/deepinfra/models/base/sdxl.py @@ -0,0 +1,27 @@ +import json + +from black import Optional + +from deepinfra import BaseModel +from deepinfra.constants import SDXL +from deepinfra.types.sdxl.request import SdxlRequest +from deepinfra.types.sdxl.response import SdxlResponse + + +class Sdxl(BaseModel): + """ + Class for the SDXL model. + @docs Check the model at https://deepinfra.com/stability-ai/sdxl/api + """ + + def __init__(self, api_token: Optional[str] = None): + super().__init__(SDXL, api_token) + + def generate(self, body: dict) -> SdxlResponse: + """ + Generates an image. + :param input: + :return: + """ + response = self.client.post(json.dumps(body)) + return response.json() diff --git a/deepinfra/models/base/text_to_image.py b/deepinfra/models/base/text_to_image.py index 42733ba..38c4f85 100644 --- a/deepinfra/models/base/text_to_image.py +++ b/deepinfra/models/base/text_to_image.py @@ -1,3 +1,5 @@ +import json + from deepinfra.models.base import BaseModel from deepinfra.types.text_to_image import TextToImageResponse @@ -8,12 +10,12 @@ class TextToImage(BaseModel): @docs Check the available models at https://deepinfra.com/models/text-to-image """ - def generate(self, input): + def generate(self, body: dict): """ Generates an image. :param input: :return: """ - body = {"input": input} - response = self.client.post(body) + + response = self.client.post(json.dumps(body)) return response.json() diff --git a/deepinfra/types/automatic_speech_recognition/response.py b/deepinfra/types/automatic_speech_recognition/response.py index 5db05e6..a4ffbe2 100644 --- a/deepinfra/types/automatic_speech_recognition/response.py +++ b/deepinfra/types/automatic_speech_recognition/response.py @@ -1,5 +1,4 @@ from typing import List - from dataclasses import dataclass from deepinfra.types.common.inference_status import InferenceStatus diff --git a/deepinfra/types/sdxl/__init__.py b/deepinfra/types/sdxl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepinfra/types/sdxl/request.py b/deepinfra/types/sdxl/request.py new file mode 100644 index 0000000..5335aa9 --- /dev/null +++ b/deepinfra/types/sdxl/request.py @@ -0,0 +1,23 @@ +from typing import Optional + +from dataclasses import dataclass + + +@dataclass +class SdxlRequest: + prompt: str + negative_prompt: Optional[str] = None + image: Optional[str] = None + mask: Optional[str] = None + width: Optional[int] = None + height: Optional[int] = None + num_outputs: Optional[int] = None + scheduler: Optional[str] = None + num_inference_steps: Optional[int] = None + guidance_scale: Optional[float] = None + prompt_strength: Optional[float] = None + seed: Optional[int] = None + refine: Optional[str] = None + high_noise_frac: Optional[float] = None + refine_steps: Optional[int] = None + apply_watermark: Optional[bool] = True diff --git a/deepinfra/types/sdxl/response.py b/deepinfra/types/sdxl/response.py new file mode 100644 index 0000000..2cd033c --- /dev/null +++ b/deepinfra/types/sdxl/response.py @@ -0,0 +1,16 @@ +from typing import List + +from dataclasses import dataclass + + +@dataclass +class SdxlResponseItem: + format: str + type: str + + +@dataclass +class SdxlResponse: + items: List[SdxlResponseItem] + title: str + type: str diff --git a/requirements.txt b/requirements.txt index 135d793..0f29974 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ setuptools setuptools-scm requests -dataclasses -requests-toolbelt \ No newline at end of file +requests-toolbelt +pydantic \ No newline at end of file diff --git a/tests/test_sdxl.py b/tests/test_sdxl.py new file mode 100644 index 0000000..c701780 --- /dev/null +++ b/tests/test_sdxl.py @@ -0,0 +1,28 @@ +import unittest +from unittest.mock import patch + +from deepinfra import TextToImage +from deepinfra.constants import SDXL +from deepinfra.models.base.sdxl import Sdxl + +model_name = SDXL +api_key = "API KEY" + + +class TestSdxl(unittest.TestCase): + @patch("requests.post") + def test_generate(self, mock_post): + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = {"image": "image data"} + + text_to_image = Sdxl(api_key) + body = {"input": {"prompt": "Hello, World!"}} + response = text_to_image.generate(body) + + called_args, called_kwargs = mock_post.call_args + url = called_args[0] + header = called_kwargs["headers"] + self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + + self.assertEqual(response["image"], "image data") + self.assertEqual(header["Authorization"], f"Bearer {api_key}") From 833de66e27bdb3fd1fb84de244bbd2b27792d1ab Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 30 Apr 2024 15:59:38 +0300 Subject: [PATCH 2/8] feat(base models): SDXL and Text to image --- README.md | 53 +++++++++++++++++++ deepinfra/models/base/__init__.py | 2 + .../base/automatic_speech_recognition.py | 2 +- deepinfra/models/base/embeddings.py | 3 +- deepinfra/models/base/sdxl.py | 6 +-- deepinfra/models/base/text_generation.py | 5 +- 6 files changed, 62 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index 6bf57c7..bd9bd15 100644 --- a/README.md +++ b/README.md @@ -55,4 +55,57 @@ transcription = asr.generate(body) print(transcription["text"]) ``` +#### Generate an image using SDXL +```python +from deepinfra import Sdxl +import base64 + +model = Sdxl() +body = { + "input":{ + "prompt": "A happy little cloud" + } +} +resp = model.generate(body) +image = resp["images"][0] +image = image.replace("data:image/png;base64,", "") + +with open("image.png", "wb") as fp: + fp.write(base64.b64decode(image)) + +``` + +#### Generate an image using a text-to-image model + +```python +from deepinfra import TextToImage +import base64 + +model_name = "CompVis/stable-diffusion-v1-4" +model = TextToImage(model_name) +body = { + "prompt": "A happy little cloud" +} +resp = model.generate(body) +image = resp["image"] +image = image.replace("data:image/png;base64,", "") + +with open("image.png", "wb") as fp: + fp.write(base64.b64decode(image)) +``` + +#### Generate text using LLM + +```python +from deepinfra import TextGeneration + +model_name = "meta-llama/Meta-Llama-3-8B-Instruct" +model = TextGeneration(model_name) +body = { + "input": "Write a story about a happy little cloud" +} +resp = model.generate(body) +result = resp["results"][0]["generated_text"] +print(result) +``` \ No newline at end of file diff --git a/deepinfra/models/base/__init__.py b/deepinfra/models/base/__init__.py index ba188d4..53f530e 100644 --- a/deepinfra/models/base/__init__.py +++ b/deepinfra/models/base/__init__.py @@ -2,3 +2,5 @@ from .text_to_image import TextToImage from .embeddings import Embeddings from .automatic_speech_recognition import AutomaticSpeechRecognition +from .sdxl import Sdxl +from .text_generation import TextGeneration \ No newline at end of file diff --git a/deepinfra/models/base/automatic_speech_recognition.py b/deepinfra/models/base/automatic_speech_recognition.py index 95d596b..30ef369 100644 --- a/deepinfra/models/base/automatic_speech_recognition.py +++ b/deepinfra/models/base/automatic_speech_recognition.py @@ -14,7 +14,7 @@ class AutomaticSpeechRecognition(BaseModel): @docs Check the available models at https://deepinfra.com/models/automatic-speech-recognition """ - def generate(self, body) -> AutomaticSpeechRecognitionResponse: + def generate(self, body): """ Generates the automatic speech recognition response. @param body: The request body. diff --git a/deepinfra/models/base/embeddings.py b/deepinfra/models/base/embeddings.py index 9343153..302b119 100644 --- a/deepinfra/models/base/embeddings.py +++ b/deepinfra/models/base/embeddings.py @@ -1,5 +1,4 @@ from deepinfra.models.base import BaseModel -from deepinfra.types.embeddings.request import EmbeddingsRequest from deepinfra.types.embeddings.response import EmbeddingsResponse @@ -8,7 +7,7 @@ class Embeddings(BaseModel): @docs Check the available models at https://deepinfra.com/models/embeddings """ - def generate(self, body) -> EmbeddingsResponse: + def generate(self, body): """ Generates embeddings. :param body: diff --git a/deepinfra/models/base/sdxl.py b/deepinfra/models/base/sdxl.py index f3ee9c2..ce515d1 100644 --- a/deepinfra/models/base/sdxl.py +++ b/deepinfra/models/base/sdxl.py @@ -2,10 +2,8 @@ from black import Optional -from deepinfra import BaseModel +from deepinfra.models.base import BaseModel from deepinfra.constants import SDXL -from deepinfra.types.sdxl.request import SdxlRequest -from deepinfra.types.sdxl.response import SdxlResponse class Sdxl(BaseModel): @@ -17,7 +15,7 @@ class Sdxl(BaseModel): def __init__(self, api_token: Optional[str] = None): super().__init__(SDXL, api_token) - def generate(self, body: dict) -> SdxlResponse: + def generate(self, body: dict): """ Generates an image. :param input: diff --git a/deepinfra/models/base/text_generation.py b/deepinfra/models/base/text_generation.py index 55ab472..2627f5e 100644 --- a/deepinfra/models/base/text_generation.py +++ b/deepinfra/models/base/text_generation.py @@ -3,6 +3,7 @@ which is the base class for all text generation models. """ +import json from typing import Union from deepinfra.models.base import BaseModel @@ -16,11 +17,11 @@ class TextGeneration(BaseModel): @docs Check the available models at https://deepinfra.com/models/text-generation """ - def generate(self, body: dict) -> TextGenerationResponse: + def generate(self, body: dict): """ Generates text. :param body: :return: """ - response = self.client.post(body) + response = self.client.post(json.dumps(body)) return response.json() From 60c181f27a9fc59d3ce854b4f68d34e2aa4037d7 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 30 Apr 2024 16:01:40 +0300 Subject: [PATCH 3/8] refactor(SDXL): better reusability --- deepinfra/types/sdxl/request.py | 19 +++---------------- 1 file changed, 3 insertions(+), 16 deletions(-) diff --git a/deepinfra/types/sdxl/request.py b/deepinfra/types/sdxl/request.py index 5335aa9..89a0e4c 100644 --- a/deepinfra/types/sdxl/request.py +++ b/deepinfra/types/sdxl/request.py @@ -2,22 +2,9 @@ from dataclasses import dataclass +from deepinfra.types.text_to_image import TextToImageRequest + @dataclass class SdxlRequest: - prompt: str - negative_prompt: Optional[str] = None - image: Optional[str] = None - mask: Optional[str] = None - width: Optional[int] = None - height: Optional[int] = None - num_outputs: Optional[int] = None - scheduler: Optional[str] = None - num_inference_steps: Optional[int] = None - guidance_scale: Optional[float] = None - prompt_strength: Optional[float] = None - seed: Optional[int] = None - refine: Optional[str] = None - high_noise_frac: Optional[float] = None - refine_steps: Optional[int] = None - apply_watermark: Optional[bool] = True + input: TextToImageRequest From 8322a6550111b519aa70e934f189b76eb3f676fa Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 30 Apr 2024 16:02:25 +0300 Subject: [PATCH 4/8] refactor(sdxl): removed redundant indent --- deepinfra/types/sdxl/request.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/deepinfra/types/sdxl/request.py b/deepinfra/types/sdxl/request.py index 89a0e4c..5e30627 100644 --- a/deepinfra/types/sdxl/request.py +++ b/deepinfra/types/sdxl/request.py @@ -1,5 +1,3 @@ -from typing import Optional - from dataclasses import dataclass from deepinfra.types.text_to_image import TextToImageRequest From 198f49ae5bb65fe7b7264792fde77ae2e11b5c28 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 30 Apr 2024 16:07:19 +0300 Subject: [PATCH 5/8] doc(SDXL): fixed --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index bd9bd15..7e7f70e 100644 --- a/README.md +++ b/README.md @@ -59,7 +59,7 @@ print(transcription["text"]) ```python from deepinfra import Sdxl -import base64 +import requests model = Sdxl() body = { @@ -69,10 +69,11 @@ body = { } resp = model.generate(body) image = resp["images"][0] -image = image.replace("data:image/png;base64,", "") +image_url = resp["output"][0] -with open("image.png", "wb") as fp: - fp.write(base64.b64decode(image)) +# Write image_url to image.png +with open("image.png", "wb") as file: + file.write(requests.get(image_url).content) ``` From 60401bf4e48593b103e4130185a7d49a68c69d22 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Sat, 4 May 2024 17:10:40 +0300 Subject: [PATCH 6/8] revert(backoff) --- deepinfra/clients/deepinfra.py | 27 +++++++++++++++----- deepinfra/constants/client.py | 2 +- deepinfra/exceptions/__init__.py | 0 deepinfra/exceptions/max_retries_exceeded.py | 12 +++++++++ 4 files changed, 33 insertions(+), 8 deletions(-) create mode 100644 deepinfra/exceptions/__init__.py create mode 100644 deepinfra/exceptions/max_retries_exceeded.py diff --git a/deepinfra/clients/deepinfra.py b/deepinfra/clients/deepinfra.py index ce5a7ee..0ca7667 100644 --- a/deepinfra/clients/deepinfra.py +++ b/deepinfra/clients/deepinfra.py @@ -1,7 +1,7 @@ -import json import time import requests +from deepinfra.exceptions import MaxRetriesExceededError from deepinfra.constants import ( MAX_RETRIES, USER_AGENT, @@ -19,6 +19,10 @@ def __init__(self, url, auth_token): self.url = url self.auth_token = auth_token + def backoff_delay(self, attempt): + delay = self.initial_backoff if attempt == 1 else self.subsequent_backoff + time.sleep(delay) + def post(self, data, config=None): """ Performs a POST request. @@ -36,9 +40,18 @@ def post(self, data, config=None): "User-Agent": USER_AGENT, "Authorization": f"Bearer {self.auth_token}", } - try: - response = requests.post(self.url, data=data, headers=headers) - response.raise_for_status() - return response - except requests.RequestException as error: - raise error + for attempt in range(self.max_retries + 1): + try: + response = requests.post(self.url, data=data, headers=headers) + response.raise_for_status() + return response + except requests.RequestException as error: + if attempt < self.max_retries: + print( + f"Request failed, retrying... Attempt {attempt + 1}/{self.max_retries}" + ) + self.backoff_delay(attempt + 1) + else: + raise error + + raise MaxRetriesExceededError() \ No newline at end of file diff --git a/deepinfra/constants/client.py b/deepinfra/constants/client.py index 6d9cc07..2d57565 100644 --- a/deepinfra/constants/client.py +++ b/deepinfra/constants/client.py @@ -2,7 +2,7 @@ This module contains the constants used by the DeepInfra API client. """ -MAX_RETRIES = 0 +MAX_RETRIES = 5 INITIAL_BACKOFF = 5000 SUBSEQUENT_BACKOFF = 2000 USER_AGENT = "DeepInfra Python API Client" diff --git a/deepinfra/exceptions/__init__.py b/deepinfra/exceptions/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepinfra/exceptions/max_retries_exceeded.py b/deepinfra/exceptions/max_retries_exceeded.py new file mode 100644 index 0000000..ce51107 --- /dev/null +++ b/deepinfra/exceptions/max_retries_exceeded.py @@ -0,0 +1,12 @@ +""" +This error is raised when the maximum number of retries is exceeded by DeepInfraClient. +""" + + +class MaxRetriesExceededError(Exception): + """ + This error is raised when the maximum number of retries is exceeded by DeepInfraClient. + """ + + def __init__(self, message="Maximum retries exceeded"): + super().__init__(message) \ No newline at end of file From 831bdfbd536b911e5054079a0bbc14b6396881d6 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Sat, 4 May 2024 17:14:42 +0300 Subject: [PATCH 7/8] ref(lint) --- deepinfra/clients/deepinfra.py | 2 +- deepinfra/exceptions/__init__.py | 1 + deepinfra/exceptions/max_retries_exceeded.py | 2 +- 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/deepinfra/clients/deepinfra.py b/deepinfra/clients/deepinfra.py index 0ca7667..d1f6ca1 100644 --- a/deepinfra/clients/deepinfra.py +++ b/deepinfra/clients/deepinfra.py @@ -54,4 +54,4 @@ def post(self, data, config=None): else: raise error - raise MaxRetriesExceededError() \ No newline at end of file + raise MaxRetriesExceededError() diff --git a/deepinfra/exceptions/__init__.py b/deepinfra/exceptions/__init__.py index e69de29..9a4b0e0 100644 --- a/deepinfra/exceptions/__init__.py +++ b/deepinfra/exceptions/__init__.py @@ -0,0 +1 @@ +from .max_retries_exceeded import MaxRetriesExceededError \ No newline at end of file diff --git a/deepinfra/exceptions/max_retries_exceeded.py b/deepinfra/exceptions/max_retries_exceeded.py index ce51107..113e399 100644 --- a/deepinfra/exceptions/max_retries_exceeded.py +++ b/deepinfra/exceptions/max_retries_exceeded.py @@ -9,4 +9,4 @@ class MaxRetriesExceededError(Exception): """ def __init__(self, message="Maximum retries exceeded"): - super().__init__(message) \ No newline at end of file + super().__init__(message) From dd1bb95b7ae79a0af3222e5bcbf13507f756b5a4 Mon Sep 17 00:00:00 2001 From: Oguz Vuruskaner Date: Tue, 7 May 2024 15:42:37 +0300 Subject: [PATCH 8/8] fix(annottation) --- deepinfra/models/base/text_generation.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/deepinfra/models/base/text_generation.py b/deepinfra/models/base/text_generation.py index 68ad86b..670ba9b 100644 --- a/deepinfra/models/base/text_generation.py +++ b/deepinfra/models/base/text_generation.py @@ -4,20 +4,18 @@ """ import json -from typing import Union from deepinfra.models.base import BaseModel -from deepinfra.types.text_generation.request import TextGenerationRequest from deepinfra.types.text_generation.response import TextGenerationResponse -class TextGeneration(BaseModel) -> TextGenerationResponse: +class TextGeneration(BaseModel): """ Initializes one of the DeepInfra text generation models. @docs Check the available models at https://deepinfra.com/models/text-generation """ - def generate(self, body: dict): + def generate(self, body: dict) -> TextGenerationResponse: """ Generates text. :param body: