Skip to content

Commit

Permalink
Merge branch 'MLA_TP' into 'main'
Browse files Browse the repository at this point in the history
Add support of TP for MLA

See merge request ADLR/megatron-lm!2328
  • Loading branch information
ko3n1g committed Feb 2, 2025
2 parents 04f9344 + 4d4676e commit ea94163
Show file tree
Hide file tree
Showing 5 changed files with 143 additions and 37 deletions.
67 changes: 48 additions & 19 deletions megatron/core/extensions/transformer_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
_initialize_affine_weight_cpu,
set_tensor_model_parallel_attributes,
)
from megatron.core.tensor_parallel.random import get_data_parallel_rng_tracker_name
from megatron.core.tensor_parallel.utils import divide
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.transformer_config import TransformerConfig
Expand Down Expand Up @@ -98,6 +99,13 @@ class TELinear(te.pytorch.Linear):
Note that if Megatron's parallel_state has not been initialized
yet, the tp_group passed to TE will be None and must be set later
via set_tensor_parallel_group().
parallel_mode currently supports 3 different values:
- "column": Split the weight matrix along output dimension (used in TEColumnParallelLinear)
- "row": Split the weight matrix along input dimension (used in TERowParallelLinear)
- "duplicated": No tensor parallelism and weight is duplicated across TP ranks
- Note: For expert linear layers, we will disable communication logic here
as TP communication is handled in token_dispatcher.
"""

def __init__(
Expand Down Expand Up @@ -170,27 +178,39 @@ def __init__(
if is_expert:
rng_tracker_name = get_expert_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if parallel_mode == "duplicated":
rng_tracker_name = get_data_parallel_rng_tracker_name()
else:
rng_tracker_name = None
if is_te_min_version("1.7.0"):
extra_kwargs["rng_tracker_name"] = rng_tracker_name

# Disable communications in TE when using TP or EP by making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)

if explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
parallel_mode = None
tp_size = 1
te_parallel_mode = parallel_mode
if parallel_mode == "duplicated":
# Handle non-parallel case
tp_group = None
tp_size = 1
explicit_expert_comm = False
te_parallel_mode = None
else:
# Disable communications in TE when using TP or EP by
# making TE agnostic of model parallel.
if is_expert:
tp_group = get_expert_tensor_parallel_group(check_initialized=False)
tp_size = get_expert_tensor_parallel_world_size()
else:
tp_group = get_tensor_model_parallel_group(check_initialized=False)
tp_size = get_tensor_model_parallel_world_size()
explicit_expert_comm = is_expert and (tp_size > 1 or self.expert_parallel)

if explicit_expert_comm:
if parallel_mode == "column":
output_size = divide(output_size, tp_size)
elif parallel_mode == "row":
input_size = divide(input_size, tp_size)
te_parallel_mode = None
tp_size = 1
tp_group = None

super().__init__(
in_features=input_size,
Expand All @@ -205,12 +225,21 @@ def __init__(
init_method=condition_init_method(config, init_method),
bias=bias,
return_bias=self.te_return_bias,
parallel_mode=parallel_mode,
parallel_mode=te_parallel_mode,
**extra_kwargs,
)

for param in self.parameters():
setattr(param, 'allreduce', not (is_expert and self.expert_parallel))
if is_expert:
# Reduce the gradient on the expert_data_parallel group for expert linear layers
setattr(param, 'allreduce', not self.expert_parallel)
else:
# Reduce the gradient on DP group
setattr(param, 'allreduce', True)
if parallel_mode == "duplicated":
# Reduce the gradient further on the TP group since the weight is
# duplicated across TP ranks
setattr(param, 'sequence_parallel', self.config.sequence_parallel)

def forward(self, x):
"""Forward."""
Expand Down
21 changes: 15 additions & 6 deletions megatron/core/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
TEColumnParallelLinear,
TEDotProductAttention,
TELayerNormColumnParallelLinear,
TELinear,
TENorm,
TERowParallelLinear,
)
Expand Down Expand Up @@ -100,14 +101,22 @@ def get_gpt_layer_with_transformer_engine_spec(
params={"attn_mask_type": AttnMaskType.causal},
submodules=MLASelfAttentionSubmodules(
linear_q_proj=TEColumnParallelLinear,
linear_q_down_proj=TEColumnParallelLinear,
linear_q_up_proj=TEColumnParallelLinear,
linear_kv_down_proj=TEColumnParallelLinear,
linear_kv_up_proj=TEColumnParallelLinear,
linear_q_down_proj=TELinear,
linear_q_up_proj=(
TELayerNormColumnParallelLinear
if qk_layernorm
else TEColumnParallelLinear
),
linear_kv_down_proj=TELinear,
linear_kv_up_proj=(
TELayerNormColumnParallelLinear
if qk_layernorm
else TEColumnParallelLinear
),
core_attention=TEDotProductAttention,
linear_proj=TERowParallelLinear,
q_layernorm=TENorm if qk_layernorm else IdentityOp,
kv_layernorm=TENorm if qk_layernorm else IdentityOp,
q_layernorm=IdentityOp,
kv_layernorm=IdentityOp,
),
),
self_attn_bda=get_bias_dropout_add,
Expand Down
25 changes: 14 additions & 11 deletions megatron/core/transformer/multi_latent_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
_yarn_get_mscale,
apply_rotary_pos_emb,
)
from megatron.core.tensor_parallel.mappings import gather_from_sequence_parallel_region
from megatron.core.transformer.attention import Attention
from megatron.core.transformer.enums import AttnMaskType
from megatron.core.transformer.spec_utils import ModuleSpec, build_module
Expand Down Expand Up @@ -50,11 +51,6 @@ def __init__(
attention_type: str,
cp_comm_type: str = None,
) -> None:
world_size = parallel_state.get_tensor_model_parallel_world_size()
assert (
world_size == 1
), "MLA is not supported with Tensor Parallelism yet, \
use Expert Parallelism and Pipeline Parallelism for better performance."

super().__init__(
config=config,
Expand Down Expand Up @@ -228,12 +224,12 @@ def __init__(
submodules.linear_q_down_proj,
self.config.hidden_size,
self.config.q_lora_rank,
parallel_mode="duplicated",
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
skip_weight_param_allocation=False,
)

self.linear_q_up_proj = build_module(
Expand All @@ -252,12 +248,12 @@ def __init__(
submodules.linear_kv_down_proj,
self.config.hidden_size,
self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim,
parallel_mode="duplicated",
config=self.config,
init_method=self.config.init_method,
gather_output=False,
bias=False,
skip_bias_add=False,
is_expert=False,
skip_weight_param_allocation=False,
)

self.linear_kv_up_proj = build_module(
Expand Down Expand Up @@ -303,7 +299,6 @@ def get_query_key_value_tensors(
assert (
hidden_states.ndim == 3
), f"hidden_states should be 3D, [s, b, n*h], got {hidden_states.ndim}D"
q_len, bsz, _ = hidden_states.size()

if self.config.q_lora_rank is not None:
q_compressed, _ = self.linear_q_down_proj(hidden_states)
Expand All @@ -313,6 +308,8 @@ def get_query_key_value_tensors(
# hidden_states:[s, b, 2048], q: [s, b, n * 192]
q, _ = self.linear_q_proj(hidden_states)

q_len, bsz, _ = q.size()

# q: [s, b, n, 192]
q = q.view(q_len, bsz, self.num_attention_heads_per_partition, self.q_head_dim)

Expand All @@ -329,6 +326,10 @@ def get_query_key_value_tensors(
kv_combined, [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], dim=-1
)

# Gather the input from sequence parallel region
if parallel_state.get_tensor_model_parallel_world_size() > 1:
k_pos_emb = gather_from_sequence_parallel_region(k_pos_emb)

# kv: [s, b, 2048]
kv, _ = self.linear_kv_up_proj(self.kv_layernorm(kv_compressed))

Expand All @@ -349,6 +350,8 @@ def get_query_key_value_tensors(
if len(rotary_pos_emb) == 2:
mscale = rotary_pos_emb[1]
rotary_pos_emb = rotary_pos_emb[0]
else:
mscale = 1.0

if inference_params is not None:
# add offset to the sequence start for inference
Expand Down Expand Up @@ -377,7 +380,7 @@ def get_query_key_value_tensors(
query = torch.cat([q_no_pe, q_pos_emb], dim=-1)

# key: [s, b, n, 192]
k_pos_emb = k_pos_emb.expand(-1, -1, self.config.num_attention_heads, -1)
k_pos_emb = k_pos_emb.expand(-1, -1, self.num_attention_heads_per_partition, -1)
key = torch.cat([k_no_pe, k_pos_emb], dim=-1)

query = query.contiguous()
Expand Down
4 changes: 4 additions & 0 deletions megatron/training/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,6 +616,10 @@ def validate_args(args, defaults={}):

if args.tp_comm_overlap:
assert args.sequence_parallel == True, 'Tensor parallel communication/GEMM overlap can happen only when sequence parallelism is enabled'

if args.multi_latent_attention:
if args.tensor_model_parallel_size > 1:
assert args.sequence_parallel == True, 'Sequence parallelism should be enabled when MLA is used with tensor parallel'

# disable async_tensor_model_parallel_allreduce when
# model parallel memory optimization is enabled
Expand Down
63 changes: 62 additions & 1 deletion tests/unit_tests/transformer/test_multi_latent_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ def test_cpu_forward(self):

def test_gpu_forward(self):
if is_te_min_version("1.10.0"):

# use flash attention for hopper, future may support fused attention for ampere
os.environ['NVTE_FUSED_ATTN'] = "0"
os.environ['NVTE_FLASH_ATTN'] = "1"
Expand Down Expand Up @@ -128,3 +127,65 @@ def test_checkpointed_gpu_forward(self):
assert output.shape[1] == micro_batch_size
assert output.shape[2] == config.hidden_size
assert bias.shape[0] == config.hidden_size


class TestTensorParallelMLAAttention:

def setup_method(self, method):
self.tensor_parallel_size = 2
Utils.initialize_model_parallel(self.tensor_parallel_size, 1)
model_parallel_cuda_manual_seed(123)
self.transformer_config = MLATransformerConfig(
num_layers=2,
hidden_size=12,
num_attention_heads=4,
q_lora_rank=32,
kv_lora_rank=32,
qk_head_dim=128,
v_head_dim=128,
qk_pos_emb_head_dim=64,
rotary_base=10000,
max_position_embeddings=64,
tensor_model_parallel_size=self.tensor_parallel_size,
sequence_parallel=True,
)
self.parallel_attention = MLASelfAttention(
self.transformer_config,
get_gpt_layer_with_transformer_engine_spec(
multi_latent_attention=True
).submodules.self_attention.submodules,
layer_number=1,
attn_mask_type=AttnMaskType.causal,
)

def teardown_method(self, method):
Utils.destroy_model_parallel()

def test_gpu_forward(self):
if is_te_min_version("1.10.0"):
# use flash attention for hopper, future may support fused attention for ampere
os.environ['NVTE_FUSED_ATTN'] = "0"
os.environ['NVTE_FLASH_ATTN'] = "1"

config = self.parallel_attention.config
sequence_length = 64
sub_sequence_length = sequence_length // self.tensor_parallel_size
micro_batch_size = 2

self.parallel_attention.cuda()

# [sequence length, batch size, hidden size]
hidden_states = torch.ones(
(sub_sequence_length, micro_batch_size, self.parallel_attention.config.hidden_size)
)
hidden_states = hidden_states.cuda()

attention_mask = torch.ones((1, 1, sequence_length, sequence_length), dtype=bool).cuda()

output, bias = self.parallel_attention(hidden_states, attention_mask)

assert config.recompute_granularity is None
assert output.shape[0] == sub_sequence_length
assert output.shape[1] == micro_batch_size
assert output.shape[2] == config.hidden_size
assert bias.shape[0] == config.hidden_size

0 comments on commit ea94163

Please sign in to comment.