Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] added embedding gradient check #4124

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions colossalai/shardformer/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def setattr_(obj, attr: str, value, ignore: bool = False):
except AttributeError:
if ignore:
return
raise AttributeError(f"Object {obj} has no attribute {attr}")
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
setattr(obj, attrs[-1], value)


Expand All @@ -76,5 +76,5 @@ def getattr_(obj, attr: str, ignore: bool = False):
except AttributeError:
if ignore:
return None
raise AttributeError(f"Object {obj} has no attribute {attr}")
raise AttributeError(f"Object {obj.__class__.__name__} has no attribute {attr}")
return obj
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def module_policy(self):
),
SubModuleReplacementDescription(
suffix="dropout",
target_module=col_nn.DropoutForParallelInput,
target_module=col_nn.DropoutForReplicatedInput,
)
])
}
Expand Down
19 changes: 16 additions & 3 deletions colossalai/shardformer/policies/bloom.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch
import torch.distributed as dist
import torch.nn as nn

import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription


Expand Down Expand Up @@ -73,7 +75,6 @@ def preprocess(self):
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
# TODO:
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size
if vocab_size % world_size != 0:
Expand Down Expand Up @@ -161,13 +162,12 @@ def module_policy(self):

def new_model_class(self):
# do nothing
return self.model
return None

def postprocess(self):
return self.model


# BertModel
class BloomModelPolicy(BloomPolicy):
pass

Expand All @@ -191,6 +191,19 @@ def module_policy(self):
policy.update(new_item)
return policy

def postprocess(self):
binding_map = {"transformer.word_embeddings.weight": "lm_head.weight"}
for k, v in binding_map.items():
param = getattr_(self.model, k)

if not isinstance(param, nn.Parameter):
param = nn.Parameter(param)

# tie weights
setattr_(self.model, k, param)
setattr_(self.model, v, param)
return self.model


class BloomForSequenceClassificationPolicy(BloomPolicy):

Expand Down
17 changes: 15 additions & 2 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import FusedLayerNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
Expand Down Expand Up @@ -35,7 +36,7 @@ def module_policy(self):
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
target_module=VocabParallelEmbedding1D,
)
]),
OPTDecoderLayer:
Expand Down Expand Up @@ -127,6 +128,18 @@ def module_policy(self):
policy.update(new_item)
return policy

def postprocess(self):
binding_map = {
'model.decoder.embed_tokens': 'lm_head',
}

for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight

return self.model


class OPTForSequenceClassificationPolicy(OPTPolicy):

Expand Down
105 changes: 87 additions & 18 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,20 @@
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import (
DropoutForParallelInput,
Embedding1D,
FusedRMSNorm,
Linear1D_Col,
Linear1D_Row,
VocabParallelEmbedding1D,
)
from colossalai.shardformer.policies.basepolicy import ModulePolicyDescription

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"]


class T5ModelPolicy(Policy):
class T5BasePolicy(Policy):

def config_sanity_check(self):
pass
Expand Down Expand Up @@ -33,14 +42,18 @@ def module_policy(self):
T5Stack,
)

return {
base_policy = {
T5Stack:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=DropoutForParallelInput,
),
SubModuleReplacementDescription(
suffix="embed_tokens",
target_module=Embedding1D,
)
]),
T5LayerSelfAttention:
Expand Down Expand Up @@ -158,30 +171,86 @@ def new_model_class(self):
return None

def postprocess(self):
binding_map = [["shared", "encoder.embed_tokens"], ["shared", "decoder.embed_tokens"]]

for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
return self.model


class T5ForConditionalGenerationPolicy(T5ModelPolicy):
class T5ModelPolicy(T5BasePolicy):

def module_policy(self):
from transformers import T5Model

base_policy = super().module_policy()
base_policy[T5Model] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
return base_policy


class T5ForConditionalGenerationPolicy(T5BasePolicy):

def module_policy(self):
from transformers import T5ForConditionalGeneration

