Skip to content

Commit

Permalink
[shardformer] adapted T5 and LLaMa test to use kit
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee committed Jun 20, 2023
1 parent 436975c commit 78a6686
Show file tree
Hide file tree
Showing 21 changed files with 227 additions and 161 deletions.
13 changes: 11 additions & 2 deletions colossalai/shardformer/shard/sharder.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,16 @@ def _replace_sub_module(
if description.ignore_if_not_exist and native_sub_module is None:
continue

replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
try:
replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'],
**kwargs)
except Exception as e:
raise RuntimeError(
f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}"
"========== Error =========="
f" with {target_module.__qualname__} with the following exception:\n{e}"
"==========================="
"Please check your model configuration or sharding policy, you can set up an issue for us to help you as well."
)

setattr_(org_layer, suffix, replace_layer)
2 changes: 1 addition & 1 deletion colossalai/testing/comparison.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any,
raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}")
assert torch.allclose(
out1, out2, atol=atol, rtol=rtol
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}"
), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}"
else:
assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}"
30 changes: 19 additions & 11 deletions tests/kit/model_zoo/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,27 +28,35 @@ def register(self,
model_fn: Callable,
data_gen_fn: Callable,
output_transform_fn: Callable,
loss_fn: Callable = None,
model_attribute: ModelAttribute = None):
"""
Register a model and data generation function.
Examples:
>>> # Register
>>> model_zoo = ModelZooRegistry()
>>> model_zoo.register('resnet18', resnet18, resnet18_data_gen)
>>> # Run the model
>>> data = resnresnet18_data_gen() # do not input any argument
>>> model = resnet18() # do not input any argument
>>> out = model(**data)
```python
# normal forward workflow
model = resnet18()
data = resnet18_data_gen()
output = model(**data)
transformed_output = output_transform_fn(output)
loss = loss_fn(transformed_output)
# Register
model_zoo = ModelZooRegistry()
model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn)
```
Args:
name (str): Name of the model.
model_fn (callable): A function that returns a model. **It must not contain any arguments.**
output_transform_fn (callable): A function that transforms the output of the model into Dict.
data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
model_fn (Callable): A function that returns a model. **It must not contain any arguments.**
data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.**
output_transform_fn (Callable): A function that transforms the output of the model into Dict.
loss_fn (Callable): a function to compute the loss from the given output. Defaults to None
model_attribute (ModelAttribute): Attributes of the model. Defaults to None.
"""
self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute)
self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute)

