Skip to content

Commit

Permalink
Add rope scaling configs for NeMo 1 (#11807)
Browse files Browse the repository at this point in the history
* Add rope scaling configs

Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: BoxiangW <BoxiangW@users.noreply.github.com>

* Fix bug

Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>

---------

Signed-off-by: Boxiang Wang <boxiangw@nvidia.com>
Signed-off-by: BoxiangW <BoxiangW@users.noreply.github.com>
Co-authored-by: BoxiangW <BoxiangW@users.noreply.github.com>
Signed-off-by: Abhinav Garg <abhgarg@nvidia.com>
  • Loading branch information
2 people authored and abhinavg4 committed Jan 30, 2025
1 parent 23a01c0 commit 3aa244d
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 7 deletions.
7 changes: 1 addition & 6 deletions nemo/collections/common/parts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,14 +112,9 @@ def extend_instance(obj, mixin):
) # mixin needs to go first for our forward() logic to work


def apply_rope_scaling(freqs):
def apply_rope_scaling(freqs, scale_factor=8, low_freq_factor=1, high_freq_factor=4, old_context_len=8192):
# Apply scaling for RoPE frequencies
logger.info("apply rope scaling ...")
# Values obtained from grid search
scale_factor = 8
low_freq_factor = 1
high_freq_factor = 4
old_context_len = 8192 # original llama3 length

low_freq_wavelen = old_context_len / low_freq_factor
high_freq_wavelen = old_context_len / high_freq_factor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,13 @@ def mcore_model_customize(cfg, model):
if cfg.get("apply_embedding_scaling", False) and parallel_state.is_pipeline_first_stage():
extend_instance(model.embedding, EmbeddingScalingMixin)
if cfg.get("scale_positional_embedding", False):
model.rotary_pos_emb.inv_freq = apply_rope_scaling(model.rotary_pos_emb.inv_freq)
model.rotary_pos_emb.inv_freq = apply_rope_scaling(
model.rotary_pos_emb.inv_freq,
scale_factor=cfg.get('scale_factor', 8),
low_freq_factor=cfg.get('low_freq_factor', 1),
high_freq_factor=cfg.get('high_freq_factor', 4),
old_context_len=cfg.get('old_context_len', 8192),
)
if cfg.get("mcore_customization_config", {}).get("final_logit_softcapping", 0):
from nemo.collections.nlp.models.language_modeling.megatron.gemma2.gemma2_modules import Gemma2OutputLayer

Expand Down

0 comments on commit 3aa244d

Please sign in to comment.