Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/NVIDIA/NeMo into main
Browse files Browse the repository at this point in the history
  • Loading branch information
ekmb committed Jun 1, 2023
2 parents f4f09fa + 8672af6 commit 2b6777f
Show file tree
Hide file tree
Showing 8 changed files with 1,814 additions and 137 deletions.
2 changes: 2 additions & 0 deletions Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -3175,6 +3175,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_granularity='full' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \
model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings"
Expand Down Expand Up @@ -3211,6 +3212,7 @@ assert_frame_equal(training_curve, gt_curve, rtol=1e-3, atol=1e-3)"'''
model.hidden_size=256 \
model.num_attention_heads=8 \
model.activations_checkpoint_method='block' \
model.activations_checkpoint_granularity='full' \
model.activations_checkpoint_num_layers=1 \
model.data.data_prefix=[.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document,.5,/home/TestData/nlp/megatron_gpt/data/gpt/simple_wiki_gpt_preproc_text_document] \
model.data.index_mapping_dir=examples/nlp/language_modeling/gpt_index_mappings"
Expand Down
27 changes: 17 additions & 10 deletions nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1247,16 +1247,20 @@ def _restore_activation_checkpointing_args(self):
_reset_activation_checkpointing_args.
"""
# Restore config values.
self.cfg.activations_checkpoint_granularity = self.last_checkpointing_granularity
self.cfg.activations_checkpoint_method = self.last_checkpointing_method
self.cfg.activations_checkpoint_num_layers = self.last_checkpointing_num_layers
self.cfg.activations_checkpoint_granularity = self.last_activations_checkpoint_granularity
self.cfg.activations_checkpoint_method = self.last_activations_checkpoint_method
self.cfg.activations_checkpoint_num_layers = self.last_activations_checkpoint_num_layers
self.cfg.activations_checkpoint_layers_per_pipeline = self.last_activations_checkpoint_layers_per_pipeline

# Restore model parameters.
for module in self.get_gpt_module_list():
module.language_model.encoder.activations_checkpoint_granularity = self.last_checkpointing_granularity
module.language_model.encoder.activations_checkpoint_method = self.last_checkpointing_method
module.language_model.encoder.activations_checkpoint_num_layers = self.last_checkpointing_num_layers
module.language_model.encoder.activations_checkpoint_granularity = (
self.last_activations_checkpoint_granularity
)
module.language_model.encoder.activations_checkpoint_method = self.last_activations_checkpoint_method
module.language_model.encoder.activations_checkpoint_num_layers = (
self.last_activations_checkpoint_num_layers
)
module.language_model.encoder.activations_checkpoint_layers_per_pipeline = (
self.last_activations_checkpoint_layers_per_pipeline
)
Expand All @@ -1270,12 +1274,13 @@ def _reset_sequence_parallelism_args(self):
self.last_sequence_parallel = self.cfg.sequence_parallel

# Reset config values. Needed for calling generate.
self.cfg.sequence_parallel = None
self.cfg.sequence_parallel = False

# Reset model parameters.

for module in self.get_gpt_module_list():
module.language_model.encoder.sequence_parallel = None
for mod in module.modules():
if hasattr(mod, "sequence_parallel"):
mod.sequence_parallel = self.last_sequence_parallel

def _restore_sequence_parallelism_args(self):
""" Restores the sequence parallelism parameters using the values saved by
Expand All @@ -1287,4 +1292,6 @@ def _restore_sequence_parallelism_args(self):

# Restore model parameters.
for module in self.get_gpt_module_list():
module.language_model.encoder.sequence_parallel = self.last_sequence_parallel
for mod in module.modules():
if hasattr(mod, "sequence_parallel"):
mod.sequence_parallel = self.last_sequence_parallel
3 changes: 0 additions & 3 deletions nemo/collections/nlp/modules/common/megatron/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1268,9 +1268,6 @@ def custom_forward(*inputs):

return custom_forward

# Make sure memory is freed.
tensor_parallel.reset_checkpointed_activations_memory_buffer()

if self.activations_checkpoint_method == 'uniform':
# Uniformly divide the total number of Transformer layers and checkpoint
# the input activation of each divided chunk.
Expand Down
16 changes: 6 additions & 10 deletions nemo/collections/tts/modules/submodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -758,15 +758,11 @@ def forward(self, batch_size=None, speaker=None, reference_spec=None, reference_
embs = self.lookup_module(speaker)

# Get GST based speaker embedding
if self.gst_module is not None:
if reference_spec is None or reference_spec_lens is None:
raise ValueError(
"You should add `reference_audio` in sup_data_types or remove `speaker_encoder`in config."
)
out = self.gst_module(reference_spec, reference_spec_lens)
embs = out if embs is None else embs + out

elif self.gst_module is None and reference_spec is not None and reference_spec_lens is not None:
logging.warning("You may add `gst_module` in speaker_encoder to use reference_audio.")
if reference_spec is not None and reference_spec_lens is not None:
if self.gst_module is not None:
out = self.gst_module(reference_spec, reference_spec_lens)
embs = out if embs is None else embs + out
else:
logging.warning("You may add `gst_module` in speaker_encoder to use reference_audio.")

return embs
Loading

0 comments on commit 2b6777f

Please sign in to comment.