Skip to content

Commit

Permalink
[shardformer] import huggingface implicitly (hpcaitech#4101)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee authored and ver217 committed Jul 13, 2023
1 parent 0fa4c7f commit 257d468
Show file tree
Hide file tree
Showing 9 changed files with 91 additions and 38 deletions.
2 changes: 2 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@

from .basepolicy import Policy

__all__ = ["PolicyLocation", "get_autopolicy", "import_policy"]


@dataclass
class PolicyLocation:
Expand Down
2 changes: 2 additions & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

from ..shard.shard_config import ShardConfig

__all__ = ["ParallelModule", "SubModuleReplacementDescription", "ModulePolicyDescription", "Policy"]


class ParallelModule():

Expand Down
30 changes: 21 additions & 9 deletions colossalai/shardformer/policies/bert.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
import torch.nn as nn
from transformers.models.bert.modeling_bert import (
BertEmbeddings,
BertForMultipleChoice,
BertForSequenceClassification,
BertForTokenClassification,
BertLayer,
BertLMPredictionHead,
)

import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'BertPolicy', 'BertModelPolicy', 'BertForPretrainingPolicy', 'BertLMHeadModelPolicy', 'BertForMaskedLMPolicy',
'BertForNextSentencePredictionPolicy', 'BertForSequenceClassificationPolicy', 'BertForTokenClassificationPolicy',
'BertForMultipleChoicePolicy'
]


class BertPolicy(Policy):

Expand All @@ -33,6 +31,8 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.bert.modeling_bert import BertEmbeddings, BertLayer

