Skip to content

Commit

Permalink
Merge pull request #12 from ovuruska/10-type-annotations
Browse files Browse the repository at this point in the history
Type annotations
  • Loading branch information
ovuruska authored Apr 29, 2024
2 parents 513364b + b03bb15 commit cc121f6
Show file tree
Hide file tree
Showing 20 changed files with 51 additions and 64 deletions.
4 changes: 4 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ jobs:
pip install -r requirements.txt
pip install -r requirements-dev.txt
- name: Run type check
run: |
mypy --install-types deepinfra --non-interactive --verbose
- name: Run lint check
run: |
black --check --verbose deepinfra
12 changes: 2 additions & 10 deletions deepinfra/models/base/automatic_speech_recognition.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
The automatic speech recognition model.
"""

from deepinfra.models.base import BaseModel
from deepinfra.types.automatic_speech_recognition.response import (
AutomaticSpeechRecognitionResponse,
Expand All @@ -9,19 +10,10 @@


class AutomaticSpeechRecognition(BaseModel):

"""
@docs Check the available models at https://deepinfra.com/models/automatic-speech-recognition
"""

def __init__(self, endpoint: str, auth_token: str = None):
"""
Initializes the automatic speech recognition model.
@param endpoint: The endpoint of the model or the model name.
@param auth_token: The API key to authenticate the requests.
"""
super().__init__(endpoint, auth_token)

def generate(self, body) -> AutomaticSpeechRecognitionResponse:
"""
Generates the automatic speech recognition response.
Expand All @@ -34,4 +26,4 @@ def generate(self, body) -> AutomaticSpeechRecognitionResponse:
response = self.client.post(
form_data, {"headers": {"content-type": form_data.content_type}}
)
return response.json()
return AutomaticSpeechRecognitionResponse(**response.json())
4 changes: 3 additions & 1 deletion deepinfra/models/base/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
Base class for all models.
"""

import os
from typing import Optional

from deepinfra.clients import DeepInfraClient
from deepinfra.constants.client import ROOT_URL
Expand All @@ -15,7 +17,7 @@ class BaseModel:
@param auth_token: The API key to authenticate the requests.
"""

def __init__(self, endpoint, auth_token: str = None):
def __init__(self, endpoint, auth_token: Optional[str] = None):
if URLUtils.is_valid_url(endpoint):
self.endpoint = endpoint
else:
Expand Down
12 changes: 2 additions & 10 deletions deepinfra/models/base/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,19 +8,11 @@ class Embeddings(BaseModel):
@docs Check the available models at https://deepinfra.com/models/embeddings
"""

def __init__(self, endpoint: str, auth_token: str):
"""
Initializes the embeddings model.
@param endpoint: The endpoint of the model or the model name.
@param auth_token: The API key to authenticate the requests.
"""
super().__init__(endpoint, auth_token)

def generate(self, body: EmbeddingsRequest) -> EmbeddingsResponse:
def generate(self, body) -> EmbeddingsResponse:
"""
Generates embeddings.
:param body:
:return:
"""
response = self.client.post(body)
return response
return EmbeddingsResponse(**response.json())
15 changes: 5 additions & 10 deletions deepinfra/models/base/text_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
This module contains the TextGeneration class,
which is the base class for all text generation models.
"""

from typing import Union

