diff --git a/colossalai/shardformer/layer/attn.py b/colossalai/shardformer/layer/attn.py index f255aa3ac7d0..ff90076275a7 100644 --- a/colossalai/shardformer/layer/attn.py +++ b/colossalai/shardformer/layer/attn.py @@ -659,6 +659,11 @@ def forward( if sp_rank != sp_size - 1: q1 = q[half_idx_back] + # Non-contiguous indexing creates a new contiguous tensor, + # so only do it once + if sp_rank != sp_size - 1: + q1 = q[half_idx_back] + # Pre-allocate double buffer for overlapping and receiving next step's inputs kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) @@ -914,6 +919,10 @@ def backward(ctx, dout, _): softmax_lse1 = softmax_lse[:, half_idx_back] dout = dout.contiguous() + if sp_rank != sp_size - 1: + softmax_lse1 = softmax_lse[:, half_idx_back] + dout = dout.contiguous() + # Double comm buffers for sending and receiving kv kv_buffers = [torch.stack((k, v))] # (2, T, H, D) kv_buffers.append(torch.empty_like(kv_buffers[0])) diff --git a/colossalai/shardformer/layer/loss.py b/colossalai/shardformer/layer/loss.py index 12df824d1c0c..0e2241af9fc9 100644 --- a/colossalai/shardformer/layer/loss.py +++ b/colossalai/shardformer/layer/loss.py @@ -153,7 +153,6 @@ def dist_cross_entropy( labels: torch.Tensor, # [B, S] or [B, S, Vocab_size] logits: torch.Tensor, # [B, S, Vocab_size] shard_config: ShardConfig, - out_features: int, vocab_size: int, dtype: torch.dtype, seq_dim: int = 1, @@ -226,13 +225,13 @@ def dist_cross_entropy( logits, labels, process_group=shard_config.tensor_parallel_process_group, - vocab_size=out_features, + vocab_size=vocab_size, dtype=dtype, mode="sum", ) else: # NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D - logits = logits.view(-1, vocab_size) + logits = logits.view(-1, logits.size(-1)) loss = loss_fct(logits, labels) # Reduce loss instead of gathering logits over seq dim for savings diff --git a/colossalai/shardformer/modeling/bloom.py b/colossalai/shardformer/modeling/bloom.py index 26ffef6c5ee0..daa2296dd338 100644 --- a/colossalai/shardformer/modeling/bloom.py +++ b/colossalai/shardformer/modeling/bloom.py @@ -359,14 +359,15 @@ def bloom_for_causal_lm_forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states).contiguous() - loss = dist_cross_entropy( - labels, - lm_logits, - shard_config, - self.lm_head.out_features, - self.config.vocab_size, - self.transformer.dtype, - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.lm_head.out_features, + self.transformer.dtype, + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] @@ -1024,9 +1025,11 @@ def forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/chatglm2.py b/colossalai/shardformer/modeling/chatglm2.py index 5be4b9d78e11..2e12e78378ef 100644 --- a/colossalai/shardformer/modeling/chatglm2.py +++ b/colossalai/shardformer/modeling/chatglm2.py @@ -183,6 +183,15 @@ def chatglm_model_forward( if full_attention_mask is None: if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1): full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask) + + # Support SP + PP + sp_size = shard_config.sequence_parallel_size + sp_mode = shard_config.sequence_parallelism_mode + sp_group = shard_config.sequence_parallel_process_group + # For generating full positions ids (the states will be gathered along the seq dim before attention fwd). + if sp_mode != "ring_attn" and not stage_manager.is_first_stage(): + seq_length *= sp_size + # Rotary positional embeddings rotary_pos_emb = self.rotary_pos_emb(self.seq_length) if position_ids is not None: @@ -206,11 +215,11 @@ def chatglm_model_forward( # Keep the input split across all PP stages if stage_manager.is_first_stage(): if shard_config.enable_sequence_parallelism: - if shard_config.sequence_parallelism_mode == "split_gather": + if sp_mode == "split_gather": hidden_states = split_forward_gather_backward( hidden_states, dim=0, - process_group=shard_config.tensor_parallel_process_group, + process_group=sp_group, ) elif shard_config.sequence_parallelism_mode == "all_to_all": hidden_states = split_forward_gather_backward( @@ -255,7 +264,9 @@ def chatglm_model_forward( # Gather seq-wise in the final output stage sp_mode = shard_config.sequence_parallelism_mode if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) + hidden_states = gather_sp_output( + hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0 + ) if not return_dict: return tuple( @@ -321,9 +332,21 @@ def chatglm_for_conditional_generation_forward( hidden_states = hidden_states[-1:] lm_logits = self.transformer.output_layer(hidden_states) lm_logits = lm_logits.transpose(0, 1).contiguous() - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, lm_logits.dtype - ) + + loss = None + if labels is not None: + # ChatGLM doesn't have lm_head split + enable_tp = shard_config.enable_tensor_parallelism + shard_config.enable_tensor_parallelism = False + loss = dist_cross_entropy( + labels, + lm_logits, + shard_config, + self.transformer.output_layer.out_features, + lm_logits.dtype, + ) + shard_config.enable_tensor_parallelism = enable_tp + if not return_dict: output = (lm_logits,) + transformer_outputs[1:] return ((loss,) + output) if loss is not None else output @@ -424,7 +447,9 @@ def forward( ) if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode): - hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode) + hidden_states = gather_sp_output( + hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0 + ) if not return_dict: return tuple( diff --git a/colossalai/shardformer/modeling/command.py b/colossalai/shardformer/modeling/command.py index cac325dcbea6..bdcf6f0a2f69 100644 --- a/colossalai/shardformer/modeling/command.py +++ b/colossalai/shardformer/modeling/command.py @@ -93,10 +93,16 @@ def command_model_forward( if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() + + # NOTE: For generating full positions ids + # (the states will be gathered along the seq dim before attention fwd). + if shard_config.sequence_parallelism_mode != "ring_attn" and not stage_manager.is_first_stage(): + seq_length *= shard_config.sequence_parallel_size + if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") - cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device) + cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device) seq_length_with_past = seq_length + past_seen_tokens @@ -136,7 +142,7 @@ def command_model_forward( ) use_cache = False - if shard_config.enable_sequence_parallelism: + if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism: if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]: hidden_states = split_forward_gather_backward( hidden_states, @@ -320,9 +326,10 @@ def command_for_causal_lm_forward( logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -659,14 +666,16 @@ def forward( logits = self.lm_head(hidden_states) logits = logits * self.logit_scale logits = logits.float() - loss = dist_cross_entropy( - labels, - logits, - shard_config, - self.lm_head.out_features, - self.config.vocab_size, - self.model.dtype, - ) + + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.model.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/gpt2.py b/colossalai/shardformer/modeling/gpt2.py index 6ecda91c4d35..db38c9a0ec33 100644 --- a/colossalai/shardformer/modeling/gpt2.py +++ b/colossalai/shardformer/modeling/gpt2.py @@ -372,9 +372,11 @@ def gpt2_lmhead_model_forward( hidden_states = outputs[0] lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + outputs[1:] @@ -1264,9 +1266,11 @@ def forward( hidden_states = transformer_outputs[0] lm_logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype + ) if not return_dict: output = (lm_logits,) + transformer_outputs[1:] diff --git a/colossalai/shardformer/modeling/llama.py b/colossalai/shardformer/modeling/llama.py index 1c75bc5c240d..36394628b749 100644 --- a/colossalai/shardformer/modeling/llama.py +++ b/colossalai/shardformer/modeling/llama.py @@ -346,9 +346,9 @@ def llama_for_causal_lm_forward( if stage_manager.is_last_stage(): hidden_states = outputs[0] logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -865,9 +865,9 @@ def forward( else: logits = self.lm_head(hidden_states) logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output diff --git a/colossalai/shardformer/modeling/mistral.py b/colossalai/shardformer/modeling/mistral.py index ec1a8a00a58a..7fc6a1062037 100644 --- a/colossalai/shardformer/modeling/mistral.py +++ b/colossalai/shardformer/modeling/mistral.py @@ -274,10 +274,9 @@ def mistral_for_causal_lm_forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -687,10 +686,9 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/opt.py b/colossalai/shardformer/modeling/opt.py index 636b46cc461d..3ea4db9e2f70 100644 --- a/colossalai/shardformer/modeling/opt.py +++ b/colossalai/shardformer/modeling/opt.py @@ -330,14 +330,15 @@ def opt_for_causal_lm_forward( ) if stage_manager.is_last_stage(): logits = self.lm_head(outputs[0]).contiguous() - loss = dist_cross_entropy( - labels, - logits, - shard_config, - self.lm_head.out_features, - self.config.vocab_size, - self.model.decoder.dtype, - ) + loss = None + if labels is not None: + loss = dist_cross_entropy( + labels, + logits, + shard_config, + self.lm_head.out_features, + self.model.decoder.dtype, + ) if not return_dict: output = (logits,) + outputs[1:] @@ -955,9 +956,9 @@ def forward( ) logits = self.lm_head(outputs[0]).contiguous() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.decoder.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/modeling/qwen2.py b/colossalai/shardformer/modeling/qwen2.py index d44c7382fdf6..353ccc2f5947 100644 --- a/colossalai/shardformer/modeling/qwen2.py +++ b/colossalai/shardformer/modeling/qwen2.py @@ -350,9 +350,9 @@ def qwen2_for_causal_lm_forward( if hidden_states.shape[1] == 2: pass logits = self.lm_head(hidden_states) - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] @@ -824,9 +824,9 @@ def forward( hidden_states = outputs[0] logits = self.lm_head(hidden_states) logits = logits.float() - loss = dist_cross_entropy( - labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype - ) + loss = None + if labels is not None: + loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype) if not return_dict: output = (logits,) + outputs[1:] diff --git a/colossalai/shardformer/policies/chatglm2.py b/colossalai/shardformer/policies/chatglm2.py index 3877bdac3ae2..f99d1ef819b7 100644 --- a/colossalai/shardformer/policies/chatglm2.py +++ b/colossalai/shardformer/policies/chatglm2.py @@ -64,7 +64,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: if sp_mode == "ring": warnings.warn( - f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather" + f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather" ) sp_mode = "split_gather" overlap = self.shard_config.enable_sequence_overlap diff --git a/tests/test_shardformer/test_model/test_shard_chatglm2.py b/tests/test_shardformer/test_model/test_shard_chatglm2.py index 92c077950ecc..17a8bf318976 100644 --- a/tests/test_shardformer/test_model/test_shard_chatglm2.py +++ b/tests/test_shardformer/test_model/test_shard_chatglm2.py @@ -136,26 +136,25 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, @parameterize( "test_config", [ - { # Ulysess + Flash attention - "tp_size": 1, + { + "tp_size": 2, "pp_size": 2, - "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "all_to_all", + "sequence_parallelism_mode": "split_gather", "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 2, + { # Ulysess + Flash attention + "tp_size": 1, "pp_size": 2, "sp_size": 2, "num_microbatches": 2, "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "split_gather", + "sequence_parallelism_mode": "all_to_all", "enable_flash_attention": True, "use_lazy_init": True, "zero_stage": 1, @@ -174,17 +173,6 @@ def check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, "precision": "fp16", "initial_scale": 1, }, - { - "tp_size": 4, - "pp_size": 1, - "num_microbatches": 1, - "enable_sequence_parallelism": True, - "sequence_parallelism_mode": "ring", - "enable_flash_attention": False, - "use_lazy_init": True, - "precision": "fp32", - "initial_scale": 1, - }, { "tp_size": 4, "pp_size": 1, @@ -248,7 +236,11 @@ def run_chatglm_test(test_config): loss_fn, _, ) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Test config failed for model {name}: {test_config}") + raise e clear_layout_converter() torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_command.py b/tests/test_shardformer/test_model/test_shard_command.py index 2e6997597928..9435ef84bfa8 100644 --- a/tests/test_shardformer/test_model/test_shard_command.py +++ b/tests/test_shardformer/test_model/test_shard_command.py @@ -281,7 +281,11 @@ def run_command_test(test_config): sub_model_zoo = model_zoo.get_sub_registry("transformers_command", "transformers_command_for_causal_lm") for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): - check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + try: + check_forward_backward(model_fn, data_gen_fn, output_transform_fn, loss_fn, test_config) + except Exception as e: + print(f"Failed test config: {test_config}") + raise e clear_layout_converter() Randomizer.reset_index()