forked from hpcaitech/ColossalAI
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[shardformer] support llama model using shardformer (hpcaitech#3969)
adjust layer attr
- Loading branch information
1 parent
2b5df70
commit 2bd8f37
Showing
4 changed files
with
243 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, Dict, List, Tuple, Type | ||
|
||
import torch.nn as nn | ||
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel | ||
|
||
import colossalai.shardformer.layer.layers as col_nn | ||
|
||
from .basepolicy import Argument, Col_Layer, Policy, Row_Layer | ||
|
||
|
||
class LlamaPolicy(Policy): | ||
|
||
@staticmethod | ||
def argument_policy(config, world_size: int) -> Dict[nn.Module, Argument]: | ||
return { | ||
LlamaDecoderLayer: | ||
Argument(attr_dict={ | ||
"self_attn.hidden_size": config.hidden_size // world_size, | ||
"self_attn.num_heads": config.num_attention_heads // world_size, | ||
}, | ||
param_funcs=[LlamaPolicy.attn_layer, LlamaPolicy.mlp_layer]), | ||
LlamaModel: | ||
Argument(attr_dict={}, param_funcs=[LlamaPolicy.embeddings]) | ||
} | ||
|
||
@staticmethod | ||
def attn_layer() -> List: | ||
return [ | ||
Col_Layer( | ||
suffix="self_attn.q_proj", | ||
weight="weight", | ||
bias="bias", | ||
replace_layer=col_nn.Linear1D_Col, | ||
), | ||
Col_Layer( | ||
suffix="self_attn.k_proj", | ||
weight="weight", | ||
bias="bias", | ||
replace_layer=col_nn.Linear1D_Col, | ||
), | ||
Col_Layer( | ||
suffix="self_attn.v_proj", | ||
weight="weight", | ||
bias="bias", | ||
replace_layer=col_nn.Linear1D_Col, | ||
), | ||
Row_Layer( | ||
suffix="self_attn.o_proj", | ||
weight="weight", | ||
bias="bias", | ||
replace_layer=col_nn.Linear1D_Row, | ||
) | ||
] | ||
|
||
@staticmethod | ||
def mlp_layer() -> List: | ||
return [ | ||
Col_Layer( | ||
suffix="mlp.gate_proj", | ||
weight="weight", | ||
bias="bias", | ||
replace_layer=col_nn.Linear1D_Col, | ||
gather_output=True, | ||
), | ||
Col_Layer( | ||
suffix="mlp.up_proj", | ||
weight="weight", | ||
bias="bias", | ||
replace_layer=col_nn.Linear1D_Row, | ||
gather_output=True, | ||
), | ||
Col_Layer( | ||
suffix="mlp.down_proj", | ||
weight="weight", | ||
bias="bias", | ||
replace_layer=col_nn.Linear1D_Col, | ||
gather_output=True, | ||
), | ||
] | ||
|
||
@staticmethod | ||
def embeddings() -> List: | ||
return [Col_Layer( | ||
suffix="embed_tokens", | ||
weight="weight", | ||
replace_layer=col_nn.VocabParallelEmbedding1D, | ||
)] | ||
|
||
from transformers import LlamaForCausalLM | ||
|
||
|
||
class LlamaForCausalLMPolicy(LlamaPolicy): | ||
|
||
@staticmethod | ||
def argument(config, world_size): | ||
llamapolicy = LlamaPolicy.argument_policy(config, world_size) | ||
argument = {LlamaForCausalLM: Argument(attr_dict={}, param_funcs=[LlamaForCausalLMPolicy.lm_head])} | ||
argument.update(llamapolicy) | ||
|
||
@staticmethod | ||
def lm_head() -> List: | ||
return [Col_Layer(suffix="lm_head", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)] | ||
|
||
|
||
from transformers import LlamaForSequenceClassification | ||
|
||
|
||
class LlamaForSequenceClassificationPolicy(LlamaPolicy): | ||
|
||
@staticmethod | ||
def argument(config, world_size): | ||
llamapolicy = LlamaPolicy.argument_policy(config, world_size) | ||
argument = { | ||
LlamaForSequenceClassification: | ||
Argument(attr_dict={}, param_funcs=[LlamaForSequenceClassificationPolicy.score]) | ||
} | ||
argument.update(llamapolicy) | ||
|
||
@staticmethod | ||
def score() -> List: | ||
return [Col_Layer(suffix="score", weight="weight", replace_layer=col_nn.Linear1D_Col, gather_output=True)] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
import copy | ||
import os | ||
import random | ||
|
||
import pytest | ||
import torch | ||
from transformers import AutoTokenizer, LlamaConfig, LlamaForCausalLM, LlamaModel, LlamaTokenizerFast | ||
|
||
import colossalai | ||
from colossalai.logging import disable_existing_loggers | ||
from colossalai.shardformer.shard import ShardConfig, shard_model | ||
from colossalai.testing import rerun_if_address_is_in_use, spawn | ||
|
||
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = 'true' | ||
CONFIG = dict(parallel=dict(data=1, pipeline=1, tensor=dict(size=4, mode='1d')),) | ||
tokenizer = LlamaTokenizerFast.from_pretrained("hf-internal-testing/llama-tokenizer") | ||
|
||
|
||
def build_model(rank, world_size): | ||
cfg = LlamaConfig(num_hidden_layers=16) | ||
org_model = LlamaForCausalLM(cfg) | ||
|
||
shardconfig = ShardConfig( | ||
rank=rank, | ||
world_size=world_size, | ||
gather_output=True, | ||
) | ||
org_model = org_model.to('cuda') | ||
|
||
org_model_forshard = copy.deepcopy(org_model) | ||
sharded_model = shard_model(org_model_forshard, shardconfig).to('cuda') | ||
|
||
return org_model, sharded_model | ||
|
||
|
||
def check_forward(org_model, sharded_model): | ||
input = 'Hello, my dog is cute' | ||
inputs = tokenizer(input, return_tensors='pt').to('cuda') | ||
del inputs["token_type_ids"] | ||
del inputs["attention_mask"] | ||
#orgin model | ||
org_model.eval() | ||
org_out = org_model(**inputs) | ||
|
||
#shard model | ||
sharded_model.eval() | ||
shard_out = sharded_model(**inputs) | ||
|
||
assert torch.allclose( | ||
org_out[0], shard_out[0], | ||
atol=1e-4), f"shard model output is not equal to orgin model output\n{org_out[0]}\n{shard_out[0]}" | ||
|
||
|
||
def check_backward(org_model, sharded_model): | ||
# prepare input | ||
input = 'Hello, my dog is cute' | ||
tokenized_input = tokenizer(input, return_tensors='pt').to('cuda') | ||
del tokenized_input["token_type_ids"] | ||
del tokenized_input["attention_mask"] | ||
labels = tokenized_input['input_ids'].clone() | ||
labels[labels == tokenizer.pad_token_id] = -100 | ||
tokenized_input['labels'] = labels | ||
|
||
#orgin model | ||
org_model.train() | ||
org_out = org_model(**tokenized_input) | ||
org_loss = org_out.loss | ||
org_loss.backward() | ||
org_grad = org_model.model.layers[0].self_attn.q_proj.weight.grad | ||
|
||
torch.cuda.empty_cache() | ||
#shard model | ||
sharded_model.train() | ||
shard_out = sharded_model(**tokenized_input) | ||
shard_loss = shard_out.loss | ||
shard_loss.backward() | ||
shard_grad = sharded_model.model.layers[0].self_attn.q_proj.weight.grad | ||
shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(4)] | ||
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad) | ||
all_shard_grad = torch.cat(shard_grad_list, dim=0) | ||
|
||
assert torch.allclose(org_loss, shard_loss, | ||
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}" | ||
assert torch.allclose(org_grad, all_shard_grad, | ||
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{shard_grad}" | ||
|
||
|
||
def check_llama(rank, world_size, port): | ||
disable_existing_loggers() | ||
colossalai.launch(config=CONFIG, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl') | ||
|
||
org_model, sharded_model = build_model(rank, world_size) | ||
check_forward(org_model, sharded_model) | ||
check_backward(org_model, sharded_model) | ||
|
||
torch.cuda.empty_cache() | ||
|
||
|
||
@pytest.mark.dist | ||
@rerun_if_address_is_in_use() | ||
def test_llama(): | ||
spawn(check_llama, 4) | ||
|
||
|
||
if __name__ == "__main__": | ||
test_llama() |