Skip to content

Commit

Permalink
[shardformer] made tensor parallelism configurable (hpcaitech#4144)
Browse files Browse the repository at this point in the history
* [shardformer] made tensor parallelism configurable

* polish code
  • Loading branch information
FrankLeeeee authored and ver217 committed Jul 13, 2023
1 parent b5819ab commit cb10853
Show file tree
Hide file tree
Showing 15 changed files with 814 additions and 668 deletions.
25 changes: 25 additions & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,28 @@ def postprocess(self) -> nn.Module:
the classifier layer
"""
pass

def append_or_create_submodule_replacement(
self, description: Union[SubModuleReplacementDescription,
List[SubModuleReplacementDescription]], policy: Dict[Union[str, nn.Module],
ModulePolicyDescription],
target_key: Union[str, nn.Module]) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
r"""
Append or create a new submodule replacement description to the policy for the given key.
Args:
submodule_replace_desc (Union[SubModuleReplacementDescription, List[SubModuleReplacementDescription]]): the submodule replacement description to be appended
policy (Dict[Union[str, nn.Module], ModulePolicyDescription]): the policy to be updated
target_key (Union[str, nn.Module]): the key of the policy to be updated
"""
# convert to list
if isinstance(description, SubModuleReplacementDescription):
description = [description]

# append or create a new description
if target_key in policy:
policy[target_key].sub_module_replacement.extend(description)
else:
policy[target_key] = ModulePolicyDescription(sub_module_replacement=description)

return policy
299 changes: 136 additions & 163 deletions colossalai/shardformer/policies/bert.py

Large diffs are not rendered by default.

166 changes: 84 additions & 82 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,57 +85,53 @@ def preprocess(self):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomBlock, BloomModel

base_policy = {
BloomBlock:
ModulePolicyDescription(
attribute_replacement={
# 1. shard hidden size
"self_attention.hidden_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size":
self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
# 2. shard number of heads
"self_attention.num_heads":
self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
),
]),
BloomModel:
ModulePolicyDescription(attribute_replacement={
policy = {}

if self.shard_config.enable_tensor_parallelism:
policy[BloomBlock] = ModulePolicyDescription(attribute_replacement={
"self_attention.hidden_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.split_size": self.model.config.hidden_size // self.shard_config.tensor_parallel_size,
"self_attention.num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="self_attention.query_key_value",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="self_attention.dense",
target_module=col_nn.Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="self_attention.attention_dropout",
target_module=col_nn.DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="mlp.dense_h_to_4h",
target_module=col_nn.Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="mlp.dense_4h_to_h",
target_module=col_nn.Linear1D_Row,
),
])

policy[BloomModel] = ModulePolicyDescription(
attribute_replacement={
"num_heads": self.model.config.n_head // self.shard_config.tensor_parallel_size,
},
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])
}
method_replacement={"build_alibi_tensor": build_bloom_alibi_tensor},
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="word_embeddings",
target_module=col_nn.VocabParallelEmbedding1D,
)
])

# optimization configuration
if self.shard_config.enable_fused_normalization:
base_policy[BloomModel].sub_module_replacement.extend([
# handle bloom model
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="ln_f",
target_module=col_nn.FusedLayerNorm,
Expand All @@ -144,8 +140,12 @@ def module_policy(self):
suffix="word_embeddings_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
base_policy[BloomBlock].sub_module_replacement.extend([
],
policy=policy,
target_key=BloomModel)

# handle bloom block
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(
suffix="input_layernorm",
target_module=col_nn.FusedLayerNorm,
Expand All @@ -154,9 +154,11 @@ def module_policy(self):
suffix="post_attention_layernorm",
target_module=col_nn.FusedLayerNorm,
)
])
],
policy=policy,
target_key=BloomBlock)

return base_policy
return policy

def postprocess(self):
return self.model
Expand All @@ -171,27 +173,26 @@ class BloomForCausalLMPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForCausalLM
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForCausalLM:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)

# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="lm_head", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForCausalLM)

return policy

def postprocess(self):
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}

for k, v in binding_map.items():
param = getattr_(self.model, k)

if not isinstance(param, nn.Parameter):
param = nn.Parameter(param)

# tie weights
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model

Expand All @@ -201,15 +202,14 @@ class BloomForSequenceClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForSequenceClassification
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForSequenceClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True))
])
}
policy.update(new_item)

# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=SubModuleReplacementDescription(
suffix="score", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
policy=policy,
target_key=BloomForSequenceClassification)

return policy


Expand All @@ -218,19 +218,21 @@ class BloomForTokenClassificationPolicy(BloomPolicy):
def module_policy(self):
from transformers.models.bloom.modeling_bloom import BloomForTokenClassification
policy = super().module_policy()
# add a new item for casual lm
new_item = {
BloomForTokenClassification:
ModulePolicyDescription(sub_module_replacement=[
SubModuleReplacementDescription(
suffix="classifier", target_module=col_nn.Linear1D_Col, kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
])
}
policy.update(new_item)

# handle tensor parallelism
if self.shard_config.enable_tensor_parallelism:
self.append_or_create_submodule_replacement(description=[
SubModuleReplacementDescription(suffix="classifier",
target_module=col_nn.Linear1D_Col,
kwargs=dict(gather_output=True)),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForReplicatedInput,
),
],
policy=policy,
target_key=BloomForTokenClassification)

return policy


Expand Down
Loading

0 comments on commit cb10853

Please sign in to comment.