From 490964a0124abc48ad4e88af0c6e81c2d17268c3 Mon Sep 17 00:00:00 2001 From: BoxiangW <45734921+BoxiangW@users.noreply.github.com> Date: Fri, 10 Jan 2025 11:39:36 -0800 Subject: [PATCH] Add rope scaling configs for NeMo 1 (#11807) * Add rope scaling configs Signed-off-by: Boxiang Wang * Apply isort and black reformatting Signed-off-by: BoxiangW * Fix bug Signed-off-by: Boxiang Wang --------- Signed-off-by: Boxiang Wang Signed-off-by: BoxiangW Co-authored-by: BoxiangW --- nemo/collections/common/parts/utils.py | 7 +------ .../nlp/models/language_modeling/megatron_gpt_model.py | 8 +++++++- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/nemo/collections/common/parts/utils.py b/nemo/collections/common/parts/utils.py index e08f7d710183..c00de27c55bd 100644 --- a/nemo/collections/common/parts/utils.py +++ b/nemo/collections/common/parts/utils.py @@ -112,14 +112,9 @@ def extend_instance(obj, mixin): ) # mixin needs to go first for our forward() logic to work -def apply_rope_scaling(freqs): +def apply_rope_scaling(freqs, scale_factor=8, low_freq_factor=1, high_freq_factor=4, old_context_len=8192): # Apply scaling for RoPE frequencies logger.info("apply rope scaling ...") - # Values obtained from grid search - scale_factor = 8 - low_freq_factor = 1 - high_freq_factor = 4 - old_context_len = 8192 # original llama3 length low_freq_wavelen = old_context_len / low_freq_factor high_freq_wavelen = old_context_len / high_freq_factor diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 165a6d650843..02ef522dde1f 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -199,7 +199,13 @@ def mcore_model_customize(cfg, model): if cfg.get("apply_embedding_scaling", False) and parallel_state.is_pipeline_first_stage(): extend_instance(model.embedding, EmbeddingScalingMixin) if cfg.get("scale_positional_embedding", False): - model.rotary_pos_emb.inv_freq = apply_rope_scaling(model.rotary_pos_emb.inv_freq) + model.rotary_pos_emb.inv_freq = apply_rope_scaling( + model.rotary_pos_emb.inv_freq, + scale_factor=cfg.get('scale_factor', 8), + low_freq_factor=cfg.get('low_freq_factor', 1), + high_freq_factor=cfg.get('high_freq_factor', 4), + old_context_len=cfg.get('old_context_len', 8192), + ) if cfg.get("mcore_customization_config", {}).get("final_logit_softcapping", 0): from nemo.collections.nlp.models.language_modeling.megatron.gemma2.gemma2_modules import Gemma2OutputLayer