diff --git a/megatron/core/extensions/transformer_engine.py b/megatron/core/extensions/transformer_engine.py index f64862c3cb..62336cdb03 100644 --- a/megatron/core/extensions/transformer_engine.py +++ b/megatron/core/extensions/transformer_engine.py @@ -685,6 +685,11 @@ def forward( packed_seq_kwargs = ( dataclasses.asdict(packed_seq_params) if packed_seq_params is not None else {} ) + # overwrite self.qkv_format depending on self.config.apply_rope_fusion, which can be set + # after init + if self.config.apply_rope_fusion and is_te_min_version("0.13.0", check_equality=False): + self.qkv_format = 'bshd' + qkv_format = packed_seq_kwargs.get('qkv_format', self.qkv_format) if get_te_version() < PkgVersion("1.3.0"): @@ -701,6 +706,19 @@ def forward( packed_seq_kwargs.pop("cu_seqlens_q_padded", None) packed_seq_kwargs.pop("cu_seqlens_kv_padded", None) + # WAR for peak memory usage. + # See https://gitlab-master.nvidia.com/ADLR/megatron-lm/-/merge_requests/2388 + if self.config.apply_rope_fusion and qkv_format == 'bshd': + query, key, value = [x.contiguous().transpose(0, 1) for x in (query, key, value)] + # In PyTorch, the following two tensors are in fact the same: + # Tensor with shape (1, S, H, D) and stride (S*H*D, H*D, D, 1) + # Tensor with shape (1, S, H, D) and stride (H*D, H*D, D, 1) + # Stride for a dimension that is 1 has no meaning, so tensors created two different ways + # can have same shape but different strides. + # We unify them to the first one to pass the stride check in TE + if value.shape == key.shape and value.shape[0] == 1 and value.stride() != key.stride(): + value = value.as_strided(value.shape, key.stride()) + attention_bias_kwargs = {} if attention_bias is not None: assert is_te_min_version("1.2.0"), ( @@ -734,7 +752,10 @@ def forward( query, key, value, attention_mask, **attention_bias_kwargs, **packed_seq_kwargs ) - return core_attn_out + if self.config.apply_rope_fusion and qkv_format == 'bshd': + return core_attn_out.transpose(0, 1) + else: + return core_attn_out if is_te_min_version("1.9.0.dev0"):