Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[MMS] Update docs with HF TTS implementation #25907

Merged
merged 3 commits into from
Sep 1, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 91 additions & 1 deletion docs/source/en/model_doc/mms.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,97 @@ To further improve performance from ASR models, language model decoding can be u

### Speech Synthesis (TTS)

Individual TTS models are available for each of the 1100+ languages. The models and inference documentation can be found [here](https://huggingface.co/facebook/mms-tts).
MMS-TTS uses the same model architecture as VITS, which was added to 🤗 Transformers in v4.33. MMS trains a separate
model checkpoint for each of the 1100+ languages in the project. All available checkpoints can be found on the Hugging
Face Hub: [facebook/mms-tts](https://huggingface.co/models?sort=trending&search=facebook%2Fmms-tts), and the inference
documentation under [VITS](https://huggingface.co/docs/transformers/main/en/model_doc/vits).

#### Inference

To use the MMS model, first update to the latest version of the Transformers library:

```
pip install --upgrade transformers accelerate
```

Since the flow-based model in VITS is non-deterministic, it is good practice to set a seed to ensure reproducibility of
the outputs. For languages with a Roman alphabet, such as English or French, the tokenizer can be used directly to
pre-process the text inputs. The following code example runs a forward pass using the MMS-TTS English checkpoint:
sanchit-gandhi marked this conversation as resolved.
Show resolved Hide resolved

```python
import torch
from transformers import VitsTokenizer, VitsModel, set_seed

tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
model = VitsModel.from_pretrained("facebook/mms-tts-eng")

inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")

set_seed(555) # make deterministic

with torch.no_grad():
outputs = model(**inputs)

waveform = outputs.waveform[0]
```

The resulting waveform can be saved as a `.wav` file:

```python
import scipy

scipy.io.wavfile.write("synthesized_speech.wav", rate=model.config.sampling_rate, data=waveform)
```

Or displayed in a Jupyter Notebook / Google Colab:

```python
from IPython.display import Audio

Audio(waveform, rate=model.config.sampling_rate)
```

For certain languages with non-Roman alphabets, such as Arabic, Mandarin or Hindi, the [`uroman`](https://github.com/isi-nlp/uroman)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
For certain languages with non-Roman alphabets, such as Arabic, Mandarin or Hindi, the [`uroman`](https://github.com/isi-nlp/uroman)
- For certain languages with non-Roman alphabets, such as Arabic, Mandarin or Hindi, the [`uroman`](https://github.com/isi-nlp/uroman)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why add a bullet point here?

perl package is required to pre-process the text inputs to the Roman alphabet.

You can check whether you require the `uroman` package for your language by inspecting the `is_uroman` attribute of
the pre-trained `tokenizer`:

```python
from transformers import VitsTokenizer

tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
print(tokenizer.is_uroman)
```

If required, you should apply the uroman package to your text inputs **prior** to passing them to the `VitsTokenizer`,
since currently the tokenizer does not support performing the pre-processing itself.
Comment on lines +233 to +234
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you think we should add this? With a if self.is_uroman: requires_backend ? Makes sense no?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also add an example of how to do this pre-processing?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a perl package, not a Python package, so we can't import it. The only way of running uroman is through the command line, see discussion here: #24085 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an example for pre-processing: d264692


**Tips:**

* The MMS-TTS checkpoints are trained on lower-cased, un-punctuated text. By default, the `VitsTokenizer` *normalizes* the inputs by removing any casing and punctuation, to avoid passing out-of-vocabulary characters to the model. Hence, the model is agnostic to casing and punctuation, so these should be avoided in the text prompt. You can disable normalisation by setting `noramlize=False` in the call to the tokenizer, but this will lead to un-expected behaviour and is discouraged.
* The speaking rate can be varied by setting the attribute `model.speaking_rate` to a chosen value. Likewise, the randomness of the noise is controlled by `model.noise_scale`:

```python
import torch
from transformers import VitsTokenizer, VitsModel, set_seed

tokenizer = VitsTokenizer.from_pretrained("facebook/mms-tts-eng")
model = VitsModel.from_pretrained("facebook/mms-tts-eng")

inputs = tokenizer(text="Hello - my dog is cute", return_tensors="pt")

# make deterministic
set_seed(555)

# make speech faster and more noisy
model.speaking_rate = 1.5
model.noise_scale = 0.8

with torch.no_grad():
outputs = model(**inputs)
```


### Language Identification (LID)

Expand Down
Loading