Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/vitt #4096

Merged
merged 10 commits into from
Jun 28, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ We will follow this roadmap to develop Shardformer:
- [ ] GPT Neo
- [ ] GPT-J
- [ ] CV
- [ ] ViT
- [x] ViT
- [ ] BEiT
- [ ] SwinTransformer
- [ ] SwinTransformer V2
Expand Down
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/_operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,4 +287,4 @@ def reduce_forward(input_, process_group):


def reduce_backward(input_, process_group):
return _ReduceBackward.apply(input_, process_group)
return _ReduceBackward.apply(input_, process_group)
2 changes: 1 addition & 1 deletion colossalai/shardformer/layer/layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,4 @@ def from_native_module(module: nn.LayerNorm, *args, **kwargs) -> nn.Module:
# copy weight and bias
layernorm.weight.copy_(module.weight)
layernorm.bias.copy_(module.bias)
return layernorm
return layernorm
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,4 +316,4 @@ def module_policy(self):
])
}
module_policy.update(addon_module)
return module_policy
return module_policy
2 changes: 1 addition & 1 deletion colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,4 +167,4 @@ def module_policy(self):


class T5EncoderPolicy(T5ModelPolicy):
pass
pass
96 changes: 96 additions & 0 deletions colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
from typing import Dict, Union

import torch.nn as nn

from transformers.models.vit.modeling_vit import ViTModel, ViTLayer, ViTEmbeddings, ViTAttention

from colossalai.shardformer.layer import Linear1D_Col, Linear1D_Row, Dropout1D

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

class ViTPolicy(Policy):

def preprocess(self):
# Resize embedding
vocab_size = self.model.config.vocab_size
world_size = self.shard_config.tensor_parallel_size

if vocab_size % world_size != 0:
new_vocab_size = vocab_size + world_size - vocab_size % world_size
self.model.resize_token_embeddings(new_vocab_size)

return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
return {
ViTEmbeddings:
ModulePolicyDescription(
attribute_replacement{},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="dropout",
target_module=Dropout1D,
)
]
),
ViTLayer:
ModulePolicyDescription(
attribute_replacement{
"attention.attention.num_attention_heads":
self.model.config.num_attention_heads//self.shard_config.tensor_parallel_size,
"attention.attention.all_head_size":
self.model.config.hidden_size//self.shard_config.tensor_parallel_size,
},
param_replacement=[],
sub_module_replacement=[
SubModuleReplacementDescription(
suffix="attention.attention.query",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.key",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.value",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="attention.attention.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="attention.output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="attention.output.dropout",
target_module=Dropout1D,
),
SubModuleReplacementDescription(
suffix="intermediate.dense",
target_module=Linear1D_Col,
),
SubModuleReplacementDescription(
suffix="output.dense",
target_module=Linear1D_Row,
),
SubModuleReplacementDescription(
suffix="output.dropout",
target_module=Dropout1D,
),
]
),
}

def new_model_class(self):
return None

def postprocess(self):
return self.model





2 changes: 1 addition & 1 deletion tests/test_device/test_device_mesh.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,4 +86,4 @@ def test_device_mesh_from_process_group():

if __name__ == '__main__':
test_device_mesh()
test_device_mesh_from_process_group()
test_device_mesh_from_process_group()
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_layer/test_layernorm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def test_layernorm():


if __name__ == '__main__':
test_layernorm_1d()
test_layernorm_1d()
2 changes: 1 addition & 1 deletion tests/test_shardformer/test_model/test_shard_t5.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,4 +56,4 @@ def test_t5():


if __name__ == "__main__":
test_t5()
test_t5()
55 changes: 55 additions & 0 deletions tests/test_shardformer/test_model/test_shard_vit.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import pytest
import torch

import colossalai
from colossalai.logging import disable_existing_loggers
from colossalai.testing import assert_hf_output_close, clear_cache_before_run, rerun_if_address_is_in_use, spawn
from tests.kit.model_zoo import model_zoo
from tests.test_shardformer.test_model._utils import build_model, run_forward


def check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn):
# check forward
org_output, org_loss, shard_output, shard_loss = run_forward(org_model, sharded_model, data_gen_fn,
output_transform_fn, loss_fn)
assert_hf_output_close(org_output, shard_output)

# do backward
org_loss.backward()
shard_loss.backward()

# check grad
org_grad = org_model.encoder.layer[0].attention.attention.query.weight.grad
shard_grad = sharded_model.encoder.layer[0].attention.attention.query.weight.grad

shard_grad_list = [torch.zeros([*shard_grad.shape]).to('cuda') for _ in range(2)]
shard_grad = torch.distributed.all_gather(shard_grad_list, shard_grad)
all_shard_grad = torch.cat(shard_grad_list, dim=0)

assert torch.allclose(org_loss, shard_loss,
atol=1e-5), f"shard model loss is not equal to orgin model loss\n{org_loss}\n{shard_loss}"
assert torch.allclose(org_grad, all_shard_grad,
atol=1e-5), f"shard model grad is not equal to orgin model grad\n{org_grad}\n{all_shard_grad}"


def check_vit(rank, world_size, port):
disable_existing_loggers()
colossalai.launch(config={}, rank=rank, world_size=world_size, host='localhost', port=port, backend='nccl')

sub_model_zoo = model_zoo.get_sub_registry('transformers_vit')
for name, (model_fn, data_gen_fn, output_transform_fn, loss_fn, _) in sub_model_zoo.items():
org_model, sharded_model = build_model(world_size, model_fn)
check_forward_backward(org_model, sharded_model, data_gen_fn, output_transform_fn, loss_fn)

torch.cuda.empty_cache()


@pytest.mark.dist
@rerun_if_address_is_in_use()
@clear_cache_before_run()
def test_vit():
spawn(check_vit, 4)


if __name__ == "__main__":
test_vit()
Loading