Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Support audio model #929

Merged
merged 13 commits into from
Jan 25, 2024
5 changes: 4 additions & 1 deletion .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -134,12 +134,15 @@ jobs:
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/image/tests/test_stable_diffusion.py
${{ env.SELF_HOST_PYTHON }} -m pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_whisper.py
else
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/client/tests/test_client.py
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py xinference
--cov-config=setup.cfg --cov-report=xml --cov=xinference --ignore xinference/client/tests/test_client.py --ignore xinference/model/image/tests/test_stable_diffusion.py --ignore xinference/model/audio/tests/test_whisper.py xinference
fi
working-directory: .
110 changes: 108 additions & 2 deletions xinference/api/restful_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,22 @@ def serve(self, logging_conf: Optional[dict] = None):
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/audio/transcriptions",
self.create_transcriptions,
methods=["POST"],
dependencies=[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/audio/translations",
self.create_translations,
methods=["POST"],
dependencies=[Security(self._auth_service, scopes=["models:read"])]
if self.is_authenticated()
else None,
)
self._router.add_api_route(
"/v1/images/generations",
self.create_images,
Expand Down Expand Up @@ -879,6 +895,94 @@ async def rerank(self, request: RerankRequest) -> Response:
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_transcriptions(
self,
model: str = Form(...),
file: UploadFile = File(media_type="application/octet-stream"),
language: Optional[str] = Form(None),
prompt: Optional[str] = Form(None),
response_format: Optional[str] = Form("json"),
temperature: Optional[float] = Form(0),
kwargs: Optional[str] = Form(None),
) -> Response:
model_uid = model
try:
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
if kwargs is not None:
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
transcription = await model_ref.transcriptions(
audio=await file.read(),
language=language,
prompt=prompt,
response_format=response_format,
temperature=temperature,
**parsed_kwargs,
)
return Response(content=transcription, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_translations(
self,
model: str = Form(...),
file: UploadFile = File(media_type="application/octet-stream"),
prompt: Optional[str] = Form(None),
response_format: Optional[str] = Form("json"),
temperature: Optional[float] = Form(0),
kwargs: Optional[str] = Form(None),
) -> Response:
model_uid = model
try:
model_ref = await (await self._get_supervisor_ref()).get_model(model_uid)
except ValueError as ve:
logger.error(str(ve), exc_info=True)
await self._report_error_event(model_uid, str(ve))
raise HTTPException(status_code=400, detail=str(ve))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

try:
if kwargs is not None:
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
translation = await model_ref.translations(
audio=await file.read(),
prompt=prompt,
response_format=response_format,
temperature=temperature,
**parsed_kwargs,
)
return Response(content=translation, media_type="application/json")
except RuntimeError as re:
logger.error(re, exc_info=True)
await self._report_error_event(model_uid, str(re))
raise HTTPException(status_code=400, detail=str(re))
except Exception as e:
logger.error(e, exc_info=True)
await self._report_error_event(model_uid, str(e))
raise HTTPException(status_code=500, detail=str(e))

async def create_images(self, request: TextToImageRequest) -> Response:
model_uid = request.model
try:
Expand Down Expand Up @@ -937,15 +1041,17 @@ async def create_variations(

try:
if kwargs is not None:
kwargs = json.loads(kwargs)
parsed_kwargs = json.loads(kwargs)
else:
parsed_kwargs = {}
image_list = await model_ref.image_to_image(
image=Image.open(image.file),
prompt=prompt,
negative_prompt=negative_prompt,
n=n,
size=size,
response_format=response_format,
**kwargs,
**parsed_kwargs,
)
return Response(content=image_list, media_type="application/json")
except RuntimeError as re:
Expand Down
113 changes: 113 additions & 0 deletions xinference/client/restful/restful_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,115 @@ def generate(
return response_data


class RESTfulAudioModelHandle(RESTfulModelHandle):
def transcriptions(
self,
audio: bytes,
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: Optional[str] = "json",
temperature: Optional[float] = 0,
):
"""
Transcribes audio into the input language.

Parameters
----------

audio: bytes
The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg,
mpga, m4a, ogg, wav, or webm.
language: Optional[str]
The language of the input audio. Supplying the input language in ISO-639-1
(https://en.wikipedia.org/wiki/List_of_ISO_639_language_codes) format will improve accuracy and latency.
prompt: Optional[str]
An optional text to guide the model's style or continue a previous audio segment.
The prompt should match the audio language.
response_format: Optional[str], defaults to json
The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
temperature: Optional[float], defaults to 0
The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
If set to 0, the model will use log probability to automatically increase the temperature
until certain thresholds are hit.

Returns
-------
The transcribed text.
"""
url = f"{self._base_url}/v1/audio/transcriptions"
params = {
"model": self._model_uid,
"language": language,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature,
}
files: List[Any] = []
for key, value in params.items():
files.append((key, (None, value)))
files.append(("file", ("file", audio, "application/octet-stream")))
response = requests.post(url, files=files, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to transcribe the audio, detail: {_get_error_string(response)}"
)

response_data = response.json()
return response_data

def translations(
self,
audio: bytes,
prompt: Optional[str] = None,
response_format: Optional[str] = "json",
temperature: Optional[float] = 0,
):
"""
Translates audio into English.

Parameters
----------

audio: bytes
The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg,
mpga, m4a, ogg, wav, or webm.
prompt: Optional[str]
An optional text to guide the model's style or continue a previous audio segment.
The prompt should match the audio language.
response_format: Optional[str], defaults to json
The format of the transcript output, in one of these options: json, text, srt, verbose_json, or vtt.
temperature: Optional[float], defaults to 0
The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random,
while lower values like 0.2 will make it more focused and deterministic.
If set to 0, the model will use log probability to automatically increase the temperature
until certain thresholds are hit.

Returns
-------
The translated text.
"""
url = f"{self._base_url}/v1/audio/translations"
params = {
"model": self._model_uid,
"prompt": prompt,
"response_format": response_format,
"temperature": temperature,
}
files: List[Any] = []
for key, value in params.items():
files.append((key, (None, value)))
files.append(("file", ("file", audio, "application/octet-stream")))
response = requests.post(url, files=files, headers=self.auth_headers)
if response.status_code != 200:
raise RuntimeError(
f"Failed to translate the audio, detail: {_get_error_string(response)}"
)

response_data = response.json()
return response_data


class Client:
def __init__(self, base_url):
self.base_url = base_url
Expand Down Expand Up @@ -803,6 +912,10 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
return RESTfulRerankModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
elif desc["model_type"] == "audio":
return RESTfulAudioModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
else:
raise ValueError(f"Unknown model type:{desc['model_type']}")

Expand Down
4 changes: 2 additions & 2 deletions xinference/client/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,15 +116,15 @@ def test_client_for_embedding(setup):
assert len(client.list_models()) == 0

model_uid = client.launch_model(
model_name="jina-embeddings-v2-small-en", model_type="embedding"
model_name="bge-small-en-v1.5", model_type="embedding"
)
assert len(client.list_models()) == 1

model = client.get_model(model_uid=model_uid)
assert isinstance(model, EmbeddingModelHandle)

completion = model.create_embedding("write a poem.")
assert len(completion["data"][0]["embedding"]) == 512
assert len(completion["data"][0]["embedding"]) == 384

client.terminate_model(model_uid=model_uid)
assert len(client.list_models()) == 0
Expand Down
8 changes: 3 additions & 5 deletions xinference/client/tests/test_client_with_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,17 @@ def test_client_auth(setup_with_auth):
assert len(client.list_models()) == 0

with pytest.raises(RuntimeError):
client.launch_model(
model_name="jina-embeddings-v2-small-en", model_type="embedding"
)
client.launch_model(model_name="bge-small-en-v1.5", model_type="embedding")

client.login("user3", "pass3")
model_uid = client.launch_model(
model_name="jina-embeddings-v2-small-en", model_type="embedding"
model_name="bge-small-en-v1.5", model_type="embedding"
)
model = client.get_model(model_uid=model_uid)
assert isinstance(model, RESTfulEmbeddingModelHandle)

completion = model.create_embedding("write a poem.")
assert len(completion["data"][0]["embedding"]) == 512
assert len(completion["data"][0]["embedding"]) == 384

with pytest.raises(RuntimeError):
client.terminate_model(model_uid=model_uid)
Expand Down
44 changes: 44 additions & 0 deletions xinference/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,50 @@ async def rerank(
)
raise AttributeError(f"Model {self._model.model_spec} is not for reranking.")

@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
@request_limit
async def transcriptions(
self,
audio: bytes,
language: Optional[str] = None,
prompt: Optional[str] = None,
response_format: str = "json",
temperature: float = 0,
):
if hasattr(self._model, "transcriptions"):
return await self._call_wrapper(
self._model.transcriptions,
audio,
language,
prompt,
response_format,
temperature,
)
raise AttributeError(
f"Model {self._model.model_spec} is not for creating transcriptions."
)

@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio"))
@request_limit
async def translations(
self,
audio: bytes,
prompt: Optional[str] = None,
response_format: str = "json",
temperature: float = 0,
):
if hasattr(self._model, "translations"):
return await self._call_wrapper(
self._model.translations,
audio,
prompt,
response_format,
temperature,
)
raise AttributeError(
f"Model {self._model.model_spec} is not for creating translations."
)

@log_async(logger=logger)
@request_limit
async def text_to_image(
Expand Down
Loading
Loading