from deepinfra.models.base import BaseModel
from deepinfra.types.text_generation.request import TextGenerationRequest
from deepinfra.types.text_generation.response import TextGenerationResponse
Expand All @@ -13,19 +16,11 @@ class TextGeneration(BaseModel):
@docs Check the available models at https://deepinfra.com/models/text-generation
"""

def __init__(self, endpoint: str, auth_token: str):
"""
Initializes the text generation model.
:param endpoint:
:param auth_token:
"""
super().__init__(endpoint, auth_token)

def generate(self, body: TextGenerationRequest) -> TextGenerationResponse:
def generate(self, body: dict) -> TextGenerationResponse:
"""
Generates text.
:param body:
:return:
"""
response = self.client.post(body)
return response
return TextGenerationResponse(**response.json())
12 changes: 2 additions & 10 deletions deepinfra/models/base/text_to_image.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,13 @@
from deepinfra.models.base import BaseModel
from deepinfra.types.text_to_image import TextToImageResponse


class TextToImage(BaseModel):

"""
Initializes one of the DeepInfra image generation models.
@docs Check the available models at https://deepinfra.com/models/text-to-image
"""

def __init__(self, endpoint: str, auth_token: str):
"""
Initializes the image generation model.
@param endpoint: The endpoint of the model or the model name.
@param auth_token: The API key to authenticate the requests. If not provided, it will be fetched from the environment.
"""
super().__init__(endpoint, auth_token)

def generate(self, input):
"""
Generates an image.
Expand All @@ -24,4 +16,4 @@ def generate(self, input):
"""
body = {"input": input}
response = self.client.post(body)
return response
return TextToImageResponse(**response.json())
2 changes: 1 addition & 1 deletion deepinfra/types/automatic_speech_recognition/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional


@dataclass(kw_only=True)
@dataclass
class AutomaticSpeechRecognitionRequest:
audio: str
task: Optional[str]
Expand Down
6 changes: 3 additions & 3 deletions deepinfra/types/automatic_speech_recognition/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,15 @@
from deepinfra.types.common.inference_status import InferenceStatus


@dataclass(kw_only=True)
@dataclass
class AutomaticSpeechRecognitionWord:
text: str
start: int
end: int
confidence: float


@dataclass(kw_only=True)
@dataclass
class AutomaticSpeechRecognitionSegment:
id: int
seek: int
Expand All @@ -29,7 +29,7 @@ class AutomaticSpeechRecognitionSegment:
words: List[AutomaticSpeechRecognitionWord]


@dataclass(kw_only=True)
@dataclass
class AutomaticSpeechRecognitionResponse:
text: str
segments: List[AutomaticSpeechRecognitionSegment]
Expand Down
4 changes: 2 additions & 2 deletions deepinfra/types/common/inference_status.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from dataclasses import dataclass


@dataclass(kw_only=True)
@dataclass
class InferenceStatus:
status: str
runtime_ms: int
Expand All @@ -10,7 +10,7 @@ class InferenceStatus:
tokens_input: int


@dataclass(kw_only=True)
@dataclass
class Status:
UNKNOWN = "unknown"
QUEUED = "queued"
Expand Down
2 changes: 1 addition & 1 deletion deepinfra/types/embeddings/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, List


@dataclass(kw_only=True)
@dataclass
class EmbeddingsRequest:
inputs: List[str]
normalize: Optional[bool] = None
Expand Down
4 changes: 2 additions & 2 deletions deepinfra/types/embeddings/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, List


@dataclass(kw_only=True)
@dataclass
class EmbeddingStatus:
status: str
runtime_ms: int
Expand All @@ -11,7 +11,7 @@ class EmbeddingStatus:
tokens_input: int


@dataclass(kw_only=True)
@dataclass
class EmbeddingsResponse:
embeddings: List[List[float]]
input_tokens: int
Expand Down
2 changes: 0 additions & 2 deletions deepinfra/types/image_generation/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion deepinfra/types/text_generation/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Optional, List, Dict


@dataclass(kw_only=True)
@dataclass
class TextGenerationRequest:
input: str
stream: Optional[bool] = None
Expand Down
4 changes: 2 additions & 2 deletions deepinfra/types/text_generation/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
from deepinfra.types.common.inference_status import InferenceStatus


@dataclass(kw_only=True)
@dataclass
class GeneratedText:
generated_text: str


@dataclass(kw_only=True)
@dataclass
class TextGenerationResponse:
request_id: str
inference_status: InferenceStatus
Expand Down
2 changes: 2 additions & 0 deletions deepinfra/types/text_to_image/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .request import TextToImageRequest
from .response import TextToImageResponse
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@
from typing import Optional


@dataclass(kw_only=True)
class ImageGenerationRequest:
@dataclass
class TextToImageRequest:
prompt: str
negative_prompt: Optional[str] = None
image: Optional[str] = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
from deepinfra.types.common.inference_status import InferenceStatus


@dataclass(kw_only=True)
@dataclass
class Metrics:
predict_time: int


@dataclass(kw_only=True)
class ImageGenerationResponse:
@dataclass
class TextToImageResponse:
request_id: str
inference_status: InferenceStatus
input: Dict
Expand Down
8 changes: 5 additions & 3 deletions deepinfra/utils/form_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import json
from typing import List, Optional

from requests_toolbelt import MultipartEncoder

Expand All @@ -12,19 +12,21 @@ class FormDataUtils:
"""

@staticmethod
def get_form_data(data, blob_keys=()):
def get_form_data(data: dict, blob_keys: Optional[List[str]] = None):
"""
Creates a MultipartEncoder object from the data.
:param data:
:param blob_keys:
:return:
"""
if blob_keys is None:
blob_keys = list()
body = {}

for key, value in data.items():
if key in blob_keys:
body[key] = (key, ReadStreamUtils.get_read_stream(value))
else:
body[key] = json.dumps(value)
body[key] = value

return MultipartEncoder(fields=body)
6 changes: 6 additions & 0 deletions mypy.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
[mypy]
check_untyped_defs = False

# requests_toolbelt is not typed
[mypy-requests_toolbelt.*]
ignore_missing_imports = True
4 changes: 3 additions & 1 deletion requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1 +1,3 @@
black==23.3.0
black==23.3.0
mypy
types-requests

0 comments on commit cc121f6

Please sign in to comment.