Skip to content

Commit

Permalink
Merge branch 'main' into revamp_documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
erastorgueva-nv authored Oct 11, 2023
2 parents bd7564e + 0174752 commit 66682fc
Showing 1 changed file with 23 additions and 6 deletions.
29 changes: 23 additions & 6 deletions examples/asr/speech_to_text_finetune.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,25 +49,25 @@
For documentation on fine-tuning this model, please visit:
https://docs.nvidia.com/deeplearning/nemo/user-guide/docs/en/main/asr/configs.html#fine-tuning-configurations
"""

import time
import pytorch_lightning as pl
from omegaconf import OmegaConf
from pytorch_lightning.utilities import rank_zero_only

from nemo.collections.asr.models import ASRModel
from nemo.core.config import hydra_runner
from nemo.utils import logging, model_utils
from nemo.utils.exp_manager import exp_manager
from nemo.utils.get_rank import is_global_rank_zero


@rank_zero_only
def get_base_model(cfg):
def get_base_model(trainer, cfg):
"""
Returns the base model to be fine-tuned.
Currently supports two types of initializations:
1) `init_from_nemo_model`, and
2) `init_from_pretrained_model`.
Args:
trainer: PyTorch Lightning Trainer
cfg: config
Returns:
asr_model: ASRModel instance
Expand All @@ -84,7 +84,24 @@ def get_base_model(cfg):
elif nemo_model_path is not None:
asr_model = ASRModel.restore_from(restore_path=nemo_model_path)
elif pretrained_name is not None:
asr_model = ASRModel.from_pretrained(model_name=pretrained_name)
# Due to potential first time download of the model on the cluster, we need to make sure that only one
# rank downloads the model and the others wait for the download to finish.
num_ranks = trainer.num_devices * trainer.num_devices

if num_ranks > 1 and is_global_rank_zero():
asr_model = ASRModel.from_pretrained(model_name=pretrained_name)
else:
# Sleep on all ranks for at least 60 seconds
wait_time = int(cfg.get('exp_manager', {}).get('seconds_to_sleep', 60))
if wait_time < 60:
wait_time = 60

logging.info(f"Sleeping for at least {wait_time} seconds to wait for model download to finish.")

time.sleep(wait_time)

# restore model from cached model dir
asr_model = ASRModel.from_pretrained(model_name=pretrained_name)

return asr_model

Expand Down Expand Up @@ -180,7 +197,7 @@ def main(cfg):
"Currently for simplicity of single script for all model types, we only support `init_from_nemo_model` and `init_from_pretrained_model`"
)

asr_model = get_base_model(cfg)
asr_model = get_base_model(trainer, cfg)

# Check vocabulary type and update if needed
asr_model = check_vocabulary(asr_model, cfg)
Expand Down

0 comments on commit 66682fc

Please sign in to comment.