Skip to content

Commit

Permalink
merge
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Aug 14, 2024
2 parents 84aaa53 + 83df9c4 commit 2c61804
Show file tree
Hide file tree
Showing 13 changed files with 136 additions and 92 deletions.
9 changes: 9 additions & 0 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,11 @@ def forward(
if sp_rank != sp_size - 1:
q1 = q[half_idx_back]

# Non-contiguous indexing creates a new contiguous tensor,
# so only do it once
if sp_rank != sp_size - 1:
q1 = q[half_idx_back]

# Pre-allocate double buffer for overlapping and receiving next step's inputs
kv_buffers = [torch.stack((k, v))] # (2, B, Sq, H, D)
kv_buffers.append(torch.empty_like(kv_buffers[0]))
Expand Down Expand Up @@ -914,6 +919,10 @@ def backward(ctx, dout, _):
softmax_lse1 = softmax_lse[:, half_idx_back]
dout = dout.contiguous()

if sp_rank != sp_size - 1:
softmax_lse1 = softmax_lse[:, half_idx_back]
dout = dout.contiguous()

# Double comm buffers for sending and receiving kv
kv_buffers = [torch.stack((k, v))] # (2, T, H, D)
kv_buffers.append(torch.empty_like(kv_buffers[0]))
Expand Down
5 changes: 2 additions & 3 deletions colossalai/shardformer/layer/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def dist_cross_entropy(
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
logits: torch.Tensor, # [B, S, Vocab_size]
shard_config: ShardConfig,
out_features: int,
vocab_size: int,
dtype: torch.dtype,
seq_dim: int = 1,
Expand Down Expand Up @@ -226,13 +225,13 @@ def dist_cross_entropy(
logits,
labels,
process_group=shard_config.tensor_parallel_process_group,
vocab_size=out_features,
vocab_size=vocab_size,
dtype=dtype,
mode="sum",
)
else:
# NOTE if use TP and not parallel_output, the output is gathered in VocabParallelLMHead1D
logits = logits.view(-1, vocab_size)
logits = logits.view(-1, logits.size(-1))
loss = loss_fct(logits, labels)

# Reduce loss instead of gathering logits over seq dim for savings
Expand Down
25 changes: 14 additions & 11 deletions colossalai/shardformer/modeling/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,14 +359,15 @@ def bloom_for_causal_lm_forward(
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states).contiguous()

loss = dist_cross_entropy(
labels,
lm_logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.transformer.dtype,
)
loss = None
if labels is not None:
loss = dist_cross_entropy(
labels,
lm_logits,
shard_config,
self.lm_head.out_features,
self.transformer.dtype,
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down Expand Up @@ -1024,9 +1025,11 @@ def forward(
hidden_states = transformer_outputs[0]
lm_logits = self.lm_head(hidden_states)

loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down
39 changes: 32 additions & 7 deletions colossalai/shardformer/modeling/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,15 @@ def chatglm_model_forward(
if full_attention_mask is None:
if (attention_mask is not None and not attention_mask.all()) or (past_key_values and seq_length != 1):
full_attention_mask = self.get_masks(input_ids, past_key_values, padding_mask=attention_mask)

# Support SP + PP
sp_size = shard_config.sequence_parallel_size
sp_mode = shard_config.sequence_parallelism_mode
sp_group = shard_config.sequence_parallel_process_group
# For generating full positions ids (the states will be gathered along the seq dim before attention fwd).
if sp_mode != "ring_attn" and not stage_manager.is_first_stage():
seq_length *= sp_size

# Rotary positional embeddings
rotary_pos_emb = self.rotary_pos_emb(self.seq_length)
if position_ids is not None:
Expand All @@ -206,11 +215,11 @@ def chatglm_model_forward(
# Keep the input split across all PP stages
if stage_manager.is_first_stage():
if shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode == "split_gather":
if sp_mode == "split_gather":
hidden_states = split_forward_gather_backward(
hidden_states,
dim=0,
process_group=shard_config.tensor_parallel_process_group,
process_group=sp_group,
)
elif shard_config.sequence_parallelism_mode == "all_to_all":
hidden_states = split_forward_gather_backward(
Expand Down Expand Up @@ -255,7 +264,9 @@ def chatglm_model_forward(
# Gather seq-wise in the final output stage
sp_mode = shard_config.sequence_parallelism_mode
if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode)
hidden_states = gather_sp_output(
hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0
)

if not return_dict:
return tuple(
Expand Down Expand Up @@ -321,9 +332,21 @@ def chatglm_for_conditional_generation_forward(
hidden_states = hidden_states[-1:]
lm_logits = self.transformer.output_layer(hidden_states)
lm_logits = lm_logits.transpose(0, 1).contiguous()
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, lm_logits.dtype
)

loss = None
if labels is not None:
# ChatGLM doesn't have lm_head split
enable_tp = shard_config.enable_tensor_parallelism
shard_config.enable_tensor_parallelism = False
loss = dist_cross_entropy(
labels,
lm_logits,
shard_config,
self.transformer.output_layer.out_features,
lm_logits.dtype,
)
shard_config.enable_tensor_parallelism = enable_tp

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
return ((loss,) + output) if loss is not None else output
Expand Down Expand Up @@ -424,7 +447,9 @@ def forward(
)

if (not shard_config.parallel_output) or force_sp_output_gather or is_share_sp_tp(sp_mode):
hidden_states = gather_sp_output(hidden_states, shard_config.sequence_parallel_process_group, sp_mode)
hidden_states = gather_sp_output(
hidden_states, shard_config.sequence_parallel_process_group, sp_mode, sp_dim=0
)

if not return_dict:
return tuple(
Expand Down
35 changes: 22 additions & 13 deletions colossalai/shardformer/modeling/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,10 +93,16 @@ def command_model_forward(
if not isinstance(past_key_values, StaticCache):
past_key_values = DynamicCache.from_legacy_cache(past_key_values)
past_seen_tokens = past_key_values.get_seq_length()

# NOTE: For generating full positions ids
# (the states will be gathered along the seq dim before attention fwd).
if shard_config.sequence_parallelism_mode != "ring_attn" and not stage_manager.is_first_stage():
seq_length *= shard_config.sequence_parallel_size

if cache_position is None:
if isinstance(past_key_values, StaticCache):
raise ValueError("cache_position is a required argument when using StaticCache.")
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + hidden_states.shape[1], device=device)
cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=device)

seq_length_with_past = seq_length + past_seen_tokens

Expand Down Expand Up @@ -136,7 +142,7 @@ def command_model_forward(
)
use_cache = False

if shard_config.enable_sequence_parallelism:
if stage_manager.is_first_stage() and shard_config.enable_sequence_parallelism:
if shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
hidden_states = split_forward_gather_backward(
hidden_states,
Expand Down Expand Up @@ -320,9 +326,10 @@ def command_for_causal_lm_forward(
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)

loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -659,14 +666,16 @@ def forward(
logits = self.lm_head(hidden_states)
logits = logits * self.logit_scale
logits = logits.float()
loss = dist_cross_entropy(
labels,
logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.model.dtype,
)

loss = None
if labels is not None:
loss = dist_cross_entropy(
labels,
logits,
shard_config,
self.lm_head.out_features,
self.model.dtype,
)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
16 changes: 10 additions & 6 deletions colossalai/shardformer/modeling/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,9 +372,11 @@ def gpt2_lmhead_model_forward(

hidden_states = outputs[0]
lm_logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
)

if not return_dict:
output = (lm_logits,) + outputs[1:]
Expand Down Expand Up @@ -1264,9 +1266,11 @@ def forward(
hidden_states = transformer_outputs[0]

lm_logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.transformer.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(
labels, lm_logits, shard_config, self.lm_head.out_features, self.transformer.dtype
)

if not return_dict:
output = (lm_logits,) + transformer_outputs[1:]
Expand Down
12 changes: 6 additions & 6 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,9 +346,9 @@ def llama_for_causal_lm_forward(
if stage_manager.is_last_stage():
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -865,9 +865,9 @@ def forward(
else:
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)
if not return_dict:
output = (logits,) + outputs[1:]
return (loss,) + output if loss is not None else output
Expand Down
14 changes: 6 additions & 8 deletions colossalai/shardformer/modeling/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,10 +274,9 @@ def mistral_for_causal_lm_forward(
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -687,10 +686,9 @@ def forward(
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()

loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
23 changes: 12 additions & 11 deletions colossalai/shardformer/modeling/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,14 +330,15 @@ def opt_for_causal_lm_forward(
)
if stage_manager.is_last_stage():
logits = self.lm_head(outputs[0]).contiguous()
loss = dist_cross_entropy(
labels,
logits,
shard_config,
self.lm_head.out_features,
self.config.vocab_size,
self.model.decoder.dtype,
)
loss = None
if labels is not None:
loss = dist_cross_entropy(
labels,
logits,
shard_config,
self.lm_head.out_features,
self.model.decoder.dtype,
)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -955,9 +956,9 @@ def forward(
)

logits = self.lm_head(outputs[0]).contiguous()
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, self.model.decoder.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, self.model.decoder.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
12 changes: 6 additions & 6 deletions colossalai/shardformer/modeling/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,9 +350,9 @@ def qwen2_for_causal_lm_forward(
if hidden_states.shape[1] == 2:
pass
logits = self.lm_head(hidden_states)
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down Expand Up @@ -824,9 +824,9 @@ def forward(
hidden_states = outputs[0]
logits = self.lm_head(hidden_states)
logits = logits.float()
loss = dist_cross_entropy(
labels, logits, shard_config, self.lm_head.out_features, self.config.vocab_size, logits.dtype
)
loss = None
if labels is not None:
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)

if not return_dict:
output = (logits,) + outputs[1:]
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/chatglm2.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:

if sp_mode == "ring":
warnings.warn(
f"For ChatGLM2, sequence parallelism is currently not support mode {sp_mode}, will set to be split_gather"
f"For ChatGLM2, sequence parallelism doesn't support mode {sp_mode} yet, will set to be split_gather"
)
sp_mode = "split_gather"
overlap = self.shard_config.enable_sequence_overlap
Expand Down
Loading

0 comments on commit 2c61804

Please sign in to comment.