From 3e7148296fd7048aa391488cc795e66a5ce7b352 Mon Sep 17 00:00:00 2001 From: wukong1992 Date: Fri, 9 Jun 2023 10:21:55 +0800 Subject: [PATCH] [shardformer] support llama model using shardformer adjust layer attr --- .../shardformer/layer/dist_crossentropy.py | 2 +- colossalai/shardformer/policies/autopolicy.py | 14 ++ colossalai/shardformer/policies/llama.py | 122 ++++++++++++++++++ .../test_model/test_shard_llama.py | 106 +++++++++++++++ 4 files changed, 243 insertions(+), 1 deletion(-) create mode 100644 colossalai/shardformer/policies/llama.py create mode 100644 tests/test_shardformer/test_model/test_shard_llama.py diff --git a/colossalai/shardformer/layer/dist_crossentropy.py b/colossalai/shardformer/layer/dist_crossentropy.py index 05c04bb545c1..ff05209fefe8 100644 --- a/colossalai/shardformer/layer/dist_crossentropy.py +++ b/colossalai/shardformer/layer/dist_crossentropy.py @@ -21,7 +21,7 @@ def forward(ctx, vocab_logits: torch.Tensor, target: torch.Tensor, ignore_index: and can be rewrite as: loss = log(sum(exp(x[i])) - x[class] - To avoid the `nan` of log(sim(exp(x[i]))), we minus the max of x[i] + To avoid the `nan` of log(sum(exp(x[i]))), we minus the max of x[i] Args: vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is diff --git a/colossalai/shardformer/policies/autopolicy.py b/colossalai/shardformer/policies/autopolicy.py index 54cc63ba124f..27fd09b4561b 100644 --- a/colossalai/shardformer/policies/autopolicy.py +++ b/colossalai/shardformer/policies/autopolicy.py @@ -19,6 +19,20 @@ def build_policies(): from .bert import BertForSequenceClassificationPolicy auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy + from transformers.models.llama.modeling_llama import LlamaModel + + from .llama import LlamaPolicy + auto_policy_dict[LlamaModel] = LlamaPolicy + + from transformers import LlamaForSequenceClassification + + from .llama import LlamaForSequenceClassificationPolicy + auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy + + from transformers import LlamaForCausalLM + + from .llama import LlamaForCausalLMPolicy + auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy from transformers import GPT2Model diff --git a/colossalai/shardformer/policies/llama.py b/colossalai/shardformer/policies/llama.py new file mode 100644 index 000000000000..fac6765cdcb5 --- /dev/null +++ b/colossalai/shardformer/policies/llama.py @@ -0,0 +1,122 @@ +from dataclasses import dataclass +from typing import Any, Callable, Dict, List, Tuple, Type + +import torch.nn as nn +from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel + +import colossalai.shardformer.layer.layers as col_nn + +from .basepolicy import Argument, Col_Layer, Policy, Row_Layer + + +class LlamaPolicy(Policy): + + @staticmethod + def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: + return { + LlamaDecoderLayer: + Argument(attr_dict={ + "self_attn.hidden_size": config.hidden_size // world_size, + "self_attn.num_heads": config.num_attention_heads // world_size, + }, + param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]), + LlamaModel: + Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings]) + } + + @staticmethod + def attn_layer() -> List: + return [ + Col_Layer( + suffix="self_attn.q_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + suffix="self_attn.k_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Col_Layer( + suffix="self_attn.v_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + ), + Row_Layer( + suffix="self_attn.o_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Row, + ) + ] + + @staticmethod + def mlp_layer() -> List: + return [ + Col_Layer( + suffix="mlp.gate_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ), + Col_Layer( + suffix="mlp.up_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Row, + gather_output=True, + ), + Col_Layer( + suffix="mlp.down_proj", + weight="weight", + bias="bias", + replace_layer=col_nn.Linear1D_Col, + gather_output=True, + ), + ] + + @staticmethod + def embeddings() -> List: + return [Col_Layer( + suffix="embed_tokens", + weight="weight", + replace_layer=col_nn.VocabParallelEmbedding1D, + )] + +from transformers import LlamaForCausalLM + + +class LlamaForCausalLMPolicy(LlamaPolicy): + + @staticmethod + def argument(config, world_size): + llamapolicy = LlamaPolicy.argument_policy(config, world_size) + argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])} + argument.update(llamapolicy) + + @staticmethod + def lm_head() -> List: + return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)] + + +from transformers import LlamaForSequenceClassification + + +class LlamaForSequenceClassificationPolicy(LlamaPolicy): + + @staticmethod + def argument(config, world_size): + llamapolicy = LlamaPolicy.argument_policy(config, world_size) + argument = { + LlamaForSequenceClassification: + Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score]) + } + argument.update(llamapolicy) + + @staticmethod + def score() -> List: + return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)] diff --git a/tests/test_shardformer/test_model/test_shard_llama.py b/tests/test_shardformer/test_model/test_shard_llama.py new file mode 100644 index 000000000000..689898bbbad2 --- /dev/null +++ b/tests/test_shardformer/test_model/test_shard_llama.py @@ -0,0 +1,106 @@ +import copy +import os +import random + +import pytest +import torch +from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaTokenizerFast + +import colossalai +from colossalai.logging import disable_existing_loggers +from colossalai.shardformer.shard import ShardConfig, shard_model +from colossalai.testing import rerun_if_address_is_in_use, spawn + +os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' +CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=4, mode='1d')),) +tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") + + +def build_model(rank, world_size): + cfg = LlamaConfig(num_hidden_layers=16) + org_model = LlamaForCausalLM(cfg) + + shardconfig = ShardConfig( + rank=rank, + world_size=world_size, + gather_output=True, + ) + org_model = org_model.to('cuda') + + org_model_forshard = copy.deepcopy(org_model) + sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') + + return org_model, sharded_model + + +def check_forward(org_model, sharded_model): + input = 'Hello, my dog is cute' + inputs = tokenizer(input, return_tensors='pt').to('cuda') + del inputs["token_type_ids"] + del inputs["attention_mask"] + #orgin model + org_model.eval() + org_out = org_model(**inputs) + + #shard model + sharded_model.eval() + shard_out = sharded_model(**inputs) + + assert torch.allclose( + org_out[0], shard_out[0], + atol=1e-4), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" + + +def check_backward(org_model, sharded_model): + # prepare input + input = 'Hello, my dog is cute' + tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') + del tokenized_input["token_type_ids"] + del tokenized_input["attention_mask"] + labels = tokenized_input['input_ids'].clone() + labels[labels == tokenizer.pad_token_id] = -100 + tokenized_input['labels'] = labels + + #orgin model + org_model.train() + org_out = org_model(**tokenized_input) + org_loss = org_out.loss + org_loss.backward() + org_grad = org_model.model.layers[0].self_attn.q_proj.weight.grad + + torch.cuda.empty_cache() + #shard model + sharded_model.train() + shard_out = sharded_model(**tokenized_input) + shard_loss = shard_out.loss + shard_loss.backward() + shard_grad = sharded_model.model.layers[0].self_attn.q_proj.weight.grad + shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] + shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) + all_shard_grad = torch.cat(shard_grad_list, dim=0) + + assert torch.allclose(org_loss, shard_loss, + atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" + assert torch.allclose(org_grad, all_shard_grad, + atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" + + +def check_llama(rank, world_size, port): + disable_existing_loggers() + colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') + + org_model, sharded_model = build_model(rank, world_size) + check_forward(org_model, sharded_model) + check_backward(org_model, sharded_model) + + torch.cuda.empty_cache() + + +@pytest.mark.dist +@rerun_if_address_is_in_use() +def test_llama(): + spawn(check_llama, 4) + + +if __name__ == "__main__": + test_llama()