base_policy = {
BertLayer:
ModulePolicyDescription(
Expand Down Expand Up @@ -123,7 +123,7 @@ def module_policy(self):

def new_model_class(self):
# do nothing
return self.model
return None

def postprocess(self):
return self.model
Expand All @@ -143,6 +143,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead

module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
Expand Down Expand Up @@ -184,6 +186,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead

module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
Expand Down Expand Up @@ -221,6 +225,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertLMPredictionHead

module_policy = super().module_policy()
addon_module = {
BertLMPredictionHead:
Expand Down Expand Up @@ -261,6 +267,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertForSequenceClassification

module_policy = super().module_policy()
addon_module = {
BertForSequenceClassification:
Expand All @@ -284,6 +292,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertForTokenClassification

module_policy = super().module_policy()
addon_module = {
BertForTokenClassification:
Expand Down Expand Up @@ -314,6 +324,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.bert.modeling_bert import BertForMultipleChoice

module_policy = super().module_policy()
addon_module = {
BertForMultipleChoice:
Expand Down
14 changes: 12 additions & 2 deletions colossalai/shardformer/policies/gpt2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import torch.nn as nn
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2DoubleHeadsModel, GPT2LMHeadModel, GPT2Model

import colossalai.shardformer.layer as col_nn

from .._utils import getattr_, setattr_
from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'GPT2Policy', 'GPT2ModelPolicy', 'GPT2LMHeadModelPolicy', 'GPT2DoubleHeadsModelPolicy',
'GPT2ForTokenClassificationPolicy', 'GPT2ForSequenceClassificationPolicy'
]


class GPT2Policy(Policy):

Expand All @@ -25,7 +29,9 @@ def preprocess(self):
return self.model

def module_policy(self):
base_policy = {
from transformers.models.gpt2.modeling_gpt2 import GPT2Block, GPT2Model

return {
GPT2Model:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
Expand Down Expand Up @@ -125,6 +131,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel

module_policy = super().module_policy()
addon_module = {
GPT2LMHeadModel:
Expand Down Expand Up @@ -156,6 +164,8 @@ def __init__(self) -> None:
super().__init__()

def module_policy(self):
from transformers.models.gpt2.modeling_gpt2 import GPT2DoubleHeadsModel

module_policy = super().module_policy()
addon_module = {
GPT2DoubleHeadsModel:
Expand Down
12 changes: 9 additions & 3 deletions colossalai/shardformer/policies/llama.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
from typing import Dict, Union

import torch.nn as nn
from transformers import LlamaForCausalLM, LlamaForSequenceClassification
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

from colossalai.shardformer.layer import FusedRMSNorm, Linear1D_Col, Linear1D_Row, VocabParallelEmbedding1D

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['LlamaPolicy', 'LlamaForCausalLMPolicy', 'LlamaForSequenceClassificationPolicy']


class LlamaPolicy(Policy):

Expand All @@ -26,7 +26,9 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
base_policy = {
from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaModel

return {
LlamaDecoderLayer:
ModulePolicyDescription(
attribute_replacement={
Expand Down Expand Up @@ -109,6 +111,8 @@ def postprocess(self):
class LlamaForCausalLMPolicy(LlamaPolicy):

def module_policy(self):
from transformers import LlamaForCausalLM

policy = super().module_policy()
# add a new item for casual lm
new_item = {
Expand All @@ -128,6 +132,8 @@ def module_policy(self):
class LlamaForSequenceClassificationPolicy(LlamaPolicy):

def module_policy(self):
from transformers import LlamaForSequenceClassification

policy = super().module_policy()

# add a new item for sequence classification
Expand Down
17 changes: 9 additions & 8 deletions colossalai/shardformer/policies/opt.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from transformers.models.opt.modeling_opt import (
OPTAttention,
OPTDecoder,
OPTDecoderLayer,
OPTForCausalLM,
OPTForSequenceClassification,
)

from colossalai.shardformer.layer import Embedding1D, FusedLayerNorm, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = [
'OPTPolicy', 'OPTModelPolicy', 'OPTForCausalLMPolicy', 'OPTForSequenceClassificationPolicy',
'OPTForQuestionAnsweringPolicy'
]


class OPTPolicy(Policy):

Expand All @@ -29,6 +26,8 @@ def preprocess(self):
return self.model

def module_policy(self):
from transformers.models.opt.modeling_opt import OPTAttention, OPTDecoder, OPTDecoderLayer

base_policy = {
OPTDecoder:
ModulePolicyDescription(attribute_replacement={},
Expand Down Expand Up @@ -111,6 +110,8 @@ def __init__(self) -> None:
class OPTForCausalLMPolicy(OPTPolicy):

def module_policy(self):
from transformers.models.opt.modeling_opt import OPTForCausalLM

policy = super().module_policy()
new_item = {
OPTForCausalLM:
Expand Down
27 changes: 14 additions & 13 deletions colossalai/shardformer/policies/t5.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,4 @@
from transformers import T5ForConditionalGeneration
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)

from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, FusedRMSNorm, Linear1D_Col, Linear1D_Row
from colossalai.shardformer.layer import DropoutForParallelInput, Embedding1D, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

Expand All @@ -34,7 +23,17 @@ def preprocess(self):
return self.model

def module_policy(self):
base_policy = {
from transformers.models.t5.modeling_t5 import (
T5Attention,
T5DenseActDense,
T5DenseGatedActDense,
T5LayerCrossAttention,
T5LayerFF,
T5LayerSelfAttention,
T5Stack,
)

return {
T5Stack:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
Expand Down Expand Up @@ -165,6 +164,8 @@ def postprocess(self):
class T5ForConditionalGenerationPolicy(T5ModelPolicy):

def module_policy(self):
from transformers import T5ForConditionalGeneration

policy = super().module_policy()

new_item = {
Expand Down
7 changes: 5 additions & 2 deletions colossalai/shardformer/policies/vit.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import Dict, Union

import torch.nn as nn
from transformers.models.vit.modeling_vit import ViTAttention, ViTEmbeddings, ViTLayer, ViTModel

from colossalai.shardformer.layer import DropoutForReplicatedInput, FusedLayerNorm, Linear1D_Col, Linear1D_Row

from .basepolicy import ModulePolicyDescription, Policy, SubModuleReplacementDescription

__all__ = ['ViTPolicy']


class ViTPolicy(Policy):

Expand All @@ -25,7 +26,9 @@ def preprocess(self):
return self.model

def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
base_policy = {
from transformers.models.vit.modeling_vit import ViTEmbeddings, ViTLayer

return {
ViTEmbeddings:
ModulePolicyDescription(attribute_replacement={},
param_replacement=[],
Expand Down
18 changes: 17 additions & 1 deletion colossalai/shardformer/shard/shard_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ class ShardConfig:
"""
tensor_parallel_process_group: int = None
enable_fused_normalization: bool = False
enable_all_optimization: bool = False

# TODO: add support for tensor parallel
# pipeline_parallel_size: int
Expand All @@ -27,6 +28,21 @@ class ShardConfig:
# inference_only: bool = True
# gather_output: bool = True

@property
def tensor_parallel_size(self):
return self._tensor_parallel_size

def __post_init__(self):
# get the parallel size
self.tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)
self._tensor_parallel_size = dist.get_world_size(self.tensor_parallel_process_group)

# turn on all optimization if all_optimization is set to True
if self.enable_all_optimization:
self._turn_on_all_optimization()

def _turn_on_all_optimization(self):
"""
Turn on all optimization.
"""
# you can add all the optimization flag here
self.fused_layernorm = True

0 comments on commit 257d468

Please sign in to comment.