diff --git a/README.md b/README.md index 6bf57c7..7e7f70e 100644 --- a/README.md +++ b/README.md @@ -55,4 +55,58 @@ transcription = asr.generate(body) print(transcription["text"]) ``` +#### Generate an image using SDXL +```python +from deepinfra import Sdxl +import requests + +model = Sdxl() +body = { + "input":{ + "prompt": "A happy little cloud" + } +} +resp = model.generate(body) +image = resp["images"][0] +image_url = resp["output"][0] + +# Write image_url to image.png +with open("image.png", "wb") as file: + file.write(requests.get(image_url).content) + +``` + +#### Generate an image using a text-to-image model + +```python +from deepinfra import TextToImage +import base64 + +model_name = "CompVis/stable-diffusion-v1-4" +model = TextToImage(model_name) +body = { + "prompt": "A happy little cloud" +} +resp = model.generate(body) +image = resp["image"] +image = image.replace("data:image/png;base64,", "") + +with open("image.png", "wb") as fp: + fp.write(base64.b64decode(image)) +``` + +#### Generate text using LLM + +```python +from deepinfra import TextGeneration + +model_name = "meta-llama/Meta-Llama-3-8B-Instruct" +model = TextGeneration(model_name) +body = { + "input": "Write a story about a happy little cloud" +} +resp = model.generate(body) +result = resp["results"][0]["generated_text"] +print(result) +``` \ No newline at end of file diff --git a/deepinfra/constants/__init__.py b/deepinfra/constants/__init__.py index a27a84b..e657dab 100644 --- a/deepinfra/constants/__init__.py +++ b/deepinfra/constants/__init__.py @@ -1 +1,2 @@ from .client import MAX_RETRIES, USER_AGENT, INITIAL_BACKOFF, SUBSEQUENT_BACKOFF +from .models import SDXL \ No newline at end of file diff --git a/deepinfra/constants/models.py b/deepinfra/constants/models.py new file mode 100644 index 0000000..f53d9d3 --- /dev/null +++ b/deepinfra/constants/models.py @@ -0,0 +1 @@ +SDXL = "stability-ai/sdxl" diff --git a/deepinfra/exceptions/__init__.py b/deepinfra/exceptions/__init__.py index 03df37f..9a4b0e0 100644 --- a/deepinfra/exceptions/__init__.py +++ b/deepinfra/exceptions/__init__.py @@ -1 +1 @@ -from .max_retries_exceeded import MaxRetriesExceededError +from .max_retries_exceeded import MaxRetriesExceededError \ No newline at end of file diff --git a/deepinfra/models/base/__init__.py b/deepinfra/models/base/__init__.py index ba188d4..53f530e 100644 --- a/deepinfra/models/base/__init__.py +++ b/deepinfra/models/base/__init__.py @@ -2,3 +2,5 @@ from .text_to_image import TextToImage from .embeddings import Embeddings from .automatic_speech_recognition import AutomaticSpeechRecognition +from .sdxl import Sdxl +from .text_generation import TextGeneration \ No newline at end of file diff --git a/deepinfra/models/base/automatic_speech_recognition.py b/deepinfra/models/base/automatic_speech_recognition.py index 98803b7..bd8be5c 100644 --- a/deepinfra/models/base/automatic_speech_recognition.py +++ b/deepinfra/models/base/automatic_speech_recognition.py @@ -14,7 +14,7 @@ class AutomaticSpeechRecognition(BaseModel): @docs Check the available models at https://deepinfra.com/models/automatic-speech-recognition """ - def generate(self, body) -> AutomaticSpeechRecognitionResponse: + def generate(self, body): """ Generates the automatic speech recognition response. @param body: The request body. diff --git a/deepinfra/models/base/embeddings.py b/deepinfra/models/base/embeddings.py index 31319e9..99f2e3a 100644 --- a/deepinfra/models/base/embeddings.py +++ b/deepinfra/models/base/embeddings.py @@ -9,7 +9,7 @@ class Embeddings(BaseModel): @docs Check the available models at https://deepinfra.com/models/embeddings """ - def generate(self, body) -> EmbeddingsResponse: + def generate(self, body): """ Generates embeddings. :param body: diff --git a/deepinfra/models/base/sdxl.py b/deepinfra/models/base/sdxl.py new file mode 100644 index 0000000..ce515d1 --- /dev/null +++ b/deepinfra/models/base/sdxl.py @@ -0,0 +1,25 @@ +import json + +from black import Optional + +from deepinfra.models.base import BaseModel +from deepinfra.constants import SDXL + + +class Sdxl(BaseModel): + """ + Class for the SDXL model. + @docs Check the model at https://deepinfra.com/stability-ai/sdxl/api + """ + + def __init__(self, api_token: Optional[str] = None): + super().__init__(SDXL, api_token) + + def generate(self, body: dict): + """ + Generates an image. + :param input: + :return: + """ + response = self.client.post(json.dumps(body)) + return response.json() diff --git a/deepinfra/models/base/text_generation.py b/deepinfra/models/base/text_generation.py index e728efc..670ba9b 100644 --- a/deepinfra/models/base/text_generation.py +++ b/deepinfra/models/base/text_generation.py @@ -4,10 +4,8 @@ """ import json -from typing import Union from deepinfra.models.base import BaseModel -from deepinfra.types.text_generation.request import TextGenerationRequest from deepinfra.types.text_generation.response import TextGenerationResponse diff --git a/deepinfra/models/base/text_to_image.py b/deepinfra/models/base/text_to_image.py index 5944564..94ab3ac 100644 --- a/deepinfra/models/base/text_to_image.py +++ b/deepinfra/models/base/text_to_image.py @@ -16,6 +16,7 @@ def generate(self, input) -> TextToImageResponse: :param input: :return: """ + body = {"input": input} response = self.client.post(json.dumps(input)) return TextToImageResponse(**response.json()) diff --git a/deepinfra/types/automatic_speech_recognition/response.py b/deepinfra/types/automatic_speech_recognition/response.py index 5db05e6..a4ffbe2 100644 --- a/deepinfra/types/automatic_speech_recognition/response.py +++ b/deepinfra/types/automatic_speech_recognition/response.py @@ -1,5 +1,4 @@ from typing import List - from dataclasses import dataclass from deepinfra.types.common.inference_status import InferenceStatus diff --git a/deepinfra/types/sdxl/__init__.py b/deepinfra/types/sdxl/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/deepinfra/types/sdxl/request.py b/deepinfra/types/sdxl/request.py new file mode 100644 index 0000000..5e30627 --- /dev/null +++ b/deepinfra/types/sdxl/request.py @@ -0,0 +1,8 @@ +from dataclasses import dataclass + +from deepinfra.types.text_to_image import TextToImageRequest + + +@dataclass +class SdxlRequest: + input: TextToImageRequest diff --git a/deepinfra/types/sdxl/response.py b/deepinfra/types/sdxl/response.py new file mode 100644 index 0000000..2cd033c --- /dev/null +++ b/deepinfra/types/sdxl/response.py @@ -0,0 +1,16 @@ +from typing import List + +from dataclasses import dataclass + + +@dataclass +class SdxlResponseItem: + format: str + type: str + + +@dataclass +class SdxlResponse: + items: List[SdxlResponseItem] + title: str + type: str diff --git a/requirements.txt b/requirements.txt index 135d793..0f29974 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ setuptools setuptools-scm requests -dataclasses -requests-toolbelt \ No newline at end of file +requests-toolbelt +pydantic \ No newline at end of file diff --git a/tests/test_sdxl.py b/tests/test_sdxl.py new file mode 100644 index 0000000..c701780 --- /dev/null +++ b/tests/test_sdxl.py @@ -0,0 +1,28 @@ +import unittest +from unittest.mock import patch + +from deepinfra import TextToImage +from deepinfra.constants import SDXL +from deepinfra.models.base.sdxl import Sdxl + +model_name = SDXL +api_key = "API KEY" + + +class TestSdxl(unittest.TestCase): + @patch("requests.post") + def test_generate(self, mock_post): + mock_post.return_value.status_code = 200 + mock_post.return_value.json.return_value = {"image": "image data"} + + text_to_image = Sdxl(api_key) + body = {"input": {"prompt": "Hello, World!"}} + response = text_to_image.generate(body) + + called_args, called_kwargs = mock_post.call_args + url = called_args[0] + header = called_kwargs["headers"] + self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}") + + self.assertEqual(response["image"], "image data") + self.assertEqual(header["Authorization"], f"Bearer {api_key}")