Skip to content

Commit

Permalink
[shardformer] adapted T5 and LLaMa test to use kit (hpcaitech#4049)
Browse files Browse the repository at this point in the history
* [shardformer] adapted T5 and LLaMa test to use kit

* polish code
  • Loading branch information
FrankLeeeee authored and FoolPlayer committed Jun 21, 2023
1 parent 32caa9d commit 96ea281
Show file tree
Hide file tree
Showing 24 changed files with 242 additions and 171 deletions.
22 changes: 15 additions & 7 deletions colossalai/shardformer/layer/embedding1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,21 +65,22 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()

self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)

self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
# self.gather_output = gather_output
self.gather_output = gather_output

if device is None:
device = get_current_device()
Expand All @@ -95,7 +96,9 @@ def __init__(self,

@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
Expand Down Expand Up @@ -123,7 +126,9 @@ def from_native_module(module: nn.Embedding,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
sparse=sparse,
*args,
**kwargs)

# copy the weight
with torch.no_grad():
Expand All @@ -133,7 +138,7 @@ def from_native_module(module: nn.Embedding,
return embedding

def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

Expand All @@ -144,6 +149,9 @@ def _fill_padding_idx_with_zero(self) -> None:

def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)

return output
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

from colossalai.shardformer.layer.layers import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D
from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand Down
3 changes: 1 addition & 2 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
T5Stack,
)

from colossalai.shardformer.layer.dropout import Dropout1D
from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import Dropout1D, Embedding1D, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand Down
11 changes: 9 additions & 2 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,14 @@ def _replace_sub_module(
if description.ignore_if_not_exist and native_sub_module is None:
continue

replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
try:
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
f" with {target_module.__qualname__} with the exception: {e}. "
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
)

setattr_(org_layer, suffix, replace_layer)
2 changes: 1 addition & 1 deletion colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any,
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose(
out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
else:
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"
30 changes: 19 additions & 11 deletions tests/kit/model_zoo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,35 @@ def register(self,
model_fn: Callable,
data_gen_fn: Callable,
output_transform_fn: Callable,
loss_fn: Callable = None,
model_attribute: ModelAttribute = None):
"""
Register a model and data generation function.
Examples:
>>> # Register
>>> model_zoo = ModelZooRegistry()
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
>>> # Run the model
>>> data = resnresnet18_data_gen() # do not input any argument
>>> model = resnet18() # do not input any argument
>>> out = model(**data)
```python
# normal forward workflow
model = resnet18()
data = resnet18_data_gen()
output = model(**data)
transformed_output = output_transform_fn(output)
loss = loss_fn(transformed_output)
# Register
model_zoo = ModelZooRegistry()
model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn)
```
Args:
name (str): Name of the model.
model_fn (callable): A function that returns a model. **It must not contain any arguments.**
output_transform_fn (callable): A function that transforms the output of the model into Dict.
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
model_fn (Callable): A function that returns a model. **It must not contain any arguments.**
data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
output_transform_fn (Callable): A function that transforms the output of the model into Dict.
loss_fn (Callable): a function to compute the loss from the given output. Defaults to None
model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)

def get_sub_registry(self, keyword: str):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .albert import *
from .bert import *
from .gpt import *
from .llama import *
from .opt import *
from .t5 import *
76 changes: 76 additions & 0 deletions tests/kit/model_zoo/transformers/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

try:
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
HAS_LLAMA = True
except ImportError:
HAS_LLAMA = False

if HAS_LLAMA:
# ===============================
# Register LLaMA
# ===============================

def data_gen():
# the input ids are corresponding to the sentence
# 'Hello, my dog is cute'
#
# the code is give below:
# -----------------------------------
# from transformers import LlamaTokenizerFast
# tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------

input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)

# label is needed for casual lm
def data_gen_for_casual_lm():
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
return data

# transform the output to a dict
output_transform_fn = lambda x: x

# function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean()
loss_fn_for_casual_lm = lambda output: output.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()

config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)

# register the following models
# transformers.LlamaModel,
# transformers.LlamaForCausalLM,
# transformers.LlamaForSequenceClassification,
model_zoo.register(name='transformers_llama',
model_fn=lambda: transformers.LlamaModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_casual_lm',
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_sequence_classification',
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True))
53 changes: 41 additions & 12 deletions tests/kit/model_zoo/transformers/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,70 @@
# ===============================
# Register single-sentence T5
# ===============================
BATCH_SIZE = 2
SEQ_LENGTH = 16


def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)


# define data gen function
def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
# Generated from following code snippet
#
# from transformers import T5Config, T5Tokenizer
# config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
return dict(input_ids=input_ids)


def data_gen_for_conditional_generation():
# labels is generated with the following code
#
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
data['labels'] = labels
return data


def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code
#
# decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids
return data


# output transform function
output_transform_fn = lambda x: x

config = transformers.T5Config(d_model=128, num_layers=2)
# define loss funciton
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
loss_fn_for_conditional_generation = lambda x: x.loss

# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)

# register the following models
# transformers.T5Model,
# transformers.T5ForConditionalGeneration,
# transformers.T5EncoderModel,
model_zoo.register(name='transformers_t5',
model_fn=lambda: transformers.T5Model(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_t5_model,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_t5_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_for_conditional_generation',
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_conditional_generation,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_encoder_model',
model_fn=lambda: transformers.T5EncoderModel(config),
data_gen_fn=data_gen_for_encoder_only,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True))
2 changes: 1 addition & 1 deletion tests/test_booster/test_mixed_precision/test_fp16_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def run_torch_amp(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items():
# dlrm_interactionarch has not parameters, so skip
if name == 'dlrm_interactionarch':
continue
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
passed_models = []
failed_info = {} # (model_name, error) pair

for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
skipped_models = []

for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# FIXME(ver217): fix these models
if name in ignore_models:
skipped_models.append(name)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):


def check_torch_ddp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if name == 'dlrm_interactionarch':
continue
run_fn(model_fn, data_gen_fn, output_transform_fn)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):


def check_torch_fsdp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if any(element in name for element in [
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
'torchvision_inception_v3'
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_diffusers():

sub_model_zoo = model_zoo.get_sub_registry('diffusers')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize()
Expand All @@ -60,7 +60,7 @@ def test_torch_diffusers():

sub_model_zoo = model_zoo.get_sub_registry('diffusers')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
model = model_fn()
output = model(**data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_timm_models():

sub_model_zoo = model_zoo.get_sub_registry('timm')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}
Expand Down
Loading

0 comments on commit 96ea281

Please sign in to comment.