From d3f34ee8cc48b089c8b7dbc55697f77719f33079 Mon Sep 17 00:00:00 2001 From: Wang Binluo <32676639+wangbluo@users.noreply.github.com> Date: Mon, 29 Apr 2024 05:47:47 -0500 Subject: [PATCH] [Shardformer] add assert for num of attention heads divisible by tp_size (#5670) * add assert for num of attention heads divisible by tp_size * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- colossalai/shardformer/policies/bert.py | 3 +++ colossalai/shardformer/policies/blip2.py | 3 +++ colossalai/shardformer/policies/bloom.py | 3 +++ colossalai/shardformer/policies/falcon.py | 6 ++++++ colossalai/shardformer/policies/gpt2.py | 3 +++ colossalai/shardformer/policies/gptj.py | 3 +++ colossalai/shardformer/policies/llama.py | 6 ++++++ colossalai/shardformer/policies/mistral.py | 6 ++++++ colossalai/shardformer/policies/opt.py | 3 +++ colossalai/shardformer/policies/sam.py | 3 +++ colossalai/shardformer/policies/t5.py | 3 +++ colossalai/shardformer/policies/vit.py | 3 +++ colossalai/shardformer/policies/whisper.py | 3 +++ 13 files changed, 48 insertions(+) diff --git a/colossalai/shardformer/policies/bert.py b/colossalai/shardformer/policies/bert.py index ad40e0e56228..0c04f7d38ca0 100644 --- a/colossalai/shardformer/policies/bert.py +++ b/colossalai/shardformer/policies/bert.py @@ -79,6 +79,9 @@ def module_policy(self): sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[BertLayer] = ModulePolicyDescription( attribute_replacement={ "attention.self.all_head_size": self.model.config.hidden_size diff --git a/colossalai/shardformer/policies/blip2.py b/colossalai/shardformer/policies/blip2.py index 9d1f6a306a3d..32d4edadb3e4 100644 --- a/colossalai/shardformer/policies/blip2.py +++ b/colossalai/shardformer/policies/blip2.py @@ -52,6 +52,9 @@ def module_policy(self): norm_cls = col_nn.LayerNorm if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[Blip2EncoderLayer] = ModulePolicyDescription( attribute_replacement={ "self_attn.num_heads": self.model.config.vision_config.num_attention_heads diff --git a/colossalai/shardformer/policies/bloom.py b/colossalai/shardformer/policies/bloom.py index 4894bda35bfc..4f076d23368b 100644 --- a/colossalai/shardformer/policies/bloom.py +++ b/colossalai/shardformer/policies/bloom.py @@ -61,6 +61,9 @@ def module_policy(self): sp_partial_derived = sp_mode == "split_gather" if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.n_head % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[BloomBlock] = ModulePolicyDescription( attribute_replacement={ "self_attention.hidden_size": self.model.config.hidden_size diff --git a/colossalai/shardformer/policies/falcon.py b/colossalai/shardformer/policies/falcon.py index e72a97e4bfc0..23d6efbeb27a 100644 --- a/colossalai/shardformer/policies/falcon.py +++ b/colossalai/shardformer/policies/falcon.py @@ -47,6 +47,12 @@ def module_policy(self): embedding_cls = col_nn.PaddingEmbedding if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_kv_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." attn_attribute_replacement = { "self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, diff --git a/colossalai/shardformer/policies/gpt2.py b/colossalai/shardformer/policies/gpt2.py index 531c2153b665..281ea88c2162 100644 --- a/colossalai/shardformer/policies/gpt2.py +++ b/colossalai/shardformer/policies/gpt2.py @@ -84,6 +84,9 @@ def module_policy(self): self.shard_config.enable_flash_attention = False use_flash_attention = False if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[GPT2Model] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/gptj.py b/colossalai/shardformer/policies/gptj.py index 25e5b66dcc75..3315eb1e9256 100644 --- a/colossalai/shardformer/policies/gptj.py +++ b/colossalai/shardformer/policies/gptj.py @@ -57,6 +57,9 @@ def module_policy(self): overlap = self.shard_config.enable_sequence_overlap if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[GPTJModel] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py index 0a95284bcfdf..6e541f792248 100644 --- a/colossalai/shardformer/policies/llama.py +++ b/colossalai/shardformer/policies/llama.py @@ -138,6 +138,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, diff --git a/colossalai/shardformer/policies/mistral.py b/colossalai/shardformer/policies/mistral.py index b5018e47d65d..984b71646318 100644 --- a/colossalai/shardformer/policies/mistral.py +++ b/colossalai/shardformer/policies/mistral.py @@ -66,6 +66,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: ) if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." + assert ( + self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of key_value heads must be divisible by tensor parallel size." decoder_attribute_replacement = { "self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size, "self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size, diff --git a/colossalai/shardformer/policies/opt.py b/colossalai/shardformer/policies/opt.py index 2f6eabd5fef9..9619b3d41b8a 100644 --- a/colossalai/shardformer/policies/opt.py +++ b/colossalai/shardformer/policies/opt.py @@ -76,6 +76,9 @@ def module_policy(self): warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[OPTDecoderLayer] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/sam.py b/colossalai/shardformer/policies/sam.py index ce33925ff82e..c224d776957a 100644 --- a/colossalai/shardformer/policies/sam.py +++ b/colossalai/shardformer/policies/sam.py @@ -31,6 +31,9 @@ def module_policy(self): norm_cls = col_nn.LayerNorm if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[SamVisionLayer] = ModulePolicyDescription( attribute_replacement={ "attn.num_attention_heads": self.model.config.vision_config.num_attention_heads diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 3c7e92b47db0..1298f0af3e61 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -72,6 +72,9 @@ def module_policy(self): warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[T5Stack] = ModulePolicyDescription( sub_module_replacement=[ SubModuleReplacementDescription( diff --git a/colossalai/shardformer/policies/vit.py b/colossalai/shardformer/policies/vit.py index b7883af9f7c8..069ad0c2690c 100644 --- a/colossalai/shardformer/policies/vit.py +++ b/colossalai/shardformer/policies/vit.py @@ -44,6 +44,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]: warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.") if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[ViTEmbeddings] = ModulePolicyDescription( attribute_replacement={}, param_replacement=[], diff --git a/colossalai/shardformer/policies/whisper.py b/colossalai/shardformer/policies/whisper.py index aeb6687971e5..441e512bbb28 100644 --- a/colossalai/shardformer/policies/whisper.py +++ b/colossalai/shardformer/policies/whisper.py @@ -78,6 +78,9 @@ def module_policy(self): warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.") if self.shard_config.enable_tensor_parallelism: + assert ( + self.model.config.encoder_attention_heads % self.shard_config.tensor_parallel_size == 0 + ), f"The number of attention heads must be divisible by tensor parallel size." policy[WhisperEncoderLayer] = ModulePolicyDescription( attribute_replacement={ "self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,