Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

disable overlap for qkv #9079

Merged
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ def mcore_register_adapters(self):
if (
self.config.sequence_parallel
and hasattr(self.linear_qkv, "return_layernorm_output_gathered")
and not self.config.tp_comm_overlap
and not self.linear_qkv.ub_overlap_ag
):
# for LoRA SP, TE v1.5 can return layernorm output gathered so there is no need
# to perform the redundant gather in the adapter module, unless TP communication
Expand Down Expand Up @@ -142,11 +142,19 @@ def get_query_key_value_tensors(self, hidden_states, key_value_states=None):
if SplitAlongDim is not None:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = SplitAlongDim(mixed_qkv, 3, split_arg_list,)
(query, key, value) = SplitAlongDim(
mixed_qkv,
3,
split_arg_list,
)
else:

# [sq, b, ng, (np/ng + 2) * hn] --> [sq, b, ng, np/ng * hn], [sq, b, ng, hn], [sq, b, ng, hn]
(query, key, value) = torch.split(mixed_qkv, split_arg_list, dim=3,)
(query, key, value) = torch.split(
mixed_qkv,
split_arg_list,
dim=3,
)

# [sq, b, ng, np/ng * hn] -> [sq, b, np, hn]
query = query.reshape(query.size(0), query.size(1), -1, self.hidden_size_per_attention_head)
Expand Down Expand Up @@ -231,11 +239,21 @@ def forward(

if self.checkpoint_core_attention:
core_attn_out = self._checkpointed_attention_forward(
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params,
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)
else:
core_attn_out = self.core_attention(
query, key, value, attention_mask, attn_mask_type=attn_mask_type, packed_seq_params=packed_seq_params,
query,
key,
value,
attention_mask,
attn_mask_type=attn_mask_type,
packed_seq_params=packed_seq_params,
)

if packed_seq_params is not None:
Expand Down Expand Up @@ -316,7 +334,9 @@ def forward(self, hidden_states):
intermediate_parallel = bias_gelu_impl(intermediate_parallel, bias_parallel)
elif self.activation_func == F.silu and self.config.gated_linear_unit:
intermediate_parallel = bias_swiglu_impl(
intermediate_parallel, bias_parallel, self.config.activation_func_fp8_input_store,
intermediate_parallel,
bias_parallel,
self.config.activation_func_fp8_input_store,
)

else:
Expand Down
Loading