diff --git a/nemo/collections/common/parts/utils.py b/nemo/collections/common/parts/utils.py index e08f7d7101834..c00de27c55bd0 100644 --- a/nemo/collections/common/parts/utils.py +++ b/nemo/collections/common/parts/utils.py @@ -112,14 +112,9 @@ def extend_instance(obj, mixin): ) # mixin needs to go first for our forward() logic to work -def apply_rope_scaling(freqs): +def apply_rope_scaling(freqs, scale_factor=8, low_freq_factor=1, high_freq_factor=4, old_context_len=8192): # Apply scaling for RoPE frequencies logger.info("apply rope scaling ...") - # Values obtained from grid search - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 # original llama3 length low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 165a6d650843c..02ef522dde1f8 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -199,7 +199,13 @@ def mcore_model_customize(cfg, model): if cfg.get("apply_embedding_scaling", False) and parallel_state.is_pipeline_first_stage(): extend_instance(model.embedding, EmbeddingScalingMixin) if cfg.get("scale_positional_embedding", False): - model.rotary_pos_emb.inv_freq = apply_rope_scaling(model.rotary_pos_emb.inv_freq) + model.rotary_pos_emb.inv_freq = apply_rope_scaling( + model.rotary_pos_emb.inv_freq, + scale_factor=cfg.get('scale_factor', 8), + low_freq_factor=cfg.get('low_freq_factor', 1), + high_freq_factor=cfg.get('high_freq_factor', 4), + old_context_len=cfg.get('old_context_len', 8192), + ) if cfg.get("mcore_customization_config", {}).get("final_logit_softcapping", 0): from nemo.collections.nlp.models.language_modeling.megatron.gemma2.gemma2_modules import Gemma2OutputLayer