Skip to content

Commit

Permalink
Merge branch 'xiny/fix_peak_mem' into 'main'
Browse files Browse the repository at this point in the history
Fix peak memory consumption for NeMo

See merge request ADLR/megatron-lm!2388
  • Loading branch information
jaredcasper committed Dec 8, 2024
2 parents 9665f2d + 7da20af commit 44fd429
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,6 +685,11 @@ def forward(
packed_seq_kwargs = (
dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {}
)
# overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set
# after init
if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False):
self.qkv_format = 'bshd'

qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format)

if get_te_version() < PkgVersion("1.3.0"):
Expand All @@ -701,6 +706,19 @@ def forward(
packed_seq_kwargs.pop("cu_seqlens_q_padded", None)
packed_seq_kwargs.pop("cu_seqlens_kv_padded", None)

# WAR for peak memory usage.
# See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388
if self.config.apply_rope_fusion and qkv_format == 'bshd':
query, key, value = [x.contiguous().transpose(0, 1) for x in (query, key, value)]
# In PyTorch, the following two tensors are in fact the same:
# Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1)
# Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1)
# Stride for a dimension that is 1 has no meaning, so tensors created two different ways
# can have same shape but different strides.
# We unify them to the first one to pass the stride check in TE
if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride():
value = value.as_strided(value.shape, key.stride())

attention_bias_kwargs = {}
if attention_bias is not None:
assert is_te_min_version("1.2.0"), (
Expand Down Expand Up @@ -734,7 +752,10 @@ def forward(
query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs
)

return core_attn_out
if self.config.apply_rope_fusion and qkv_format == 'bshd':
return core_attn_out.transpose(0, 1)
else:
return core_attn_out


if is_te_min_version("1.9.0.dev0"):
Expand Down

0 comments on commit 44fd429

Please sign in to comment.