Skip to content

Commit

Permalink
feat(base models): SDXL and Text to image
Browse files Browse the repository at this point in the history
  • Loading branch information
ovuruska committed Apr 30, 2024
1 parent 9b83c9f commit 833de66
Show file tree
Hide file tree
Showing 6 changed files with 62 additions and 9 deletions.
53 changes: 53 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,57 @@ transcription = asr.generate(body)
print(transcription["text"])
```

#### Generate an image using SDXL

```python
from deepinfra import Sdxl
import base64

model = Sdxl()
body = {
"input":{
"prompt": "A happy little cloud"
}
}
resp = model.generate(body)
image = resp["images"][0]
image = image.replace("data:image/png;base64,", "")

with open("image.png", "wb") as fp:
fp.write(base64.b64decode(image))

```

#### Generate an image using a text-to-image model

```python
from deepinfra import TextToImage
import base64

model_name = "CompVis/stable-diffusion-v1-4"
model = TextToImage(model_name)
body = {
"prompt": "A happy little cloud"
}
resp = model.generate(body)
image = resp["image"]
image = image.replace("data:image/png;base64,", "")

with open("image.png", "wb") as fp:
fp.write(base64.b64decode(image))
```

#### Generate text using LLM

```python
from deepinfra import TextGeneration

model_name = "meta-llama/Meta-Llama-3-8B-Instruct"
model = TextGeneration(model_name)
body = {
"input": "Write a story about a happy little cloud"
}
resp = model.generate(body)
result = resp["results"][0]["generated_text"]
print(result)
```
2 changes: 2 additions & 0 deletions deepinfra/models/base/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,5 @@
from .text_to_image import TextToImage
from .embeddings import Embeddings
from .automatic_speech_recognition import AutomaticSpeechRecognition
from .sdxl import Sdxl
from .text_generation import TextGeneration
2 changes: 1 addition & 1 deletion deepinfra/models/base/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class AutomaticSpeechRecognition(BaseModel):
@docs Check the available models at https://deepinfra.com/models/automatic-speech-recognition
"""

def generate(self, body) -> AutomaticSpeechRecognitionResponse:
def generate(self, body):
"""
Generates the automatic speech recognition response.
@param body: The request body.
Expand Down
3 changes: 1 addition & 2 deletions deepinfra/models/base/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from deepinfra.models.base import BaseModel
from deepinfra.types.embeddings.request import EmbeddingsRequest
from deepinfra.types.embeddings.response import EmbeddingsResponse


Expand All @@ -8,7 +7,7 @@ class Embeddings(BaseModel):
@docs Check the available models at https://deepinfra.com/models/embeddings
"""

def generate(self, body) -> EmbeddingsResponse:
def generate(self, body):
"""
Generates embeddings.
:param body:
Expand Down
6 changes: 2 additions & 4 deletions deepinfra/models/base/sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,8 @@

from black import Optional

from deepinfra import BaseModel
from deepinfra.models.base import BaseModel
from deepinfra.constants import SDXL
from deepinfra.types.sdxl.request import SdxlRequest
from deepinfra.types.sdxl.response import SdxlResponse


class Sdxl(BaseModel):
Expand All @@ -17,7 +15,7 @@ class Sdxl(BaseModel):
def __init__(self, api_token: Optional[str] = None):
super().__init__(SDXL, api_token)

def generate(self, body: dict) -> SdxlResponse:
def generate(self, body: dict):
"""
Generates an image.
:param input:
Expand Down
5 changes: 3 additions & 2 deletions deepinfra/models/base/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
which is the base class for all text generation models.
"""

import json
from typing import Union

from deepinfra.models.base import BaseModel
Expand All @@ -16,11 +17,11 @@ class TextGeneration(BaseModel):
@docs Check the available models at https://deepinfra.com/models/text-generation
"""

def generate(self, body: dict) -> TextGenerationResponse:
def generate(self, body: dict):
"""
Generates text.
:param body:
:return:
"""
response = self.client.post(body)
response = self.client.post(json.dumps(body))
return response.json()

0 comments on commit 833de66

Please sign in to comment.