forked from NVIDIA/NeMo
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add support for finetuning with huggingface datasets (NVIDIA#7834)
* add finetune with huggingface dataset Signed-off-by: stevehuang52 <heh@nvidia.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update yaml Signed-off-by: stevehuang52 <heh@nvidia.com> * update Signed-off-by: stevehuang52 <heh@nvidia.com> * update and refactor Signed-off-by: stevehuang52 <heh@nvidia.com> * add extrac hf text and update Signed-off-by: stevehuang52 <heh@nvidia.com> * update and refactor Signed-off-by: stevehuang52 <heh@nvidia.com> * move dataset dependency to common Signed-off-by: stevehuang52 <heh@nvidia.com> * add docstring Signed-off-by: stevehuang52 <heh@nvidia.com> * Add to Dics Signed-off-by: Nithin Rao Koluguri <nithinraok> * add ci test Signed-off-by: Nithin Rao Koluguri <nithinraok> * add max steps in jenkins Signed-off-by: Nithin Rao Koluguri <nithinraok> * reduce max steps Signed-off-by: Nithin Rao Koluguri <nithinraok> * jenkins test Signed-off-by: Nithin Rao Koluguri <nithinraok> * add bs=2 Signed-off-by: Nithin Rao Koluguri <nithinraok> --------- Signed-off-by: stevehuang52 <heh@nvidia.com> Signed-off-by: Nithin Rao Koluguri <nithinraok> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Nithin Rao Koluguri <nithinraok> Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
- Loading branch information
1 parent
8f9bcb8
commit 84321e0
Showing
14 changed files
with
1,232 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
189 changes: 189 additions & 0 deletions
189
examples/asr/conf/asr_finetune/speech_to_text_hf_finetune.yaml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,189 @@ | ||
name: "Speech_To_Text_HF_Finetuning_using_HF_Datasets" | ||
|
||
# use `init_from_nemo_model` or `init_from_pretrained_model` to initialize the model | ||
# We do not currently support `init_from_ptl_ckpt` to create a single script for all types of models. | ||
init_from_nemo_model: null # path to nemo model | ||
init_from_pretrained_model: null # name of pretrained NeMo model, e.g., `stt_en_fastconformer_transducer_large` | ||
|
||
model: | ||
sample_rate: 16000 | ||
compute_eval_loss: false # eval samples can be very long and exhaust memory. Disable computation of transducer loss during validation/testing with this flag. | ||
log_prediction: true # enables logging sample predictions in the output during training | ||
rnnt_reduction: 'mean_volume' | ||
skip_nan_grad: false | ||
|
||
# configs for huggingface load_dataset function | ||
data_path: "librispeech_asr" | ||
data_name: null # name for the specific dataset to load, e.g., 'en' for MCV datasets, but some datasets don't require this field. | ||
streaming: false # set True to use streaming mode, which doesn't wait for data downloading but each training step takes longer in the first epoch. If True, you'll need to specify trainer.max_steps instead of trainer.max_epochs. | ||
|
||
# keys for audio, sample_rate and transcription in the huggingface dataset, keys seperated by `.` for nested fields. See example at the bottom of this file. | ||
audio_key: "audio.array" | ||
sample_rate_key: "audio.sampling_rate" | ||
text_key: "text" # the key for groundtruth transcription, e.g., MCV usually uses "sentence" while some others use "text" | ||
|
||
# simple text cleaning, by default converts all chars to lower-case and only keeps alpha-numeric chars. | ||
normalize_text: true | ||
symbols_to_keep: ["'"] # a list of symbols to keep during text cleaning. | ||
|
||
train_ds: | ||
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code | ||
streaming: ${model.streaming} | ||
normalize_text: ${model.normalize_text} | ||
symbols_to_keep: ${model.symbols_to_keep} | ||
audio_key: ${model.audio_key} | ||
sample_rate_key: ${model.sample_rate_key} | ||
text_key: ${model.text_key} | ||
hf_data_cfg: # hf_data_cfg can be a ListConfig or DictConfig. Params for each data are passed into huggingface load_dataset(). Add more params if needed | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'train.clean.360' | ||
streaming: ${model.streaming} | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'train.clean.100' | ||
streaming: ${model.streaming} | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'train.other.500' | ||
streaming: ${model.streaming} | ||
|
||
sample_rate: ${model.sample_rate} | ||
batch_size: 16 # you may increase batch_size if your memory allows | ||
shuffle: true | ||
shuffle_n: 2048 | ||
num_workers: 8 | ||
pin_memory: true | ||
use_start_end_token: false | ||
|
||
validation_ds: | ||
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code | ||
streaming: ${model.streaming} | ||
normalize_text: ${model.normalize_text} | ||
symbols_to_keep: ${model.symbols_to_keep} | ||
audio_key: ${model.audio_key} | ||
sample_rate_key: ${model.sample_rate_key} | ||
text_key: ${model.text_key} | ||
hf_data_cfg: # An example of using only one dataset | ||
path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'validation.other' | ||
streaming: ${model.streaming} | ||
|
||
sample_rate: ${model.sample_rate} | ||
batch_size: 8 | ||
shuffle: false | ||
shuffle_n: 2048 | ||
num_workers: 8 | ||
pin_memory: true | ||
use_start_end_token: false | ||
|
||
test_ds: | ||
manifest_filepath: "hugginface" # set to a not None value to avoid breaking existing code | ||
streaming: ${model.streaming} | ||
normalize_text: ${model.normalize_text} | ||
symbols_to_keep: ${model.symbols_to_keep} | ||
audio_key: ${model.audio_key} | ||
sample_rate_key: ${model.sample_rate_key} | ||
text_key: ${model.text_key} | ||
hf_data_cfg: # hf_data_cfg can be a ListConfig or DictConfig. Params for each data are passed into huggingface load_dataset(). Add more params if needed | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'test.other' | ||
streaming: ${model.streaming} | ||
- path: ${model.data_path} | ||
name: ${model.data_name} | ||
split: 'test.clean' | ||
streaming: ${model.streaming} | ||
|
||
sample_rate: ${model.sample_rate} | ||
batch_size: 8 | ||
shuffle: false | ||
shuffle_n: 2048 | ||
num_workers: 8 | ||
pin_memory: true | ||
use_start_end_token: false | ||
|
||
char_labels: # use for char based models | ||
update_labels: false | ||
labels: null # example list config: \[' ', 'a', 'b', 'c'\] | ||
|
||
tokenizer: # use for spe/bpe based tokenizer models | ||
update_tokenizer: false | ||
dir: null # path to directory which contains either tokenizer.model (bpe) or vocab.txt (for wpe) | ||
type: bpe # Can be either bpe (SentencePiece tokenizer) or wpe (WordPiece tokenizer) | ||
|
||
spec_augment: | ||
_target_: nemo.collections.asr.modules.SpectrogramAugmentation | ||
freq_masks: 2 # set to zero to disable it | ||
time_masks: 10 # set to zero to disable it | ||
freq_width: 27 | ||
time_width: 0.05 | ||
|
||
optim: | ||
name: adamw | ||
lr: 1e-4 | ||
# optimizer arguments | ||
betas: [0.9, 0.98] | ||
weight_decay: 1e-3 | ||
|
||
# scheduler setup | ||
sched: | ||
name: CosineAnnealing | ||
# scheduler config override | ||
warmup_steps: 5000 | ||
warmup_ratio: null | ||
min_lr: 5e-6 | ||
|
||
trainer: | ||
devices: -1 # number of GPUs, -1 would use all available GPUs | ||
num_nodes: 1 | ||
max_epochs: 100 | ||
max_steps: -1 # computed at runtime if not set | ||
val_check_interval: 1.0 # Set to 0.25 to check 4 times per epoch, or an int for number of iterations | ||
accelerator: auto | ||
strategy: ddp | ||
accumulate_grad_batches: 1 | ||
gradient_clip_val: 0.0 | ||
precision: 32 # 16, 32, or bf16 | ||
log_every_n_steps: 10 # Interval of logging. | ||
enable_progress_bar: True | ||
num_sanity_val_steps: 0 # number of steps to perform validation steps for sanity check the validation process before starting the training, setting to 0 disables it | ||
check_val_every_n_epoch: 1 # number of evaluations on validation every n epochs | ||
sync_batchnorm: true | ||
enable_checkpointing: False # Provided by exp_manager | ||
logger: false # Provided by exp_manager | ||
benchmark: false # needs to be false for models with variable-length speech input as it slows down training | ||
|
||
|
||
exp_manager: | ||
exp_dir: null | ||
name: ${name} | ||
create_tensorboard_logger: true | ||
create_checkpoint_callback: true | ||
checkpoint_callback_params: | ||
# in case of multiple validation sets, first one is used | ||
monitor: "val_wer" | ||
mode: "min" | ||
save_top_k: 5 | ||
always_save_nemo: True # saves the checkpoints as nemo files along with PTL checkpoints | ||
resume_if_exists: false | ||
resume_ignore_no_checkpoint: false | ||
|
||
create_wandb_logger: false | ||
wandb_logger_kwargs: | ||
name: null | ||
project: null | ||
|
||
|
||
# An example item in the HuggingFace `librispeech_asr` dataset: | ||
# {'chapter_id': 141231, | ||
# 'file': '/home/patrick/.cache/huggingface/datasets/downloads/extracted/b7ded9969e09942ab65313e691e6fc2e12066192ee8527e21d634aca128afbe2/dev_clean/1272/141231/1272-141231-0000.flac', | ||
# 'audio': { | ||
# 'path': '/home/patrick/.cache/huggingface/datasets/downloads/extracted/b7ded9969e09942ab65313e691e6fc2e12066192ee8527e21d634aca128afbe2/dev_clean/1272/141231/1272-141231-0000.flac', | ||
# 'array': array([-0.00048828, -0.00018311, -0.00137329, ..., 0.00079346, 0.00091553, 0.00085449], dtype=float32), | ||
# 'sampling_rate': 16000 | ||
# }, | ||
# 'id': '1272-141231-0000', | ||
# 'speaker_id': 1272, | ||
# 'text': 'A MAN SAID TO THE UNIVERSE SIR I EXIST'} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
Oops, something went wrong.