Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

5 high priosdxl #15

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
54 changes: 54 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,58 @@ transcription = asr.generate(body)
print(transcription["text"])
```

#### Generate an image using SDXL

```python
from deepinfra import Sdxl
import requests

model = Sdxl()
body = {
"input":{
"prompt": "A happy little cloud"
}
}
resp = model.generate(body)
image = resp["images"][0]
image_url = resp["output"][0]

# Write image_url to image.png
with open("image.png", "wb") as file:
file.write(requests.get(image_url).content)

```

#### Generate an image using a text-to-image model

```python
from deepinfra import TextToImage
import base64

model_name = "CompVis/stable-diffusion-v1-4"
model = TextToImage(model_name)
body = {
"prompt": "A happy little cloud"
}
resp = model.generate(body)
image = resp["image"]
image = image.replace("data:image/png;base64,", "")

with open("image.png", "wb") as fp:
fp.write(base64.b64decode(image))
```

#### Generate text using LLM

```python
from deepinfra import TextGeneration

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = TextGeneration(model_name)
body = {
"input": "Write a story about a happy little cloud"
}
resp = model.generate(body)
result = resp["results"][0]["generated_text"]
print(result)
```
1 change: 1 addition & 0 deletions deepinfra/constants/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1,2 @@
from .client import MAX_RETRIES, USER_AGENT, INITIAL_BACKOFF, SUBSEQUENT_BACKOFF
from .models import SDXL
1 change: 1 addition & 0 deletions deepinfra/constants/models.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SDXL = "stability-ai/sdxl"
2 changes: 1 addition & 1 deletion deepinfra/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .max_retries_exceeded import MaxRetriesExceededError
from .max_retries_exceeded import MaxRetriesExceededError
2 changes: 2 additions & 0 deletions deepinfra/models/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
from .text_to_image import TextToImage
from .embeddings import Embeddings
from .automatic_speech_recognition import AutomaticSpeechRecognition
from .sdxl import Sdxl
from .text_generation import TextGeneration
2 changes: 1 addition & 1 deletion deepinfra/models/base/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class AutomaticSpeechRecognition(BaseModel):
@docs Check the available models at https://deepinfra.com/models/automatic-speech-recognition
"""

def generate(self, body) -> AutomaticSpeechRecognitionResponse:
def generate(self, body):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess this steps on top of the other PR, so the same arguments apply. The generate method absolutely needs annotations on the input (body) and response.

Now I see the code is blocking (i.e not async). We should provide sync and async variants in that case.

Copy link
Owner Author

@ovuruska ovuruska May 4, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for annotation, I got an idea which may be scope of another PR.

Example usage
`
request = EmbeddingsRequest(inputs=["Hello World!"])
model = Embeddings()
model.generate(request)

`

By creating a request class for each base class, we can
1-) Validate the input before request is sent.
2-) Increase iteration speed for our users by improving code completion aspects.

"""
Generates the automatic speech recognition response.
@param body: The request body.
Expand Down
2 changes: 1 addition & 1 deletion deepinfra/models/base/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ class Embeddings(BaseModel):
@docs Check the available models at https://deepinfra.com/models/embeddings
"""

def generate(self, body) -> EmbeddingsResponse:
def generate(self, body):
"""
Generates embeddings.
:param body:
Expand Down
25 changes: 25 additions & 0 deletions deepinfra/models/base/sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import json

from black import Optional

from deepinfra.models.base import BaseModel
from deepinfra.constants import SDXL


class Sdxl(BaseModel):
"""
Class for the SDXL model.
@docs Check the model at https://deepinfra.com/stability-ai/sdxl/api
"""

def __init__(self, api_token: Optional[str] = None):
super().__init__(SDXL, api_token)

def generate(self, body: dict):
"""
Generates an image.
:param input:
:return:
"""
response = self.client.post(json.dumps(body))
return response.json()
2 changes: 0 additions & 2 deletions deepinfra/models/base/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
"""

import json
from typing import Union

from deepinfra.models.base import BaseModel
from deepinfra.types.text_generation.request import TextGenerationRequest
from deepinfra.types.text_generation.response import TextGenerationResponse


Expand Down
1 change: 1 addition & 0 deletions deepinfra/models/base/text_to_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ def generate(self, input) -> TextToImageResponse:
:param input:
:return:
"""

body = {"input": input}
response = self.client.post(json.dumps(input))
return TextToImageResponse(**response.json())
1 change: 0 additions & 1 deletion deepinfra/types/automatic_speech_recognition/response.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import List

from dataclasses import dataclass

from deepinfra.types.common.inference_status import InferenceStatus
Expand Down
Empty file.
8 changes: 8 additions & 0 deletions deepinfra/types/sdxl/request.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from dataclasses import dataclass

from deepinfra.types.text_to_image import TextToImageRequest


@dataclass
class SdxlRequest:
input: TextToImageRequest
16 changes: 16 additions & 0 deletions deepinfra/types/sdxl/response.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
from typing import List

from dataclasses import dataclass


@dataclass
class SdxlResponseItem:
format: str
type: str


@dataclass
class SdxlResponse:
items: List[SdxlResponseItem]
title: str
type: str
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
setuptools
setuptools-scm
requests
dataclasses
requests-toolbelt
requests-toolbelt
pydantic
28 changes: 28 additions & 0 deletions tests/test_sdxl.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest
from unittest.mock import patch

from deepinfra import TextToImage
from deepinfra.constants import SDXL
from deepinfra.models.base.sdxl import Sdxl

model_name = SDXL
api_key = "API KEY"


class TestSdxl(unittest.TestCase):
@patch("requests.post")
def test_generate(self, mock_post):
mock_post.return_value.status_code = 200
mock_post.return_value.json.return_value = {"image": "image data"}

text_to_image = Sdxl(api_key)
body = {"input": {"prompt": "Hello, World!"}}
response = text_to_image.generate(body)

called_args, called_kwargs = mock_post.call_args
url = called_args[0]
header = called_kwargs["headers"]
self.assertEqual(url, f"https://api.deepinfra.com/v1/inference/{model_name}")

self.assertEqual(response["image"], "image data")
self.assertEqual(header["Authorization"], f"Bearer {api_key}")
Loading