Skip to content

Commit

Permalink
ring attn + tp, pp tests passed; fix typos such as causal
Browse files Browse the repository at this point in the history
  • Loading branch information
Edenzzzz committed Jul 19, 2024
1 parent 6d2906a commit 501205d
Show file tree
Hide file tree
Showing 32 changed files with 105 additions and 85 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ repos:
hooks:
- id: isort
name: sort all imports (python)
args: ["--profile", "black"] # avoid comflict with black

- repo: https://github.com/psf/black-pre-commit-mirror
rev: 24.4.2
Expand Down
2 changes: 1 addition & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def sync_sp_grads(self, grads: Optional[List[torch.Tensor]] = None):
"""

if self.shard_config.enable_sequence_parallelism:
if self.shard_config.sequence_parallelism_mode == "all_to_all":
if self.shard_config.sequence_parallelism_mode in ["all_to_all", "ring_attn"]:
return

if self.shard_config.sequence_parallelism_mode in ["split_gather", "ring"]:
Expand Down
1 change: 0 additions & 1 deletion colossalai/pipeline/schedule/interleaved_pp.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,6 @@ def forward_step(
# for the first stage, input_obj is None
# for other stages, input_obj is the output of the previous stage containing hidden_states etc.
# Only attention_mask from micro_batch is used

with self.stage_manager.switch_model_chunk_id(model_chunk_id):
if isinstance(model_chunk, ModuleList):
output_obj = model_forward(model_chunk[model_chunk_id], micro_batch, input_obj)
Expand Down
27 changes: 15 additions & 12 deletions colossalai/shardformer/layer/attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -415,13 +415,12 @@ def _rescale_out_lse(out, block_out, lse, block_lse):
# new_lse = max_scale + torch.log(1 + torch.exp(min_scale - max_scale))
new_lse = lse + torch.log(1 + torch.exp(block_lse - lse))
new_block_lse = torch.exp(block_lse - new_lse)
assert _not_nan(new_lse), new_lse
# dist.barrier()
assert _not_nan(new_block_lse), new_block_lse

out.copy_(torch.exp(lse - new_lse) * out + new_block_lse * block_out)
lse.copy_(new_lse)

assert _not_nan(new_lse), new_lse
assert _not_nan(new_block_lse), new_block_lse
assert _not_nan(out), out
# block_out = block_out.float()
# out.copy_(out - F.sigmoid(block_lse - lse) * (out - block_out))
# lse.copy_(lse - F.logsigmoid(lse - block_lse))
Expand Down Expand Up @@ -600,7 +599,8 @@ def forward(
b, h, sq, d = q.shape
# (B, H, Sq, D) -> (B, Sq, H, D)
q, k, v = [x.transpose(1, 2) for x in (q, k, v)]

assert _not_nan(q), q
assert _not_nan(k), k
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
sp_global_ranks = dist.get_process_group_ranks(sp_group)
Expand All @@ -626,7 +626,9 @@ def forward(
with torch.cuda.stream(sp_streams[i % 2]):
for req in p2p_reqs[(i + 1) % 2]:
req.wait()
assert _not_nan(kv_buffers[i % 2]), kv_buffers[i % 2]
assert _not_nan(
kv_buffers[i % 2]
), f"rank {dist.get_rank()} iter {i} kv buffer is nan: {kv_buffers[i % 2]}"

if i < sp_size - 1:
p2p_reqs[i % 2] = ring_attn_p2p_comm(
Expand Down Expand Up @@ -674,7 +676,7 @@ def forward(
kv_block = kv_buffers[i % 2]
# (2, B * Sq // 2, H, D)
kv_block = kv_block.view(2, b * sq, h, d)[:, : b * sq // 2].clone()
assert _not_nan(kv_block), f"rank {sp_rank} step {i} kv_block {kv_block}"
assert _not_nan(kv_block), f"rank {dist.get_rank()} step {i} kv_block {kv_block}"
# actual_lse = (q_block.flatten(start_dim=1) @ kv_block[0].movedim(0, -1).flatten(end_dim=-2)).exp().sum(dim=-1).log()
(
_,
Expand Down Expand Up @@ -702,7 +704,7 @@ def forward(
# Drop the first half of q
q_block = q.view(b * sq, h, d)[b * sq // 2 :]
kv_block = kv_buffers[i % 2].view(2, b * sq, h, d).clone()
assert _not_nan(kv_block), f"rank {sp_rank} step {i} kv_block {kv_block}"
assert _not_nan(kv_block), f"rank {dist.get_rank()} step {i} kv_block {kv_block}"
# actual_lse = (q_block.flatten(start_dim=1) @ kv_block[0].movedim(0, -1).flatten(end_dim=-2)).exp().sum(dim=-1).log()

(
Expand Down Expand Up @@ -919,9 +921,9 @@ def backward(ctx, dout):
# Accumulate grads
if i == 0:
# TODO: use float() if precision goes wrong
dq = dq_block
dk_recv = dkv_buffers[(i + 1) % 2][0] = dk_block.clone()
dv_recv = dkv_buffers[(i + 1) % 2][1] = dv_block.clone()
dq = dq_block.float()
dk_recv = dkv_buffers[(i + 1) % 2][0] = dk_block.float()
dv_recv = dkv_buffers[(i + 1) % 2][1] = dv_block.float()
else:
# Accumulate local dq
if i <= sp_rank:
Expand All @@ -933,7 +935,8 @@ def backward(ctx, dout):
# Wait for mobile kv grad accumulators
for req in dkv_reqs:
req.wait()

assert _not_nan(dkv_buffers[(i + 1) % 2]), f"rank {dist.get_rank()} step {i} dkv_buffers is nan"
assert _not_nan(dq_block), f"rank {dist.get_rank()} step {i} dq_block is nan"
if i <= sp_rank:
# q blocks "surrounded" by kv blocks
dk_recv = dkv_buffers[(i + 1) % 2][0]
Expand Down
5 changes: 3 additions & 2 deletions colossalai/shardformer/layer/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def create_randomizer_with_offset(
return Randomizer(seed=base_seed)


def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen: bool = False):
def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, seq_dim=1, varlen: bool = False):
"""
Split the input along the sequence dimension for Ring Attention. As naively spliting sequence
in the causual setting will result in the first ranks having much less workload than the last ranks,
Expand All @@ -301,18 +301,19 @@ def zigzag_split_batch(batch: List[torch.Tensor], sp_group: ProcessGroup, varlen
Args:
batch (List[torch.Tensor]): The input tensors to split.
sp_group (ProcessGroup): The process group for sequence parallelism.
seq_dim (int): The sequence dimension to split.
varlen (bool): If the input is padded (aka "packing" mode), such that
sequences in a batch have different lengths, and we need to unpad and
split each sequence evenly by sp_size.
"""
sp_size = dist.get_world_size(sp_group)
sp_rank = dist.get_rank(sp_group)
seq_dim = 1
if sp_size > 1:
for idx, tensor in enumerate(batch):
assert (
tensor.numel() // (sp_size * 2) > 1
), f"Bro, the seq length for tensor {idx} in batch is too short to split!"

tensor = tensor.view(
*tensor.shape[:seq_dim],
2 * sp_size,
Expand Down
42 changes: 22 additions & 20 deletions colossalai/shardformer/modeling/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,30 +135,32 @@ def llama_model_forward(
if shard_config.enable_flash_attention:
# in this case, attention_mask is a dict rather than a tensor
mask_shape = (batch_size, 1, seq_length_with_past, seq_length_with_past)
attention_mask = ColoAttention.prepare_attn_kwargs(
attn_mask = ColoAttention.prepare_attn_kwargs(
mask_shape,
hidden_states.dtype,
hidden_states.device,
q_padding_mask=attention_mask,
is_causal=True,
)
else:
attention_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)
attn_mask = self._update_causal_mask(attention_mask, hidden_states, cache_position)

# Support SP + PP
if stage_manager.is_first_stage():
# Ring Attention zigzag batch processing
if sp_mode == "ring_attn":
# NOTE: This will throw an error in KV Cache inference without replicating q in all ranks.
# Also, I don't see get_llama_flash_attention_forward supporting
# query_states and key_states with different seq_len.
batch = {
"input": inputs_embeds,
"attention_mask": attention_mask["attention_mask"],
"position": position_ids,
}
batch = zigzag_split_batch(batch, sp_group)
inputs_embeds, attention_mask["attention_mask"], position_ids = batch.values()
elif sp_mode in ["ring", "split_gather"]:
assert shard_config.enable_flash_attention, "Ring Attention inherently requires Flash Attention."
if attn_mask["attention_mask_type"] == AttnMaskType.PADDED_CAUSAL:
attn_mask["cu_seqlens"], attn_mask["max_seqlen"], attn_mask["indices"] = get_pad_info(
attn_mask["attention_mask"].squeeze(1).any(dim=-1)
) # [B, 1, Sq, Skv] -> [B, Sq]
else:
attn_mask["cu_seqlens"] = attn_mask["max_seqlen"] = attn_mask["indices"] = None
batch = [hidden_states, position_ids]
# inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group)
hidden_states, position_ids = zigzag_split_batch(batch, sp_group)

elif is_share_sp_tp(sp_mode):
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group)
elif sp_mode == "all_to_all":
hidden_states = split_forward_gather_backward(hidden_states, 1, sp_group, 1 / sp_size)
Expand Down Expand Up @@ -193,12 +195,11 @@ def llama_model_forward(
for idx, decoder_layer in enumerate(self.layers[start_idx:end_idx], start=start_idx):
if output_hidden_states:
all_hidden_states += (hidden_states,)

if idx - start_idx < num_ckpt_layers:
layer_outputs = self._gradient_checkpointing_func(
decoder_layer.__call__,
hidden_states,
attention_mask,
attn_mask,
position_ids,
past_key_values,
output_attentions,
Expand All @@ -208,14 +209,13 @@ def llama_model_forward(
else:
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask,
attention_mask=attn_mask,
position_ids=position_ids,
past_key_value=past_key_values,
output_attentions=output_attentions,
use_cache=use_cache,
cache_position=cache_position,
)

hidden_states = layer_outputs[0]

if use_cache:
Expand Down Expand Up @@ -314,7 +314,7 @@ def llama_for_causal_lm_forward(

if stage_manager.is_first_stage():
if shard_config.sequence_parallelism_mode == "ring_attn":
labels = zigzag_split_batch(labels, shard_config.sequence_parallel_process_group)
labels = zigzag_split_batch([labels], shard_config.sequence_parallel_process_group)[0]

# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs = LlamaPipelineForwards.llama_model_forward(
Expand Down Expand Up @@ -500,7 +500,7 @@ def forward(

bsz, q_len, _ = hidden_states.size()
# sp: modify sp_len when sequence parallel mode is ring
if sp_mode in ["split_gather", "ring"]:
if is_share_sp_tp(sp_mode):
q_len *= sp_size

if self.config.pretraining_tp > 1:
Expand Down Expand Up @@ -555,7 +555,9 @@ def forward(
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
assert not self.q_proj.weight.isnan().any(), self.q_proj.weight

assert not query_states.isnan().any(), query_states
if sp_mode == "ring_attn":
attn_output = RingAttention.attention(
query_states,
Expand Down Expand Up @@ -701,7 +703,7 @@ def forward(
# inputs_embeds, attention_mask["attention_mask"], position_ids = zigzag_split_batch(batch, sp_group)
inputs_embeds, position_ids = zigzag_split_batch(batch, sp_group)

elif sp_mode in ["ring", "split_gather"]:
elif is_share_sp_tp(sp_mode):
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group)
elif sp_mode == "all_to_all":
inputs_embeds = split_forward_gather_backward(inputs_embeds, 1, sp_group, 1 / sp_size)
Expand Down
4 changes: 2 additions & 2 deletions colossalai/shardformer/policies/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,11 +292,11 @@ class CommandForCausalLMPolicy(CommandPolicy):
def module_policy(self):
from transformers import CohereForCausalLM

self.is_casual = True
self.is_causal = True
policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# add a new item for causal lm
new_item = {
CohereForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/deepseek.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ def module_policy(self):
policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# add a new item for causal lm
new_item = {
"DeepseekForCausalLM": ModulePolicyDescription(
sub_module_replacement=[
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def module_policy(self):
policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# add a new item for causal lm
new_item = {
LlamaForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def module_policy(self):
policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# add a new item for causal lm
new_item = {
MistralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/mixtral.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def module_policy(self):
policy = super().module_policy()
# TODO: assign pg mesh from plugin to all modules
if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# add a new item for causal lm
new_item = {
MixtralForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/qwen2.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def module_policy(self):
setattr(self.shard_config, "causal_lm", True)

if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# add a new item for causal lm
new_item = {
Qwen2ForCausalLM: ModulePolicyDescription(
sub_module_replacement=[
Expand Down
3 changes: 1 addition & 2 deletions examples/language/llama/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def empty_init():
num_model_chunks=args.n_chunks,
zero_stage=args.zero,
sp_size=args.sp,
sp_mode=args.sp_mode,
sequence_parallelism_mode=args.sp_mode,
enable_sequence_parallelism=args.sp > 1,
enable_fused_normalization=torch.cuda.is_available(),
enable_flash_attention=args.xformers,
Expand Down Expand Up @@ -324,7 +324,6 @@ def empty_init():

performance_evaluator.on_step_end(**batch)
prof.step()
booster.save_model(model, "model.pt")
performance_evaluator.on_fit_end()
coordinator.print_on_master(f"Max CUDA memory usage: {get_accelerator().max_memory_allocated()/1024**2:.2f} MB")

Expand Down
2 changes: 1 addition & 1 deletion examples/language/openmoe/model/openmoe_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def module_policy(self):
policy = super().module_policy()

if self.shard_config.enable_tensor_parallelism:
# add a new item for casual lm
# add a new item for causal lm
# TODO: recursively assign ep group foe all modules
new_item = {
OpenMoeForCausalLM: ModulePolicyDescription(
Expand Down
2 changes: 1 addition & 1 deletion examples/language/opt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ limitations under the License.
## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.

The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Causal Language Modelling at low cost.


## Our Modifications
Expand Down
2 changes: 1 addition & 1 deletion examples/tutorial/opt/opt/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ limitations under the License.
## OPT
Meta recently released [Open Pretrained Transformer (OPT)](https://github.com/facebookresearch/metaseq), a 175-Billion parameter AI language model, which stimulates AI programmers to perform various downstream tasks and application deployments.

The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning Casual Language Modelling at low cost.
The following example of [Colossal-AI](https://github.com/hpcaitech/ColossalAI) demonstrates fine-tuning causal Language Modelling at low cost.

We are using the pre-training weights of the OPT model provided by Hugging Face Hub on the raw WikiText-2 (no tokens were replaced before
the tokenization). This training script is adapted from the [HuggingFace Language Modelling examples](https://github.com/huggingface/transformers/tree/main/examples/pytorch/language-modeling).
Expand Down
4 changes: 2 additions & 2 deletions tests/kit/model_zoo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,9 @@
"transformers_bloom_for_causal_lm",
"transformers_falcon_for_causal_lm",
"transformers_chatglm_for_conditional_generation",
"transformers_llama_for_casual_lm",
"transformers_llama_for_causal_lm",
"transformers_vit_for_masked_image_modeling",
"transformers_mistral_for_casual_lm",
"transformers_mistral_for_causal_lm",
]

IS_FAST_TEST = os.environ.get("FAST_TEST", "0") == "1"
Expand Down
12 changes: 6 additions & 6 deletions tests/kit/model_zoo/transformers/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,8 @@ def data_gen():

return dict(input_ids=input_ids, attention_mask=attention_mask)

# label is needed for casual lm
def data_gen_for_casual_lm():
# label is needed for causal lm
def data_gen_for_causal_lm():
data = data_gen()
labels = data["input_ids"].clone()
data["labels"] = labels
Expand All @@ -44,7 +44,7 @@ def data_gen_for_casual_lm():

# function to get the loss
loss_fn = lambda output: output["last_hidden_state"].mean()
loss_fn_for_casual_lm = lambda output: output["loss"]
loss_fn_for_causal_lm = lambda output: output["loss"]
loss_fn_for_seq_classification = lambda output: output["logits"].mean()

config = CohereConfig(
Expand All @@ -70,10 +70,10 @@ def data_gen_for_casual_lm():
model_attribute=ModelAttribute(has_control_flow=True),
)
model_zoo.register(
name="transformers_command_for_casual_lm",
name="transformers_command_for_causal_lm",
model_fn=lambda: transformers.CohereForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
data_gen_fn=data_gen_for_causal_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
loss_fn=loss_fn_for_causal_lm,
model_attribute=ModelAttribute(has_control_flow=True),
)
Loading

0 comments on commit 501205d

Please sign in to comment.