Skip to content

Commit

Permalink
FEAT: support F5-TTS-MLX (#2671)
Browse files Browse the repository at this point in the history
  • Loading branch information
qinxuye authored Dec 14, 2024
1 parent b132fca commit fffcd93
Show file tree
Hide file tree
Showing 6 changed files with 329 additions and 2 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/python.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,9 @@ jobs:
pip install mlx-lm
pip install mlx-vlm
pip install mlx-whisper
pip install f5-tts-mlx
pip install qwen-vl-utils
pip install tomli
else
pip install "llama-cpp-python==0.2.77" --extra-index-url https://abetlen.github.io/llama-cpp-python/whl/cpu
pip install transformers
Expand Down Expand Up @@ -245,7 +247,10 @@ jobs:
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/llm/mlx/tests/test_mlx.py && \
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_whisper_mlx.py
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_whisper_mlx.py && \
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
--cov-config=setup.cfg --cov-report=xml --cov=xinference xinference/model/audio/tests/test_f5tts_mlx.py
else
pytest --timeout=1500 \
-W ignore::PendingDeprecationWarning \
Expand Down
4 changes: 3 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ all =
mlx-lm ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-vlm ; sys_platform=='darwin' and platform_machine=='arm64'
mlx-whisper ; sys_platform=='darwin' and platform_machine=='arm64'
qwen_vl_utils
f5-tts-mlx ; sys_platform=='darwin' and platform_machine=='arm64'
attrdict # For deepseek VL
timm>=0.9.16 # For deepseek VL
torchvision # For deepseek VL
Expand Down Expand Up @@ -185,7 +185,9 @@ mlx =
mlx-lm
mlx-vlm
mlx-whisper
f5-tts-mlx
qwen_vl_utils
tomli
embedding =
sentence-transformers>=3.1.0
rerank =
Expand Down
5 changes: 5 additions & 0 deletions xinference/model/audio/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from .chattts import ChatTTSModel
from .cosyvoice import CosyVoiceModel
from .f5tts import F5TTSModel
from .f5tts_mlx import F5TTSMLXModel
from .fish_speech import FishSpeechModel
from .funasr import FunASRModel
from .whisper import WhisperModel
Expand Down Expand Up @@ -171,6 +172,7 @@ def create_audio_model_instance(
CosyVoiceModel,
FishSpeechModel,
F5TTSModel,
F5TTSMLXModel,
],
AudioModelDescription,
]:
Expand All @@ -185,6 +187,7 @@ def create_audio_model_instance(
CosyVoiceModel,
FishSpeechModel,
F5TTSModel,
F5TTSMLXModel,
]
if model_spec.model_family == "whisper":
if not model_spec.engine:
Expand All @@ -201,6 +204,8 @@ def create_audio_model_instance(
model = FishSpeechModel(model_uid, model_path, model_spec, **kwargs)
elif model_spec.model_family == "F5-TTS":
model = F5TTSModel(model_uid, model_path, model_spec, **kwargs)
elif model_spec.model_family == "F5-TTS-MLX":
model = F5TTSMLXModel(model_uid, model_path, model_spec, **kwargs)
else:
raise Exception(f"Unsupported audio model family: {model_spec.model_family}")
model_description = AudioModelDescription(
Expand Down
257 changes: 257 additions & 0 deletions xinference/model/audio/f5tts_mlx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,257 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import datetime
import logging
import os
import tempfile
from io import BytesIO
from pathlib import Path
from typing import TYPE_CHECKING, Literal, Optional

import numpy as np
from tqdm import tqdm

if TYPE_CHECKING:
from .core import AudioModelFamilyV1

logger = logging.getLogger(__name__)


class F5TTSMLXModel:
def __init__(
self,
model_uid: str,
model_path: str,
model_spec: "AudioModelFamilyV1",
device: Optional[str] = None,
**kwargs,
):
self._model_uid = model_uid
self._model_path = model_path
self._model_spec = model_spec
self._device = device
self._model = None
self._kwargs = kwargs
self._model = None

@property
def model_ability(self):
return self._model_spec.model_ability

def load(self):
try:
import mlx.core as mx
from f5_tts_mlx.cfm import F5TTS
from f5_tts_mlx.dit import DiT
from f5_tts_mlx.duration import DurationPredictor, DurationTransformer
from vocos_mlx import Vocos
except ImportError:
error_message = "Failed to import module 'f5_tts_mlx'"
installation_guide = [
"Please make sure 'f5_tts_mlx' is installed.\n",
]

raise ImportError(f"{error_message}\n\n{''.join(installation_guide)}")

path = Path(self._model_path)
# vocab

vocab_path = path / "vocab.txt"
vocab = {v: i for i, v in enumerate(Path(vocab_path).read_text().split("\n"))}
if len(vocab) == 0:
raise ValueError(f"Could not load vocab from {vocab_path}")

# duration predictor

duration_model_path = path / "duration_v2.safetensors"
duration_predictor = None

if duration_model_path.exists():
duration_predictor = DurationPredictor(
transformer=DurationTransformer(
dim=512,
depth=8,
heads=8,
text_dim=512,
ff_mult=2,
conv_layers=2,
text_num_embeds=len(vocab) - 1,
),
vocab_char_map=vocab,
)
weights = mx.load(duration_model_path.as_posix(), format="safetensors")
duration_predictor.load_weights(list(weights.items()))

# vocoder

vocos = Vocos.from_pretrained("lucasnewman/vocos-mel-24khz")

# model

model_path = path / "model.safetensors"

f5tts = F5TTS(
transformer=DiT(
dim=1024,
depth=22,
heads=16,
ff_mult=2,
text_dim=512,
conv_layers=4,
text_num_embeds=len(vocab) - 1,
),
vocab_char_map=vocab,
vocoder=vocos.decode,
duration_predictor=duration_predictor,
)

weights = mx.load(model_path.as_posix(), format="safetensors")
f5tts.load_weights(list(weights.items()))
mx.eval(f5tts.parameters())

self._model = f5tts

def speech(
self,
input: str,
voice: str,
response_format: str = "mp3",
speed: float = 1.0,
stream: bool = False,
**kwargs,
):
import mlx.core as mx
import soundfile as sf
import tomli
from f5_tts_mlx.generate import (
FRAMES_PER_SEC,
SAMPLE_RATE,
TARGET_RMS,
convert_char_to_pinyin,
split_sentences,
)

if stream:
raise Exception("F5-TTS does not support stream generation.")

prompt_speech: Optional[bytes] = kwargs.pop("prompt_speech", None)
prompt_text: Optional[str] = kwargs.pop("prompt_text", None)
duration: Optional[float] = kwargs.pop("duration", None)
steps: Optional[int] = kwargs.pop("steps", 8)
cfg_strength: Optional[float] = kwargs.pop("cfg_strength", 2.0)
method: Literal["euler", "midpoint"] = kwargs.pop("method", "rk4")
sway_sampling_coef: float = kwargs.pop("sway_sampling_coef", -1.0)
seed: Optional[int] = kwargs.pop("seed", None)

if prompt_speech is None:
base = os.path.join(os.path.dirname(__file__), "../../thirdparty/f5_tts")
config = os.path.join(base, "infer/examples/basic/basic.toml")
with open(config, "rb") as f:
config_dict = tomli.load(f)
prompt_speech_path = os.path.join(base, config_dict["ref_audio"])
prompt_text = config_dict["ref_text"]
else:
with tempfile.NamedTemporaryFile(delete=False) as f: # type: ignore
f.write(prompt_speech)
prompt_speech_path = f.name

if prompt_text is None:
raise ValueError("`prompt_text` cannot be empty")

audio, sr = sf.read(prompt_speech_path)
audio = mx.array(audio)
ref_audio_duration = audio.shape[0] / SAMPLE_RATE
logger.debug(
f"Got reference audio with duration: {ref_audio_duration:.2f} seconds"
)

rms = mx.sqrt(mx.mean(mx.square(audio)))
if rms < TARGET_RMS:
audio = audio * TARGET_RMS / rms

sentences = split_sentences(input)
is_single_generation = len(sentences) <= 1 or duration is not None

if is_single_generation:
generation_text = convert_char_to_pinyin([prompt_text + " " + input]) # type: ignore

if duration is not None:
duration = int(duration * FRAMES_PER_SEC)

start_date = datetime.datetime.now()

wave, _ = self._model.sample( # type: ignore
mx.expand_dims(audio, axis=0),
text=generation_text,
duration=duration,
steps=steps,
method=method,
speed=speed,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
)

wave = wave[audio.shape[0] :]
mx.eval(wave)

generated_duration = wave.shape[0] / SAMPLE_RATE
print(
f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
)

else:
start_date = datetime.datetime.now()

output = []

for sentence_text in tqdm(split_sentences(input)):
text = convert_char_to_pinyin([prompt_text + " " + sentence_text]) # type: ignore

if duration is not None:
duration = int(duration * FRAMES_PER_SEC)

wave, _ = self._model.sample( # type: ignore
mx.expand_dims(audio, axis=0),
text=text,
duration=duration,
steps=steps,
method=method,
speed=speed,
cfg_strength=cfg_strength,
sway_sampling_coef=sway_sampling_coef,
seed=seed,
)

# trim the reference audio
wave = wave[audio.shape[0] :]
mx.eval(wave)

output.append(wave)

wave = mx.concatenate(output, axis=0)

generated_duration = wave.shape[0] / SAMPLE_RATE
logger.debug(
f"Generated {generated_duration:.2f}s of audio in {datetime.datetime.now() - start_date}."
)

# Save the generated audio
with BytesIO() as out:
with sf.SoundFile(
out, "w", SAMPLE_RATE, 1, format=response_format.upper()
) as f:
f.write(np.array(wave))
return out.getvalue()
8 changes: 8 additions & 0 deletions xinference/model/audio/model_spec.json
Original file line number Diff line number Diff line change
Expand Up @@ -250,5 +250,13 @@
"model_revision": "4dcc16f297f2ff98a17b3726b16f5de5a5e45672",
"model_ability": "text-to-audio",
"multilingual": true
},
{
"model_name": "F5-TTS-MLX",
"model_family": "F5-TTS-MLX",
"model_id": "lucasnewman/f5-tts-mlx",
"model_revision": "7642bb232e3fcacf92c51c786edebb8624da6b93",
"model_ability": "text-to-audio",
"multilingual": true
}
]
50 changes: 50 additions & 0 deletions xinference/model/audio/tests/test_f5tts_mlx.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Copyright 2022-2023 XProbe Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile


def test_f5tts_mlx(setup):
endpoint, _ = setup
from ....client import Client

client = Client(endpoint)

model_uid = client.launch_model(
model_name="F5-TTS-MLX",
model_type="audio",
download_hub="huggingface",
)
model = client.get_model(model_uid)
input_string = (
"chat T T S is a text to speech model designed for dialogue applications."
)
response = model.speech(input_string)
assert type(response) is bytes
assert len(response) > 0

with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as f:
f.write(response)

# Test openai API
import openai

client = openai.Client(api_key="not empty", base_url=f"{endpoint}/v1")
with client.audio.speech.with_streaming_response.create(
model=model_uid, input=input_string, voice="echo"
) as response:
with tempfile.NamedTemporaryFile(suffix=".mp3", delete=True) as f:
response.stream_to_file(f.name)
assert os.stat(f.name).st_size > 0

0 comments on commit fffcd93

Please sign in to comment.