Skip to content

Commit

Permalink
FEAT: Support audio model (#929)
Browse files Browse the repository at this point in the history
  • Loading branch information
codingl2k1 authored Jan 25, 2024
1 parent fbe3f8a commit 8069552
Show file tree
Hide file tree
Showing 19 changed files with 830 additions and 12 deletions.
5 changes: 4 additions & 1 deletion .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,15 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/image/tests/test_stable_diffusion.py
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_whisper.py
else
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/client/tests/test_client.py
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py xinference
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/audio/tests/test_whisper.py xinference
fi
working-directory: .
110 changes: 108 additions & 2 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,22 @@ def serve(self, logging_conf: Optional[dict] = None):
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/audio/transcriptions",
self.create_transcriptions,
methods=["POST"],
dependencies=[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/audio/translations",
self.create_translations,
methods=["POST"],
dependencies=[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/images/generations",
self.create_images,
Expand Down Expand Up @@ -879,6 +895,94 @@ async def rerank(self, request: RerankRequest) -> Response:
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_transcriptions(
self,
model: str = Form(...),
file: UploadFile = File(media_type="application/octet-stream"),
language: Optional[str] = Form(None),
prompt: Optional[str] = Form(None),
response_format: Optional[str] = Form("json"),
temperature: Optional[float] = Form(0),
kwargs: Optional[str] = Form(None),
) -> Response:
model_uid = model
try:
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
if kwargs is not None:
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
transcription = await model_ref.transcriptions(
audio=await file.read(),
language=language,
prompt=prompt,
response_format=response_format,
temperature=temperature,
**parsed_kwargs,
)
return Response(content=transcription, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_translations(
self,
model: str = Form(...),
file: UploadFile = File(media_type="application/octet-stream"),
prompt: Optional[str] = Form(None),
response_format: Optional[str] = Form("json"),
temperature: Optional[float] = Form(0),
kwargs: Optional[str] = Form(None),
) -> Response:
model_uid = model
try:
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
if kwargs is not None:
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
translation = await model_ref.translations(
audio=await file.read(),
prompt=prompt,
response_format=response_format,
temperature=temperature,
**parsed_kwargs,
)
return Response(content=translation, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_images(self, request: TextToImageRequest) -> Response:
model_uid = request.model
try:
Expand Down Expand Up @@ -937,15 +1041,17 @@ async def create_variations(

try:
if kwargs is not None:
kwargs = json.loads(kwargs)
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
image_list = await model_ref.image_to_image(
image=Image.open(image.file),
prompt=prompt,
negative_prompt=negative_prompt,
n=n,
size=size,
response_format=response_format,
**kwargs,
**parsed_kwargs,
)
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
Expand Down
113 changes: 113 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,115 @@ def generate(
return response_data


class RESTfulAudioModelHandle(RESTfulModelHandle):
def transcriptions(
self,
audio: bytes,
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: Optional[str] = "json",
temperature: Optional[float] = 0,
):
"""
Transcribes audio into the input language.
Parameters
----------
audio: bytes
The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg,
mpga, m4a, ogg, wav, or webm.
language: Optional[str]
The language of the input audio. Supplying the input language in ISO-639-1
(https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format will improve accuracy and latency.
prompt: Optional[str]
An optional text to guide the model's style or continue a previous audio segment.
The prompt should match the audio language.
response_format: Optional[str], defaults to json
The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
temperature: Optional[float], defaults to 0
The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
If set to 0, the model will use log probability to automatically increase the temperature
until certain thresholds are hit.
Returns
-------
The transcribed text.
"""
url = f"{self._base_url}/v1/audio/transcriptions"
params = {
"model": self._model_uid,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature,
}
files: List[Any] = []
for key, value in params.items():
files.append((key, (None, value)))
files.append(("file", ("file", audio, "application/octet-stream")))
response = requests.post(url, files=files, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to transcribe the audio, detail: {_get_error_string(response)}"
)

response_data = response.json()
return response_data

def translations(
self,
audio: bytes,
prompt: Optional[str] = None,
response_format: Optional[str] = "json",
temperature: Optional[float] = 0,
):
"""
Translates audio into English.
Parameters
----------
audio: bytes
The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg,
mpga, m4a, ogg, wav, or webm.
prompt: Optional[str]
An optional text to guide the model's style or continue a previous audio segment.
The prompt should match the audio language.
response_format: Optional[str], defaults to json
The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
temperature: Optional[float], defaults to 0
The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
If set to 0, the model will use log probability to automatically increase the temperature
until certain thresholds are hit.
Returns
-------
The translated text.
"""
url = f"{self._base_url}/v1/audio/translations"
params = {
"model": self._model_uid,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature,
}
files: List[Any] = []
for key, value in params.items():
files.append((key, (None, value)))
files.append(("file", ("file", audio, "application/octet-stream")))
response = requests.post(url, files=files, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to translate the audio, detail: {_get_error_string(response)}"
)

response_data = response.json()
return response_data


class Client:
def __init__(self, base_url):
self.base_url = base_url
Expand Down Expand Up @@ -803,6 +912,10 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
return RESTfulRerankModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
elif desc["model_type"] == "audio":
return RESTfulAudioModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
else:
raise ValueError(f"Unknown model type:{desc['model_type']}")

Expand Down
4 changes: 2 additions & 2 deletions xinference/client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ def test_client_for_embedding(setup):
assert len(client.list_models()) == 0

model_uid = client.launch_model(
model_name="jina-embeddings-v2-small-en", model_type="embedding"
model_name="bge-small-en-v1.5", model_type="embedding"
)
assert len(client.list_models()) == 1

model = client.get_model(model_uid=model_uid)
assert isinstance(model, EmbeddingModelHandle)

completion = model.create_embedding("write a poem.")
assert len(completion["data"][0]["embedding"]) == 512
assert len(completion["data"][0]["embedding"]) == 384

client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0
Expand Down
8 changes: 3 additions & 5 deletions xinference/client/tests/test_client_with_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,17 @@ def test_client_auth(setup_with_auth):
assert len(client.list_models()) == 0

with pytest.raises(RuntimeError):
client.launch_model(
model_name="jina-embeddings-v2-small-en", model_type="embedding"
)
client.launch_model(model_name="bge-small-en-v1.5", model_type="embedding")

client.login("user3", "pass3")
model_uid = client.launch_model(
model_name="jina-embeddings-v2-small-en", model_type="embedding"
model_name="bge-small-en-v1.5", model_type="embedding"
)
model = client.get_model(model_uid=model_uid)
assert isinstance(model, RESTfulEmbeddingModelHandle)

completion = model.create_embedding("write a poem.")
assert len(completion["data"][0]["embedding"]) == 512
assert len(completion["data"][0]["embedding"]) == 384

with pytest.raises(RuntimeError):
client.terminate_model(model_uid=model_uid)
Expand Down
44 changes: 44 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,50 @@ async def rerank(
)
raise AttributeError(f"Model {self._model.model_spec} is not for reranking.")

@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
@request_limit
async def transcriptions(
self,
audio: bytes,
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: str = "json",
temperature: float = 0,
):
if hasattr(self._model, "transcriptions"):
return await self._call_wrapper(
self._model.transcriptions,
audio,
language,
prompt,
response_format,
temperature,
)
raise AttributeError(
f"Model {self._model.model_spec} is not for creating transcriptions."
)

@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
@request_limit
async def translations(
self,
audio: bytes,
prompt: Optional[str] = None,
response_format: str = "json",
temperature: float = 0,
):
if hasattr(self._model, "translations"):
return await self._call_wrapper(
self._model.translations,
audio,
prompt,
response_format,
temperature,
)
raise AttributeError(
f"Model {self._model.model_spec} is not for creating translations."
)

@log_async(logger=logger)
@request_limit
async def text_to_image(
Expand Down
Loading

0 comments on commit 8069552

Please sign in to comment.