diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index f6ade26b758a..ba660935926a 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -185,7 +185,16 @@ def _replace_sub_module( if description.ignore_if_not_exist and native_sub_module is None: continue - replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], - **kwargs) + try: + replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], + **kwargs) + except Exception as e: + raise RuntimeError( + f"Failed to replace {suffix} of type {native_sub_module.__class__.__qualname__}" + "========== Error ==========" + f" with {target_module.__qualname__} with the following exception:\n{e}" + "===========================" + "Please check your model configuration or sharding policy, you can set up an issue for us to help you as well." + ) setattr_(org_layer, suffix, replace_layer) diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index aeecee7f11f5..5cbfb936b144 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -98,6 +98,6 @@ def assert_hf_output_close(out1: Any, raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") assert torch.allclose( out1, out2, atol=atol, rtol=rtol - ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}" + ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, \nmean error: {torch.abs(out1 - out2).mean()}" else: assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}" diff --git a/tests/kit/model_zoo/registry.py b/tests/kit/model_zoo/registry.py index 7470327a65b6..a83c71a90b1a 100644 --- a/tests/kit/model_zoo/registry.py +++ b/tests/kit/model_zoo/registry.py @@ -28,27 +28,35 @@ def register(self, model_fn: Callable, data_gen_fn: Callable, output_transform_fn: Callable, + loss_fn: Callable = None, model_attribute: ModelAttribute = None): """ Register a model and data generation function. Examples: - >>> # Register - >>> model_zoo = ModelZooRegistry() - >>> model_zoo.register('resnet18', resnet18, resnet18_data_gen) - >>> # Run the model - >>> data = resnresnet18_data_gen() # do not input any argument - >>> model = resnet18() # do not input any argument - >>> out = model(**data) + + ```python + # normal forward workflow + model = resnet18() + data = resnet18_data_gen() + output = model(**data) + transformed_output = output_transform_fn(output) + loss = loss_fn(transformed_output) + + # Register + model_zoo = ModelZooRegistry() + model_zoo.register('resnet18', resnet18, resnet18_data_gen, output_transform_fn, loss_fn) + ``` Args: name (str): Name of the model. - model_fn (callable): A function that returns a model. **It must not contain any arguments.** - output_transform_fn (callable): A function that transforms the output of the model into Dict. - data_gen_fn (callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + model_fn (Callable): A function that returns a model. **It must not contain any arguments.** + data_gen_fn (Callable): A function that returns a data sample in the form of Dict. **It must not contain any arguments.** + output_transform_fn (Callable): A function that transforms the output of the model into Dict. + loss_fn (Callable): a function to compute the loss from the given output. Defaults to None model_attribute (ModelAttribute): Attributes of the model. Defaults to None. """ - self[name] = (model_fn, data_gen_fn, output_transform_fn, model_attribute) + self[name] = (model_fn, data_gen_fn, output_transform_fn, loss_fn, model_attribute) def get_sub_registry(self, keyword: str): """ diff --git a/tests/kit/model_zoo/transformers/__init__.py b/tests/kit/model_zoo/transformers/__init__.py index f56ff7ad84eb..ffaf4c566df9 100644 --- a/tests/kit/model_zoo/transformers/__init__.py +++ b/tests/kit/model_zoo/transformers/__init__.py @@ -1,5 +1,6 @@ from .albert import * from .bert import * from .gpt import * +from .llama import * from .opt import * from .t5 import * diff --git a/tests/kit/model_zoo/transformers/llama.py b/tests/kit/model_zoo/transformers/llama.py new file mode 100644 index 000000000000..705bbc7364ba --- /dev/null +++ b/tests/kit/model_zoo/transformers/llama.py @@ -0,0 +1,76 @@ +import torch +import transformers + +from ..registry import ModelAttribute, model_zoo + +try: + from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel + HAS_LLAMA = True +except ImportError: + HAS_LLAMA = False + +if HAS_LLAMA: + # =============================== + # Register LLaMA + # =============================== + + def data_gen(): + # the input ids are corresponding to the sentence + # 'Hello, my dog is cute' + # + # the code is give below: + # ----------------------------------- + # from transformers import LlamaTokenizerFast + # tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + # input = 'Hello, my dog is cute' + # tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + # ----------------------------------- + + input_ids = torch.Tensor([[1, 15043, 29892, 590, 11203, 338, 274, 1082]]).long() + attention_mask = torch.Tensor([[1, 1, 1, 1, 1, 1, 1, 1]]).long() + return dict(input_ids=input_ids, attention_mask=attention_mask) + + # label is needed for casual lm + def data_gen_for_casual_lm(): + data = data_gen() + labels = data['input_ids'].clone() + data['labels'] = labels + return data + + # transform the output to a dict + output_transform_fn = lambda x: x + + # function to get the loss + loss_fn = lambda output: output.last_hidden_state.mean() + loss_fn_for_casual_lm = lambda output: output.loss + loss_fn_for_seq_classification = lambda output: output.logits.mean() + + config = LlamaConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128, + num_labels=16) + + # register the following models + # transformers.LlamaModel, + # transformers.LlamaForCausalLM, + # transformers.LlamaForSequenceClassification, + model_zoo.register(name='transformers_llama', + model_fn=lambda: transformers.LlamaModel(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn, + model_attribute=ModelAttribute(has_control_flow=True)) + model_zoo.register(name='transformers_llama_for_casual_lm', + model_fn=lambda: transformers.LlamaForCausalLM(config), + data_gen_fn=data_gen_for_casual_lm, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_casual_lm, + model_attribute=ModelAttribute(has_control_flow=True)) + model_zoo.register(name='transformers_llama_for_sequence_classification', + model_fn=lambda: transformers.LlamaForSequenceClassification(config), + data_gen_fn=data_gen, + output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_seq_classification, + model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/kit/model_zoo/transformers/t5.py b/tests/kit/model_zoo/transformers/t5.py index b81bcad90db8..689db2c40abb 100644 --- a/tests/kit/model_zoo/transformers/t5.py +++ b/tests/kit/model_zoo/transformers/t5.py @@ -6,24 +6,50 @@ # =============================== # Register single-sentence T5 # =============================== -BATCH_SIZE = 2 -SEQ_LENGTH = 16 - - -def data_gen(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - decoder_input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) - return dict(input_ids=input_ids, decoder_input_ids=decoder_input_ids) +# define data gen function def data_gen_for_encoder_only(): - input_ids = torch.zeros((BATCH_SIZE, SEQ_LENGTH), dtype=torch.int64) + # Generated from following code snippet + # + # from transformers import T5Config, T5Tokenizer + # config = T5Config(decoder_start_token_id=0) + # tokenizer = T5Tokenizer.from_pretrained("t5-small") + # input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids + input_ids = torch.Tensor([[13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5, 1]]).long() return dict(input_ids=input_ids) +def data_gen_for_conditional_generation(): + # labels is generated with the following code + # + # labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids + data = data_gen_for_encoder_only() + labels = torch.Tensor([[644, 4598, 229, 19250, 5, 1]]).long() + data['labels'] = labels + return data + + +def data_gen_for_t5_model(): + # decoder_inputs_ids is obtained with the following code + # + # decoder_input_ids = model._shift_right(input_ids) + data = data_gen_for_encoder_only() + decoder_input_ids = torch.Tensor([[0, 13959, 1566, 12, 2968, 10, 37, 629, 19, 1627, 5]]).long() + data['decoder_input_ids'] = decoder_input_ids + return data + + +# output transform function output_transform_fn = lambda x: x -config = transformers.T5Config(d_model=128, num_layers=2) +# define loss funciton +loss_fn_for_t5_model = lambda x: x.last_hidden_state.mean() +loss_fn_for_encoder_only = lambda x: x.last_hidden_state.mean() +loss_fn_for_conditional_generation = lambda x: x.loss + +# define model config +config = transformers.T5Config(d_model=128, num_layers=2, dropout_rate=0, decoder_start_token_id=0) # register the following models # transformers.T5Model, @@ -31,16 +57,19 @@ def data_gen_for_encoder_only(): # transformers.T5EncoderModel, model_zoo.register(name='transformers_t5', model_fn=lambda: transformers.T5Model(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_t5_model, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_t5_model, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_t5_for_conditional_generation', model_fn=lambda: transformers.T5ForConditionalGeneration(config), - data_gen_fn=data_gen, + data_gen_fn=data_gen_for_conditional_generation, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_conditional_generation, model_attribute=ModelAttribute(has_control_flow=True)) model_zoo.register(name='transformers_t5_encoder_model', model_fn=lambda: transformers.T5EncoderModel(config), data_gen_fn=data_gen_for_encoder_only, output_transform_fn=output_transform_fn, + loss_fn=loss_fn_for_encoder_only, model_attribute=ModelAttribute(has_control_flow=True)) diff --git a/tests/test_booster/test_mixed_precision/test_fp16_torch.py b/tests/test_booster/test_mixed_precision/test_fp16_torch.py index 963387da262b..26ce00e94869 100644 --- a/tests/test_booster/test_mixed_precision/test_fp16_torch.py +++ b/tests/test_booster/test_mixed_precision/test_fp16_torch.py @@ -11,7 +11,7 @@ def run_torch_amp(rank, world_size, port): # init dist env colossalai.launch(config=dict(), rank=rank, world_size=world_size, port=port, host='localhost') sub_model_zoo = model_zoo.get_sub_registry('timm') - for name, (model_fn, data_gen_fn, output_transform_fn, _) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in sub_model_zoo.items(): # dlrm_interactionarch has not parameters, so skip if name == 'dlrm_interactionarch': continue diff --git a/tests/test_booster/test_plugin/test_gemini_plugin.py b/tests/test_booster/test_plugin/test_gemini_plugin.py index d606d6d89bd4..d29c92926066 100644 --- a/tests/test_booster/test_plugin/test_gemini_plugin.py +++ b/tests/test_booster/test_plugin/test_gemini_plugin.py @@ -71,7 +71,7 @@ def check_gemini_plugin(init_method: str = 'none', early_stop: bool = True): passed_models = [] failed_info = {} # (model_name, error) pair - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): # These models lead to CUDA error if name in ('diffusers_auto_encoder_kl', 'diffusers_vq_model', 'diffusers_unet2d_model', 'timm_resmlp', 'timm_gmixer_12_224', 'timm_gmlp_b16_224', 'timm_mixer_b16_224', 'timm_convnext'): diff --git a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py index f70f27be2aa7..eedd8c59a3a8 100644 --- a/tests/test_booster/test_plugin/test_low_level_zero_plugin.py +++ b/tests/test_booster/test_plugin/test_low_level_zero_plugin.py @@ -61,7 +61,7 @@ def check_low_level_zero_plugin(stage: int, early_stop: bool = True): ignore_models = _AMP_ERR_MODELS + _LOW_LEVEL_ZERO_ERR_MODELS + _STUCK_MODELS skipped_models = [] - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): # FIXME(ver217): fix these models if name in ignore_models: skipped_models.append(name) diff --git a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py index fbe44e5ce6fb..1484273973ae 100644 --- a/tests/test_booster/test_plugin/test_torch_ddp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_ddp_plugin.py @@ -40,7 +40,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_ddp_plugin(): - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): if name == 'dlrm_interactionarch': continue run_fn(model_fn, data_gen_fn, output_transform_fn) diff --git a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py index 44767f051fdd..cbd5d57800db 100644 --- a/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py +++ b/tests/test_booster/test_plugin/test_torch_fsdp_plugin.py @@ -42,7 +42,7 @@ def run_fn(model_fn, data_gen_fn, output_transform_fn): def check_torch_fsdp_plugin(): - for name, (model_fn, data_gen_fn, output_transform_fn, _) in model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, _) in model_zoo.items(): if any(element in name for element in [ 'diffusers', 'deepfm_sparsearch', 'dlrm_interactionarch', 'torchvision_googlenet', 'torchvision_inception_v3' diff --git a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py index 0cbea82e083a..ccbe2da58bf2 100644 --- a/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py +++ b/tests/test_fx/test_tracer/test_hf_model/test_hf_diffuser.py @@ -47,7 +47,7 @@ def test_diffusers(): sub_model_zoo = model_zoo.get_sub_registry('diffusers') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() trace_and_compare(model_fn, data, output_transform_fn) torch.cuda.synchronize() @@ -60,7 +60,7 @@ def test_torch_diffusers(): sub_model_zoo = model_zoo.get_sub_registry('diffusers') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() model = model_fn() output = model(**data) diff --git a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py index 11302e8f36b0..117c70c84aa8 100644 --- a/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py +++ b/tests/test_fx/test_tracer/test_timm_model/test_timm_model.py @@ -56,7 +56,7 @@ def test_timm_models(): sub_model_zoo = model_zoo.get_sub_registry('timm') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): data = data_gen_fn() if attribute is not None and attribute.has_control_flow: meta_args = {k: v.to('meta') for k, v in data.items()} diff --git a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py index eafcaca10b1d..f73c5bb9a590 100644 --- a/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py +++ b/tests/test_fx/test_tracer/test_torchaudio_model/test_torchaudio_model.py @@ -16,7 +16,7 @@ def test_torchaudio_models(): sub_model_zoo = model_zoo.get_sub_registry('torchaudio') - for name, (model_fn, data_gen_fn, output_transform_fn, attribute) in sub_model_zoo.items(): + for name, (model_fn, data_gen_fn, output_transform_fn, _, attribute) in sub_model_zoo.items(): model = model_fn() trace_and_compare(model, data_gen_fn, diff --git a/tests/test_lazy/lazy_init_utils.py b/tests/test_lazy/lazy_init_utils.py index 2911012fafa8..9e0fd07828de 100644 --- a/tests/test_lazy/lazy_init_utils.py +++ b/tests/test_lazy/lazy_init_utils.py @@ -62,7 +62,7 @@ def assert_forward_equal(m1: torch.nn.Module, m2: torch.nn.Module, data_gen_fn: def check_lazy_init(entry: TestingEntry, seed: int = 42, verbose: bool = False, check_forward: bool = False) -> None: - model_fn, data_gen_fn, output_transform_fn, model_attr = entry + model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry _MyTensor._pre_op_fn = lambda *args: set_seed(seed) LazyTensor._pre_op_fn = lambda *args: set_seed(seed) ctx = LazyInitContext(tensor_cls=_MyTensor) diff --git a/tests/test_lazy/test_distribute.py b/tests/test_lazy/test_distribute.py index efa43eab5788..b3e0b22632a8 100644 --- a/tests/test_lazy/test_distribute.py +++ b/tests/test_lazy/test_distribute.py @@ -74,7 +74,7 @@ def run_dist_lazy_init(subset, seed: int = 42): if name in ('torchaudio_wav2vec2_base', 'torchaudio_hubert_base'): continue print_rank_0(name) - model_fn, data_gen_fn, output_transform_fn, model_attr = entry + model_fn, data_gen_fn, output_transform_fn, _, model_attr = entry ctx = LazyInitContext(tensor_cls=_MyTensor) with ctx: model = model_fn() diff --git a/tests/test_shardformer/__init__.py b/tests/test_shardformer/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_shardformer/test_model/__init__.py b/tests/test_shardformer/test_model/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1 diff --git a/tests/test_shardformer/test_model/_utils.py b/tests/test_shardformer/test_model/_utils.py new file mode 100644 index 000000000000..52ca7fce895b --- /dev/null +++ b/tests/test_shardformer/test_model/_utils.py @@ -0,0 +1,38 @@ +import copy + +from colossalai.shardformer import ShardConfig, ShardFormer + + +def build_model(world_size, model_fn): + # create new model + org_model = model_fn().cuda() + + # shard model + shard_config = ShardConfig(tensor_parallel_size=world_size) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + sharded_model = shard_former.shard_model(model_copy) + + return org_model, sharded_model + + +def run_forward(original_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # prepare input + data = data_gen_fn() + data = {k: v.cuda() for k, v in data.items()} + + # switch to train mode + original_model.train() + sharded_model.train() + + # run forward + org_output = original_model(**data) + org_output = output_transform_fn(org_output) + org_loss = loss_fn(org_output) + + shard_output = sharded_model(**data) + shard_output = output_transform_fn(shard_output) + shard_loss = loss_fn(shard_output) + + return org_output, org_loss, shard_output, shard_loss diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index b15f81aba52e..8b672af500bd 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -1,64 +1,22 @@ -import copy import os -import random import pytest import torch -from transformers import LlamaConfig, LlamaForCausalLM, LlamaForSequenceClassification, LlamaModel, LlamaTokenizerFast import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer import ShardConfig, ShardFormer from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") - - -def build_model(world_size, model_fn): - # create new model - config = LlamaConfig(num_hidden_layers=4, - hidden_size=128, - intermediate_size=256, - num_attention_heads=4, - max_position_embeddings=128) - org_model = model_fn(config).cuda() - - # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size) - model_copy = copy.deepcopy(org_model) - shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - sharded_model = shard_former.shard_model(model_copy) - - return org_model, sharded_model - - -def check_forward_backward(org_model, sharded_model): - # prepare input - input = 'Hello, my dog is cute' - tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') - del tokenized_input["token_type_ids"] - del tokenized_input["attention_mask"] - - # switch to train mode - org_model.train() - sharded_model.train() - - if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)): - org_output = org_model(**tokenized_input) - org_loss = org_output.last_hidden_state.mean() - shard_output = sharded_model(**tokenized_input) - shard_loss = shard_output.last_hidden_state.mean() - elif isinstance(org_model, LlamaForCausalLM): - labels = tokenized_input['input_ids'].clone() - labels[labels == tokenizer.pad_token_id] = -100 - tokenized_input['labels'] = labels - org_output = org_model(**tokenized_input) - org_loss = org_output.loss - shard_output = sharded_model(**tokenized_input) - shard_loss = shard_output.loss + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) + + # forward check assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) # run backward @@ -66,12 +24,12 @@ def check_forward_backward(org_model, sharded_model): shard_loss.backward() # check grad - if isinstance(org_model, LlamaModel): - llama_model = org_model - shard_llama_model = sharded_model - else: + if hasattr(org_model, 'model'): llama_model = org_model.model shard_llama_model = sharded_model.model + else: + llama_model = org_model + shard_llama_model = sharded_model org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad @@ -89,17 +47,11 @@ def check_llama(rank, world_size, port): disable_existing_loggers() colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - model_list = [ - LlamaModel, - # LlamaForCausalLM, - - # TODO: do not work yet - # LlamaForSequenceClassification - ] + sub_model_zoo = model_zoo.get_sub_registry('transformers_llama') - for model_fn in model_list: + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(world_size, model_fn) - check_forward_backward(org_model, sharded_model) + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) torch.cuda.empty_cache() diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 254649409c59..2698d7675c8e 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,64 +1,20 @@ -import copy import os import pytest import torch -from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast import colossalai from colossalai.logging import disable_existing_loggers -from colossalai.shardformer.shard import ShardConfig, ShardFormer from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn +from tests.kit.model_zoo import model_zoo +from tests.test_shardformer.test_model._utils import build_model, run_forward -os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' -CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) -tokenizer = T5Tokenizer.from_pretrained("t5-small") - - -def build_model(world_size, model_fn): - config = T5Config(decoder_start_token_id=0) - config.dropout_rate = 0 - org_model = model_fn(config=config).to('cuda') - shard_config = ShardConfig(tensor_parallel_size=world_size) - - # shard model - shard_config = ShardConfig(tensor_parallel_size=world_size) - model_copy = copy.deepcopy(org_model) - shard_former = ShardFormer(shard_config=shard_config) - shard_former.init_distributed() - sharded_model = shard_former.shard_model(model_copy) - - return org_model, sharded_model - - -def check_forward_backward(org_model, sharded_model): - # prepare input - input_ids = tokenizer("translate English to German: The house is wonderful.", - return_tensors="pt").input_ids.to('cuda') - labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda') - - # switch to train mode - org_model.train() - sharded_model.train() - - if isinstance(org_model, T5ForConditionalGeneration): - org_output = org_model(input_ids=input_ids, labels=labels) - org_loss = org_output.loss - shard_output = sharded_model(input_ids=input_ids, labels=labels) - shard_loss = shard_output.loss - elif isinstance(org_model, T5Model): - decoder_input_ids = org_model._shift_right(input_ids) - org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - org_loss = org_output.last_hidden_state.mean() - shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) - shard_loss = shard_output.last_hidden_state.mean() - elif isinstance(org_model, T5EncoderModel): - org_output = org_model(input_ids=input_ids) - org_loss = org_output.last_hidden_state.mean() - shard_output = sharded_model(input_ids=input_ids) - shard_loss = shard_output.last_hidden_state.mean() - - # key is sharded, so we ignore + +def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn): + # check forward + # the value "past_key_values" is sharded, so we ignore + org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn, + output_transform_fn, loss_fn) assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) # do backward @@ -81,18 +37,15 @@ def check_forward_backward(org_model, sharded_model): def check_t5(rank, world_size, port): disable_existing_loggers() - colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - model_fn_list = [ - T5Model, - T5ForConditionalGeneration, - T5EncoderModel, - ] + sub_model_zoo = model_zoo.get_sub_registry('transformers_t5') - for model_fn in model_fn_list: + for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items(): org_model, sharded_model = build_model(world_size, model_fn) - check_forward_backward(org_model, sharded_model) - torch.cuda.empty_cache() + check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn) + + torch.cuda.empty_cache() @pytest.mark.dist