policy = super().module_policy()
policy[T5ForConditionalGeneration] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
),
SubModuleReplacementDescription(
suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
return policy

new_item = {
T5ForConditionalGeneration:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(suffix="lm_head",
target_module=Linear1D_Col,
kwargs=dict(gather_output=True))
])
}
def postprocess(self):
super().postprocess()

binding_map = {"shared": "lm_head"}

for k, v in binding_map.items():
src_mod = getattr_(self.model, k)
dst_mod = getattr_(self.model, v)
dst_mod.weight = src_mod.weight

return self.model

policy.update(new_item)
return policy

class T5EncoderPolicy(T5BasePolicy):

class T5EncoderPolicy(T5ModelPolicy):
pass
def module_policy(self):
from transformers import T5EncoderModel

base_policy = super().module_policy()
base_policy[T5EncoderModel] = ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="shared",
target_module=VocabParallelEmbedding1D,
)
])
return base_policy

def postprocess(self):
binding_map = [
["shared", "encoder.embed_tokens"],
]

for k, v in binding_map:
mod = getattr_(self.model, k)
setattr_(self.model, v, mod)
return self.model
11 changes: 0 additions & 11 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,17 +38,6 @@ def shard(self) -> None:
self._replace_module()
self._postprocess()

def reshape_embedding(self) -> None:
r"""
Reshape the Embedding layer to make the embedding dimension divisible by world_size
"""
vocab_size = self.model_config.vocab_size
world_size = self.shard_config.world_size
if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)
self.model_config = self.model.config

def _preprocess(self) -> None:
self.model = self.policy.preprocess()

Expand Down
2 changes: 2 additions & 0 deletions tests/kit/model_zoo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ def get_sub_registry(self, keyword: str):
for k, v in self.items():
if keyword in k:
new_dict[k] = v

assert len(new_dict) > 0, f'No model found with keyword {keyword}'
return new_dict


Expand Down
29 changes: 22 additions & 7 deletions tests/test_shardformer/test_model/test_shard_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,35 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
org_loss.backward()
shard_loss.backward()

# check grad equality
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"

# check grad

if org_model.__class__.__name__ == 'BertModel':
org_grad = org_model.encoder.layer[0].attention.self.query.weight.grad
shard_grad = sharded_model.encoder.layer[0].attention.self.query.weight.grad
bert = org_model
sharded_bert = sharded_model
else:
org_grad = org_model.bert.encoder.layer[0].attention.self.query.weight.grad
shard_grad = sharded_model.bert.encoder.layer[0].attention.self.query.weight.grad
bert = org_model.bert
sharded_bert = sharded_model.bert

# compare self attention grad
org_grad = bert.encoder.layer[0].attention.self.query.weight.grad
shard_grad = sharded_bert.encoder.layer[0].attention.self.query.weight.grad

shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
# compare embedding grad
org_grad = bert.embeddings.word_embeddings.weight.grad
shard_grad = sharded_bert.embeddings.word_embeddings.weight.grad

shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

Expand Down
30 changes: 23 additions & 7 deletions tests/test_shardformer/test_model/test_shard_bloom.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,36 @@ def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transfo
org_loss.backward()
shard_loss.backward()

# check grad equality
assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"

# unwrap model
if org_model.__class__.__name__ == 'BloomModel':
org_grad = org_model.h[0].self_attention.query_key_value.weight.grad
shard_grad = sharded_model.h[0].self_attention.query_key_value.weight.grad
bloom = org_model
sharded_bloom = sharded_model
else:
org_grad = org_model.transformer.h[0].self_attention.query_key_value.weight.grad
shard_grad = sharded_model.transformer.h[0].self_attention.query_key_value.weight.grad
bloom = org_model.transformer
sharded_bloom = sharded_model.transformer

# check attention grad
org_grad = bloom.h[0].self_attention.query_key_value.weight.grad
shard_grad = sharded_bloom.h[0].self_attention.query_key_value.weight.grad

shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)

assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

# check embedding weights
org_grad = bloom.word_embeddings.weight.grad
shard_grad = sharded_bloom.word_embeddings.weight.grad

shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)

assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"

Expand Down
Loading
Loading