def get_sub_registry(self, keyword: str):
"""
Expand Down
1 change: 1 addition & 0 deletions tests/kit/model_zoo/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from .albert import *
from .bert import *
from .gpt import *
from .llama import *
from .opt import *
from .t5 import *
76 changes: 76 additions & 0 deletions tests/kit/model_zoo/transformers/llama.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import torch
import transformers

from ..registry import ModelAttribute, model_zoo

try:
from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel
HAS_LLAMA = True
except ImportError:
HAS_LLAMA = False

if HAS_LLAMA:
# ===============================
# Register LLaMA
# ===============================

def data_gen():
# the input ids are corresponding to the sentence
# 'Hello, my dog is cute'
#
# the code is give below:
# -----------------------------------
# from transformers import LlamaTokenizerFast
# tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer")
# input = 'Hello, my dog is cute'
# tokenized_input = tokenizer(input, return_tensors='pt').to('cuda')
# -----------------------------------

input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long()
attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long()
return dict(input_ids=input_ids, attention_mask=attention_mask)

# label is needed for casual lm
def data_gen_for_casual_lm():
data = data_gen()
labels = data['input_ids'].clone()
data['labels'] = labels
return data

# transform the output to a dict
output_transform_fn = lambda x: x

# function to get the loss
loss_fn = lambda output: output.last_hidden_state.mean()
loss_fn_for_casual_lm = lambda output: output.loss
loss_fn_for_seq_classification = lambda output: output.logits.mean()

config = LlamaConfig(num_hidden_layers=4,
hidden_size=128,
intermediate_size=256,
num_attention_heads=4,
max_position_embeddings=128,
num_labels=16)

# register the following models
# transformers.LlamaModel,
# transformers.LlamaForCausalLM,
# transformers.LlamaForSequenceClassification,
model_zoo.register(name='transformers_llama',
model_fn=lambda: transformers.LlamaModel(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_casual_lm',
model_fn=lambda: transformers.LlamaForCausalLM(config),
data_gen_fn=data_gen_for_casual_lm,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_casual_lm,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_llama_for_sequence_classification',
model_fn=lambda: transformers.LlamaForSequenceClassification(config),
data_gen_fn=data_gen,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_seq_classification,
model_attribute=ModelAttribute(has_control_flow=True))
53 changes: 41 additions & 12 deletions tests/kit/model_zoo/transformers/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,70 @@
# ===============================
# Register single-sentence T5
# ===============================
BATCH_SIZE = 2
SEQ_LENGTH = 16


def data_gen():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids)


# define data gen function
def data_gen_for_encoder_only():
input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64)
# Generated from following code snippet
#
# from transformers import T5Config, T5Tokenizer
# config = T5Config(decoder_start_token_id=0)
# tokenizer = T5Tokenizer.from_pretrained("t5-small")
# input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids
input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long()
return dict(input_ids=input_ids)


def data_gen_for_conditional_generation():
# labels is generated with the following code
#
# labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids
data = data_gen_for_encoder_only()
labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long()
data['labels'] = labels
return data


def data_gen_for_t5_model():
# decoder_inputs_ids is obtained with the following code
#
# decoder_input_ids = model._shift_right(input_ids)
data = data_gen_for_encoder_only()
decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long()
data['decoder_input_ids'] = decoder_input_ids
return data


# output transform function
output_transform_fn = lambda x: x

config = transformers.T5Config(d_model=128, num_layers=2)
# define loss funciton
loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean()
loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean()
loss_fn_for_conditional_generation = lambda x: x.loss

# define model config
config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0)

# register the following models
# transformers.T5Model,
# transformers.T5ForConditionalGeneration,
# transformers.T5EncoderModel,
model_zoo.register(name='transformers_t5',
model_fn=lambda: transformers.T5Model(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_t5_model,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_t5_model,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_for_conditional_generation',
model_fn=lambda: transformers.T5ForConditionalGeneration(config),
data_gen_fn=data_gen,
data_gen_fn=data_gen_for_conditional_generation,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_conditional_generation,
model_attribute=ModelAttribute(has_control_flow=True))
model_zoo.register(name='transformers_t5_encoder_model',
model_fn=lambda: transformers.T5EncoderModel(config),
data_gen_fn=data_gen_for_encoder_only,
output_transform_fn=output_transform_fn,
loss_fn=loss_fn_for_encoder_only,
model_attribute=ModelAttribute(has_control_flow=True))
2 changes: 1 addition & 1 deletion tests/test_booster/test_mixed_precision/test_fp16_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def run_torch_amp(rank, world_size, port):
# init dist env
colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost')
sub_model_zoo = model_zoo.get_sub_registry('timm')
for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items():
# dlrm_interactionarch has not parameters, so skip
if name == 'dlrm_interactionarch':
continue
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True):
passed_models = []
failed_info = {} # (model_name, error) pair

for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# These models lead to CUDA error
if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp',
'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True):
ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS
skipped_models = []

for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
# FIXME(ver217): fix these models
if name in ignore_models:
skipped_models.append(name)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):


def check_torch_ddp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if name == 'dlrm_interactionarch':
continue
run_fn(model_fn, data_gen_fn, output_transform_fn)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_booster/test_plugin/test_torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn):


def check_torch_fsdp_plugin():
for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items():
if any(element in name for element in [
'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet',
'torchvision_inception_v3'
Expand Down
4 changes: 2 additions & 2 deletions tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def test_diffusers():

sub_model_zoo = model_zoo.get_sub_registry('diffusers')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
trace_and_compare(model_fn, data, output_transform_fn)
torch.cuda.synchronize()
Expand All @@ -60,7 +60,7 @@ def test_torch_diffusers():

sub_model_zoo = model_zoo.get_sub_registry('diffusers')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
model = model_fn()
output = model(**data)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_timm_models():

sub_model_zoo = model_zoo.get_sub_registry('timm')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
data = data_gen_fn()
if attribute is not None and attribute.has_control_flow:
meta_args = {k: v.to('meta') for k, v in data.items()}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def test_torchaudio_models():

sub_model_zoo = model_zoo.get_sub_registry('torchaudio')

for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items():
for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items():
model = model_fn()
trace_and_compare(model,
data_gen_fn,
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lazy/lazy_init_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn:


def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None:
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
_MyTensor._pre_op_fn = lambda *args: set_seed(seed)
LazyTensor._pre_op_fn = lambda *args: set_seed(seed)
ctx = LazyInitContext(tensor_cls=_MyTensor)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_lazy/test_distribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ def run_dist_lazy_init(subset, seed: int = 42):
if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'):
continue
print_rank_0(name)
model_fn, data_gen_fn, output_transform_fn, model_attr = entry
model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry
ctx = LazyInitContext(tensor_cls=_MyTensor)
with ctx:
model = model_fn()
Expand Down
Empty file.
Empty file.
38 changes: 38 additions & 0 deletions tests/test_shardformer/test_model/_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import copy

from colossalai.shardformer import ShardConfig, ShardFormer


def build_model(world_size, model_fn):
# create new model
org_model = model_fn().cuda()

# shard model
shard_config = ShardConfig(tensor_parallel_size=world_size)
model_copy = copy.deepcopy(org_model)
shard_former = ShardFormer(shard_config=shard_config)
shard_former.init_distributed()
sharded_model = shard_former.shard_model(model_copy)

return org_model, sharded_model


def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# prepare input
data = data_gen_fn()
data = {k: v.cuda() for k, v in data.items()}

# switch to train mode
original_model.train()
sharded_model.train()

# run forward
org_output = original_model(**data)
org_output = output_transform_fn(org_output)
org_loss = loss_fn(org_output)

shard_output = sharded_model(**data)
shard_output = output_transform_fn(shard_output)
shard_loss = loss_fn(shard_output)

return org_output, org_loss, shard_output, shard_loss
Loading

0 comments on commit 78a6686

Please sign in to comment.