From 74aaeadc3182f89666e20cad75e29f78dd970bb7 Mon Sep 17 00:00:00 2001 From: Frank Lee Date: Mon, 19 Jun 2023 17:57:37 +0800 Subject: [PATCH] [shardformer] supported T5 and its variants (#4045) --- colossalai/shardformer/README.md | 5 +- colossalai/shardformer/layer/layers.py | 26 +- colossalai/shardformer/policies/autopolicy.py | 6 + colossalai/shardformer/policies/basepolicy.py | 1 + colossalai/shardformer/policies/t5.py | 258 +++++++++--------- colossalai/shardformer/shard/sharder.py | 11 +- colossalai/testing/__init__.py | 3 +- colossalai/testing/comparison.py | 51 +++- .../test_model/test_shard_llama.py | 82 +++--- .../test_model/test_shard_t5.py | 94 ++++--- 10 files changed, 316 insertions(+), 221 deletions(-) diff --git a/colossalai/shardformer/README.md b/colossalai/shardformer/README.md index dc2946ec937f..fee4cce7a28a 100644 --- a/colossalai/shardformer/README.md +++ b/colossalai/shardformer/README.md @@ -81,8 +81,8 @@ We will follow this roadmap to develop Shardformer: - [ ] Hugging Face - [ ] NLP - [x] BERT - - [ ] T5 - - [ ] LlaMa + - [x] T5 + - [x] LlaMa - [ ] GPT2 - [ ] BLOOM - [ ] RoBERTa @@ -90,7 +90,6 @@ We will follow this roadmap to develop Shardformer: - [ ] ERNIE - [ ] GPT Neo - [ ] GPT-J - - [ ] CV - [ ] CV - [ ] ViT - [ ] BEiT diff --git a/colossalai/shardformer/layer/layers.py b/colossalai/shardformer/layer/layers.py index ad6e1896aa5e..5dbe28956d27 100644 --- a/colossalai/shardformer/layer/layers.py +++ b/colossalai/shardformer/layer/layers.py @@ -469,13 +469,14 @@ def __init__(self, dtype: torch.dtype = None, device: torch.device = None, process_group: ProcessGroup = None, + gather_output: bool = True, weight_initializer: Callable = init.normal_(), *args, **kwargs): super().__init__() self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim + self.embedding_dim = embedding_dim self.process_group = process_group self.num_partitions = dist.get_world_size(process_group) self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions) @@ -499,7 +500,9 @@ def __init__(self, @staticmethod def from_native_module(module: nn.Embedding, - process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D": + process_group: Union[ProcessGroup, List[ProcessGroup]] = None, + *args, + **kwargs) -> "Embedding1D": r""" Build a 1D parallelized Embedding from a native nn.Embedding module. """ @@ -527,7 +530,9 @@ def from_native_module(module: nn.Embedding, max_norm=max_norm, norm_type=norm_type, scale_grad_by_freq=scale_grad_by_freq, - sparse=sparse) + sparse=sparse, + *args, + **kwargs) # copy the weight with torch.no_grad(): @@ -537,7 +542,7 @@ def from_native_module(module: nn.Embedding, return embedding def reset_parameters(self, weight_initializer) -> None: - fan_in, fan_out = self.num_embeddings, self.embed_dim + fan_in, fan_out = self.num_embeddings, self.embedding_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() @@ -548,9 +553,12 @@ def _fill_padding_idx_with_zero(self) -> None: def forward(self, input_: Tensor) -> Tensor: output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs) - output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) - return output + if self.gather_output: + output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group) + return output + else: + return output_parallel class VocabParallelEmbedding1D(ParallelLayer): @@ -595,7 +603,7 @@ def __init__(self, **kwargs): super().__init__() self.num_embeddings = num_embeddings - self.embed_dim = embedding_dim + self.embedding_dim = embedding_dim self.padding_idx = padding_idx self.embed_args = args self.embed_kwargs = kwargs @@ -610,7 +618,7 @@ def __init__(self, self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition self.weight = Parameter( - torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype)) + torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype)) # offset the seed with randomizer index and rank seed = torch.random.initial_seed() @@ -662,7 +670,7 @@ def _set_tensor_parallel_attributes(self): def reset_parameters(self, weight_initializer) -> None: with seed(ParallelMode.TENSOR): - fan_in, fan_out = self.num_embeddings, self.embed_dim + fan_in, fan_out = self.num_embeddings, self.embedding_dim weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out) self._fill_padding_idx_with_zero() diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index e1b3a6a815a2..6ce0b8fb3a3d 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -48,6 +48,12 @@ class PolicyLocation: PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"), # T5 + "transformers.models.t5.modeling_t5.T5Model": + PolicyLocation(file_name="t5", class_name="T5ModelPolicy"), + "transformers.models.t5.modeling_t5.T5ForConditionalGeneration": + PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"), + "transformers.models.t5.modeling_t5.T5EncoderModel": + PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"), # GPT2 } diff --git a/colossalai/shardformer/policies/basepolicy.py b/colossalai/shardformer/policies/basepolicy.py index e4f2e9432e10..175a914a84f9 100644 --- a/colossalai/shardformer/policies/basepolicy.py +++ b/colossalai/shardformer/policies/basepolicy.py @@ -27,6 +27,7 @@ class SubModuleReplacementDescription: suffix: str target_module: ParallelModule kwargs: Dict[str, Any] = None + ignore_if_not_exist: bool = False @dataclass diff --git a/colossalai/shardformer/policies/t5.py b/colossalai/shardformer/policies/t5.py index 7b013a37845a..9c8ee59b4178 100644 --- a/colossalai/shardformer/policies/t5.py +++ b/colossalai/shardformer/policies/t5.py @@ -1,159 +1,173 @@ -from typing import Dict - +import torch import torch.nn as nn -from torch.nn import Embedding +from transformers import T5ForConditionalGeneration from transformers.models.t5.modeling_t5 import ( T5Attention, - T5Block, T5DenseActDense, T5DenseGatedActDense, T5LayerCrossAttention, T5LayerFF, T5LayerSelfAttention, - T5Model, T5Stack, ) -import colossalai.shardformer.layer.layers as col_nn +from colossalai.shardformer.layer.dropout import Dropout1D +from colossalai.shardformer.layer.layers import Embedding1D, Linear1D_Col, Linear1D_Row -from .basepolicy import Argument, Col_Layer, Dropout_Layer, Embedding_Layer, Policy, Row_Layer +from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription + +__all__ = ["T5ModelPolicy", "T5ForConditionalGenerationPolicy", "T5EncoderPolicy"] class T5ModelPolicy(Policy): - @staticmethod - def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: - print('config heads', config.num_heads) + def preprocess(self): + # reshape the embedding layer + r""" + Reshape the Embedding layer to make the embedding dimension divisible by world_size + """ + vocab_size = self.model.config.vocab_size + world_size = self.shard_config.tensor_parallel_size + if vocab_size % world_size != 0: + new_vocab_size = vocab_size + world_size - vocab_size % world_size + self.model.resize_token_embeddings(new_vocab_size) + return self.model + + def module_policy(self): return { T5Stack: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.embedding]), - T5Block: - Argument(attr_dict={}, param_funcs=[]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ]), T5LayerSelfAttention: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ), + ]), T5LayerCrossAttention: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ]), T5Attention: - Argument(attr_dict={ - "d_model": config.d_model // world_size, - "n_heads": config.num_heads // world_size, - "inner_dim": config.num_heads * config.d_kv // world_size, + ModulePolicyDescription(attribute_replacement={ + "d_model": + self.model.config.d_model // self.shard_config.tensor_parallel_size, + "n_heads": + self.model.config.num_heads // self.shard_config.tensor_parallel_size, + "inner_dim": + self.model.config.num_heads * self.model.config.d_kv // self.shard_config.tensor_parallel_size }, - param_funcs=[T5ModelPolicy.attn_layer]), + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="q", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="k", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="v", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="o", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription(suffix="relative_attention_bias", + target_module=Embedding1D, + kwargs=dict(gather_output=False), + ignore_if_not_exist=True) + ]), T5LayerFF: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ), + ]), T5DenseGatedActDense: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_gated_layer]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi_0", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wi_1", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription(suffix="wo", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)), + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ]), T5DenseActDense: - Argument(attr_dict={}, param_funcs=[T5ModelPolicy.dropout, T5ModelPolicy.dense_act_layer]), + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription( + suffix="wi", + target_module=Linear1D_Col, + ), + SubModuleReplacementDescription( + suffix="wo", + target_module=Linear1D_Row, + ), + SubModuleReplacementDescription( + suffix="dropout", + target_module=Dropout1D, + ) + ]) } - @staticmethod - def dense_gated_layer(): - return [ - Col_Layer( - suffix="wi_0", - weight="weight", - replace_layer=col_nn.Linear1D_Col, - ), - Row_Layer( - suffix="wi_1", - weight="weight", - replace_layer=col_nn.Linear1D_Row, - ), - Col_Layer(suffix="wo", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True) - ] - - @staticmethod - def dense_act_layer(): - return [ - Col_Layer( - suffix="wi", - weight="weight", - replace_layer=col_nn.Linear1D_Col, - ), - Row_Layer( - suffix="wo", - weight="weight", - replace_layer=col_nn.Linear1D_Row, - ) - ] - - @staticmethod - def attn_layer(): - return [ - Col_Layer( - suffix="q", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="k", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Col_Layer( - suffix="v", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Col, - ), - Row_Layer( - suffix="o", - weight="weight", - bias="bias", - replace_layer=col_nn.Linear1D_Row, - ), - ] - - @staticmethod - def dropout(): - return [Dropout_Layer( - suffix="dropout", - p="p", - replace_layer=col_nn.Dropout1D, - )] - - @staticmethod - def embedding(): - return [ - Embedding_Layer( - suffix="block[0].layer[0].SelfAttention.relative_attention_bias", - weight="weight", - replace_layer=col_nn.Embedding1D, - gather_output=False, - ) - ] + def new_model_class(self): + return None - -from transformers import T5ForConditionalGeneration + def postprocess(self): + return self.model class T5ForConditionalGenerationPolicy(T5ModelPolicy): - @staticmethod - def argument_policy(config, world_size): - base_argument = T5ModelPolicy.argument_policy(config, world_size) - argument = { - T5ForConditionalGeneration: Argument(attr_dict={}, param_funcs=[T5ForConditionalGenerationPolicy.lm_head]) + def module_policy(self): + policy = super().module_policy() + + new_item = { + T5ForConditionalGeneration: + ModulePolicyDescription(attribute_replacement={}, + param_replacement=[], + sub_module_replacement=[ + SubModuleReplacementDescription(suffix="lm_head", + target_module=Linear1D_Col, + kwargs=dict(gather_output=True)) + ]) } - argument.update(base_argument) - return argument - - @staticmethod - def lm_head(): - return [Col_Layer( - suffix="lm_head", - weight="weight", - replace_layer=col_nn.Linear1D_Col, - gather_output=True, - )] - -from transformers import T5EncoderModel + policy.update(new_item) + return policy -class T5EncoderModelPolicy(T5ModelPolicy): +class T5EncoderPolicy(T5ModelPolicy): pass diff --git a/colossalai/shardformer/shard/sharder.py b/colossalai/shardformer/shard/sharder.py index c948a7939d15..f6ade26b758a 100644 --- a/colossalai/shardformer/shard/sharder.py +++ b/colossalai/shardformer/shard/sharder.py @@ -175,7 +175,16 @@ def _replace_sub_module( assert target_module is not None, 'target_module should not be None' # TODO: support different parallel mode - native_sub_module = getattr_(org_layer, suffix) + native_sub_module = getattr_(org_layer, suffix, ignore=True) + + assert not isinstance(native_sub_module, target_module), \ + f"The module with suffix {suffix} has been replaced, please check the policy" + + # if it is None and we are allowed to ignore this module + # just skip + if description.ignore_if_not_exist and native_sub_module is None: + continue + replace_layer = target_module.from_native_module(native_sub_module, self.pg_manager.pg_store['tp1d'], **kwargs) diff --git a/colossalai/testing/__init__.py b/colossalai/testing/__init__.py index 9d0475ed064c..0db33361c6a0 100644 --- a/colossalai/testing/__init__.py +++ b/colossalai/testing/__init__.py @@ -3,6 +3,7 @@ assert_close_loose, assert_equal, assert_equal_in_group, + assert_hf_output_close, assert_not_equal, check_state_dict_equal, ) @@ -20,5 +21,5 @@ __all__ = [ 'assert_equal', 'assert_not_equal', 'assert_close', 'assert_close_loose', 'assert_equal_in_group', 'parameterize', 'rerun_on_exception', 'rerun_if_address_is_in_use', 'skip_if_not_enough_gpus', 'free_port', 'spawn', - 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal' + 'clear_cache_before_run', 'run_on_environment_flag', 'check_state_dict_equal', 'assert_hf_output_close' ] diff --git a/colossalai/testing/comparison.py b/colossalai/testing/comparison.py index faf61638d8bb..aeecee7f11f5 100644 --- a/colossalai/testing/comparison.py +++ b/colossalai/testing/comparison.py @@ -1,4 +1,4 @@ -from typing import OrderedDict +from typing import Any, List, OrderedDict import torch import torch.distributed as dist @@ -52,3 +52,52 @@ def check_state_dict_equal(d1: OrderedDict, d2: OrderedDict, ignore_device: bool assert torch.equal(v, d2[k]) else: assert v == d2[k] + + +def assert_hf_output_close(out1: Any, + out2: Any, + ignore_keys: List[str] = None, + track_name: str = "", + atol=1e-5, + rtol=1e-5): + """ + Check if two outputs from huggingface are equal. + + Args: + out1 (Any): the first output + out2 (Any): the second output + ignore_keys (List[str]): the keys to ignore when comparing two dicts + track_name (str): the name of the value compared, used to track the path + """ + if isinstance(out1, dict) and isinstance(out2, dict): + # if two values are dict + # we recursively check the keys + assert set(out1.keys()) == set(out2.keys()) + for k in out1.keys(): + if ignore_keys is not None and k in ignore_keys: + continue + assert_hf_output_close(out1[k], + out2[k], + track_name=f"{track_name}.{k}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol) + elif isinstance(out1, (list, tuple)) and isinstance(out2, (list, tuple)): + # if two values are list + # we recursively check the elements + assert len(out1) == len(out2) + for i in range(len(out1)): + assert_hf_output_close(out1[i], + out2[i], + track_name=f"{track_name}.{i}", + ignore_keys=ignore_keys, + atol=atol, + rtol=rtol) + elif isinstance(out1, Tensor) and isinstance(out2, Tensor): + if out1.shape != out2.shape: + raise AssertionError(f"{track_name}: shape mismatch: {out1.shape} vs {out2.shape}") + assert torch.allclose( + out1, out2, atol=atol, rtol=rtol + ), f"{track_name}: tensor value mismatch\nvalue 1: {out1}\nvalue 2: {out2}, mean error: {torch.abs(out1 - out2).mean()}" + else: + assert out1 == out2, f"{track_name}: value mismatch.\nout1: {out1}\nout2: {out2}" diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py index a3c7647fafc6..b15f81aba52e 100644 --- a/tests/test_shardformer/test_model/test_shard_llama.py +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -9,7 +9,7 @@ import colossalai from colossalai.logging import disable_existing_loggers from colossalai.shardformer import ShardConfig, ShardFormer -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") @@ -17,7 +17,11 @@ def build_model(world_size, model_fn): # create new model - config = LlamaConfig(num_hidden_layers=8) + config = LlamaConfig(num_hidden_layers=4, + hidden_size=128, + intermediate_size=256, + num_attention_heads=4, + max_position_embeddings=128) org_model = model_fn(config).cuda() # shard model @@ -30,49 +34,47 @@ def build_model(world_size, model_fn): return org_model, sharded_model -def check_forward(org_model, sharded_model): - input = 'Hello, my dog is cute' - inputs = tokenizer(input, return_tensors='pt').to('cuda') - del inputs["token_type_ids"] - del inputs["attention_mask"] - - #orgin model - org_model.eval() - org_out = org_model(**inputs) - - #shard model - sharded_model.eval() - shard_out = sharded_model(**inputs) - - assert torch.allclose( - org_out[0], shard_out[0], - atol=1e-4), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" - - -def check_backward(org_model, sharded_model): +def check_forward_backward(org_model, sharded_model): # prepare input input = 'Hello, my dog is cute' tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') del tokenized_input["token_type_ids"] del tokenized_input["attention_mask"] - labels = tokenized_input['input_ids'].clone() - labels[labels == tokenizer.pad_token_id] = -100 - tokenized_input['labels'] = labels - #orgin model + # switch to train mode org_model.train() - org_out = org_model(**tokenized_input) - org_loss = org_out.loss - org_loss.backward() - org_grad = org_model.model.layers[0].self_attn.q_proj.weight.grad - - torch.cuda.empty_cache() - #shard model sharded_model.train() - shard_out = sharded_model(**tokenized_input) - shard_loss = shard_out.loss + + if isinstance(org_model, (LlamaModel, LlamaForSequenceClassification)): + org_output = org_model(**tokenized_input) + org_loss = org_output.last_hidden_state.mean() + shard_output = sharded_model(**tokenized_input) + shard_loss = shard_output.last_hidden_state.mean() + elif isinstance(org_model, LlamaForCausalLM): + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + tokenized_input['labels'] = labels + org_output = org_model(**tokenized_input) + org_loss = org_output.loss + shard_output = sharded_model(**tokenized_input) + shard_loss = shard_output.loss + + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values'], rtol=1e-4) + + # run backward + org_loss.backward() shard_loss.backward() - shard_grad = sharded_model.model.layers[0].self_attn.q_proj.weight.grad + + # check grad + if isinstance(org_model, LlamaModel): + llama_model = org_model + shard_llama_model = sharded_model + else: + llama_model = org_model.model + shard_llama_model = sharded_model.model + + org_grad = llama_model.layers[0].self_attn.q_proj.weight.grad + shard_grad = shard_llama_model.layers[0].self_attn.q_proj.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) all_shard_grad = torch.cat(shard_grad_list, dim=0) @@ -88,23 +90,23 @@ def check_llama(rank, world_size, port): colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') model_list = [ - LlamaForCausalLM, + LlamaModel, + # LlamaForCausalLM, # TODO: do not work yet - # LlamaModel, # LlamaForSequenceClassification ] for model_fn in model_list: org_model, sharded_model = build_model(world_size, model_fn) - check_forward(org_model, sharded_model) - check_backward(org_model, sharded_model) + check_forward_backward(org_model, sharded_model) torch.cuda.empty_cache() @pytest.mark.dist @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_llama(): spawn(check_llama, 4) diff --git a/tests/test_shardformer/test_model/test_shard_t5.py b/tests/test_shardformer/test_model/test_shard_t5.py index 9b1c2678f39b..254649409c59 100644 --- a/tests/test_shardformer/test_model/test_shard_t5.py +++ b/tests/test_shardformer/test_model/test_shard_t5.py @@ -1,71 +1,72 @@ import copy import os -import random import pytest import torch -from transformers import AutoTokenizer, BertConfig, BertForMaskedLM, T5Config, T5ForConditionalGeneration, T5Tokenizer +from transformers import T5Config, T5EncoderModel, T5ForConditionalGeneration, T5Model, T5Tokenizer, T5TokenizerFast import colossalai from colossalai.logging import disable_existing_loggers from colossalai.shardformer.shard import ShardConfig, ShardFormer -from colossalai.testing import rerun_if_address_is_in_use, spawn +from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode='1d')),) tokenizer = T5Tokenizer.from_pretrained("t5-small") -def build_model(rank, world_size): - config = T5Config.from_pretrained("t5-small") +def build_model(world_size, model_fn): + config = T5Config(decoder_start_token_id=0) config.dropout_rate = 0 - org_model = T5ForConditionalGeneration.from_pretrained("t5-small", config=config).to('cuda') + org_model = model_fn(config=config).to('cuda') + shard_config = ShardConfig(tensor_parallel_size=world_size) - shardconfig = ShardConfig( - rank=rank, - world_size=world_size, - gather_output=True, - ) - - org_model_for_shard = copy.deepcopy(org_model) - - sharded_model = shard_model(org_model_for_shard, shardconfig).to('cuda') + # shard model + shard_config = ShardConfig(tensor_parallel_size=world_size) + model_copy = copy.deepcopy(org_model) + shard_former = ShardFormer(shard_config=shard_config) + shard_former.init_distributed() + sharded_model = shard_former.shard_model(model_copy) return org_model, sharded_model -def check_forward(org_model, sharded_model): - - input_ids = tokenizer("translate English to German: The house is wonderful.", - return_tensors="pt").input_ids.to('cuda') - #orgin model - org_model.eval() - org_output = org_model.generate(input_ids) - - #shard model - sharded_model.eval() - shard_output = sharded_model.generate(input_ids) - assert torch.allclose( - org_output[0], shard_output[0], - atol=1e-5), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" - - -def check_backward(org_model, sharded_model): +def check_forward_backward(org_model, sharded_model): # prepare input input_ids = tokenizer("translate English to German: The house is wonderful.", return_tensors="pt").input_ids.to('cuda') labels = tokenizer("Das Haus ist wunderbar.", return_tensors="pt").input_ids.to('cuda') - #orgin model + # switch to train mode org_model.train() - org_loss = org_model(input_ids=input_ids, labels=labels).loss - org_loss.backward() - org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad - - #shard model sharded_model.train() - shard_loss = sharded_model(input_ids=input_ids, labels=labels).loss + + if isinstance(org_model, T5ForConditionalGeneration): + org_output = org_model(input_ids=input_ids, labels=labels) + org_loss = org_output.loss + shard_output = sharded_model(input_ids=input_ids, labels=labels) + shard_loss = shard_output.loss + elif isinstance(org_model, T5Model): + decoder_input_ids = org_model._shift_right(input_ids) + org_output = org_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + org_loss = org_output.last_hidden_state.mean() + shard_output = sharded_model(input_ids=input_ids, decoder_input_ids=decoder_input_ids) + shard_loss = shard_output.last_hidden_state.mean() + elif isinstance(org_model, T5EncoderModel): + org_output = org_model(input_ids=input_ids) + org_loss = org_output.last_hidden_state.mean() + shard_output = sharded_model(input_ids=input_ids) + shard_loss = shard_output.last_hidden_state.mean() + + # key is sharded, so we ignore + assert_hf_output_close(org_output, shard_output, ignore_keys=['past_key_values']) + + # do backward + org_loss.backward() shard_loss.backward() + + # check grad equality + org_grad = org_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad = sharded_model.encoder.block[0].layer[0].SelfAttention.q.weight.grad shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)] @@ -82,16 +83,21 @@ def check_t5(rank, world_size, port): disable_existing_loggers() colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') - org_model, sharded_model = build_model(rank, world_size) - check_forward(org_model, sharded_model) - check_backward(org_model, sharded_model) + model_fn_list = [ + T5Model, + T5ForConditionalGeneration, + T5EncoderModel, + ] - torch.cuda.empty_cache() + for model_fn in model_fn_list: + org_model, sharded_model = build_model(world_size, model_fn) + check_forward_backward(org_model, sharded_model) + torch.cuda.empty_cache() @pytest.mark.dist -@pytest.mark.skip @rerun_if_address_is_in_use() +@clear_cache_before_run() def test_t5(): spawn(check_t5, 2)