Skip to content

Commit

Permalink
[Shardformer] add assert for num of attention heads divisible by tp_s…
Browse files Browse the repository at this point in the history
…ize (#5670)

* add assert for num of attention heads divisible by tp_size

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
wangbluo and pre-commit-ci[bot] authored Apr 29, 2024
1 parent 6af6d6f commit d3f34ee
Show file tree
Hide file tree
Showing 13 changed files with 48 additions and 0 deletions.
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,9 @@ def module_policy(self):
sp_partial_derived = sp_mode == "split_gather"

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[BertLayer] = ModulePolicyDescription(
attribute_replacement={
"attention.self.all_head_size": self.model.config.hidden_size
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/blip2.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def module_policy(self):
norm_cls = col_nn.LayerNorm

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[Blip2EncoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.num_heads": self.model.config.vision_config.num_attention_heads
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,9 @@ def module_policy(self):
sp_partial_derived = sp_mode == "split_gather"

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.n_head % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[BloomBlock] = ModulePolicyDescription(
attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size
Expand Down
6 changes: 6 additions & 0 deletions colossalai/shardformer/policies/falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,12 @@ def module_policy(self):
embedding_cls = col_nn.PaddingEmbedding

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
assert (
self.model.config.num_kv_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
attn_attribute_replacement = {
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,9 @@ def module_policy(self):
self.shard_config.enable_flash_attention = False
use_flash_attention = False
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[GPT2Model] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/gptj.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ def module_policy(self):

overlap = self.shard_config.enable_sequence_overlap
if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[GPTJModel] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
Expand Down
6 changes: 6 additions & 0 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
Expand Down
6 changes: 6 additions & 0 deletions colossalai/shardformer/policies/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,12 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
)

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
assert (
self.model.config.num_key_value_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of key_value heads must be divisible by tensor parallel size."
decoder_attribute_replacement = {
"self_attn.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attn.num_heads": self.model.config.num_attention_heads // self.shard_config.tensor_parallel_size,
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,9 @@ def module_policy(self):
warnings.warn("OPT doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[OPTDecoderLayer] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/sam.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,9 @@ def module_policy(self):
norm_cls = col_nn.LayerNorm

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.vision_config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[SamVisionLayer] = ModulePolicyDescription(
attribute_replacement={
"attn.num_attention_heads": self.model.config.vision_config.num_attention_heads
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ def module_policy(self):
warnings.warn("T5 doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[T5Stack] = ModulePolicyDescription(
sub_module_replacement=[
SubModuleReplacementDescription(
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,9 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
warnings.warn("Vit doesn't support sequence parallelism now, will ignore the sequence parallelism flag.")

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.num_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[ViTEmbeddings] = ModulePolicyDescription(
attribute_replacement={},
param_replacement=[],
Expand Down
3 changes: 3 additions & 0 deletions colossalai/shardformer/policies/whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ def module_policy(self):
warnings.warn("Whisper doesn't support jit fused operator now, will ignore the jit fused operator flag.")

if self.shard_config.enable_tensor_parallelism:
assert (
self.model.config.encoder_attention_heads % self.shard_config.tensor_parallel_size == 0
), f"The number of attention heads must be divisible by tensor parallel size."
policy[WhisperEncoderLayer] = ModulePolicyDescription(
attribute_replacement={
"self_attn.embed_dim": self.model.config.d_model // self.shard_config.tensor_parallel_size,
Expand Down

0 comments on commit d3f34ee

Please sign in to comment.