Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[shardformer] adapted llama to the new API #4036

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
134 changes: 72 additions & 62 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,76 @@
import importlib
from dataclasses import dataclass

import torch.nn as nn

from .basepolicy import Policy


def build_policies():
r"""
Build the policies for the model

Return:
The dict for the policies
@dataclass
class PolicyLocation:
"""
auto_policy_dict = {}

from transformers import BertModel

from .bert import BertModelPolicy
auto_policy_dict[BertModel] = BertModelPolicy

from transformers import BertForPreTraining

from .bert import BertForPretrainingPolicy
auto_policy_dict[BertForPreTraining] = BertForPretrainingPolicy

from transformers import BertLMHeadModel

from .bert import BertLMHeadModelPolicy
auto_policy_dict[BertLMHeadModel] = BertLMHeadModelPolicy

from transformers import BertForMaskedLM

from .bert import BertForMaskedLMPolicy
auto_policy_dict[BertForMaskedLM] = BertForMaskedLMPolicy

from transformers import BertForNextSentencePrediction
PolicyLocation describes the location of a policy class.

from .bert import BertForNextSentencePredictionPolicy
auto_policy_dict[BertForNextSentencePrediction] = BertForNextSentencePredictionPolicy

from transformers import BertForSequenceClassification

from .bert import BertForSequenceClassificationPolicy
auto_policy_dict[BertForSequenceClassification] = BertForSequenceClassificationPolicy
from transformers.models.llama.modeling_llama import LlamaModel
Args:
file_name (str): The file name of the policy under colossalai.shardformer.policies
class_name (str): The class name of the policy class
"""
file_name: str
class_name: str


# we don't want to import all policies here
# as each policy file imports its own model zoo library
# we will allow the user to only import the policy file needed
_POLICY_LIST = {
# BERT
"transformers.models.bert.modeling_bert.BertModel":
PolicyLocation(file_name="bert", class_name="BertPolicy"),
"transformers.models.bert.modeling_bert.BertForPreTraining":
PolicyLocation(file_name="bert", class_name="BertForPretrainingPolicy"),
"transformers.models.bert.modeling_bert.BertForMaskedLM":
PolicyLocation(file_name="bert", class_name="BertForMaskedLMPolicy"),
"transformers.models.bert.modeling_bert.BertLMHeadModel":
PolicyLocation(file_name="bert", class_name="BertLMHeadModelPolicy"),
"transformers.models.bert.modeling_bert.BertForNextSentencePrediction":
PolicyLocation(file_name="bert", class_name="BertForNextSentencePredictionPolicy"),
"transformers.models.bert.modeling_bert.BertForSequenceClassification":
PolicyLocation(file_name="bert", class_name="BertForSequenceClassificationPolicy"),
"transformers.models.bert.modeling_bert.BertForMultipleChoice":
PolicyLocation(file_name="bert", class_name="BertForMultipleChoicePolicy"),

# LLaMA
"transformers.models.llama.modeling_llama.LlamaModel":
PolicyLocation(file_name="llama", class_name="LlamaPolicy"),
"transformers.models.llama.modeling_llama.LlamaForCausalLM":
PolicyLocation(file_name="llama", class_name="LlamaForCausalLMPolicy"),
"transformers.models.llama.modeling_llama.LlamaForSequenceClassification":
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),

# T5

# GPT2
}


def import_policy(policy_location: PolicyLocation) -> Policy:
"""
Dynamically import a Policy class based on the policy location.
"""
module_name = f"colossalai.shardformer.policies.{policy_location.file_name}"
module = importlib.import_module(module_name)
return getattr(module, policy_location.class_name)

# from .llama import LlamaPolicy
# auto_policy_dict[LlamaModel] = LlamaPolicy
# from transformers import LlamaForSequenceClassification
# from .llama import LlamaForSequenceClassificationPolicy
# auto_policy_dict[LlamaForSequenceClassification] = LlamaForSequenceClassificationPolicy
# from transformers import LlamaForCausalLM
# from .llama import LlamaForCausalLMPolicy
# auto_policy_dict[LlamaForCausalLM] = LlamaForCausalLMPolicy
# from transformers import GPT2Model
# from .gpt2 import GPT2Policy
# auto_policy_dict[GPT2Model] = GPT2Policy
# from transformers import GPT2LMHeadModel
# from .gpt2 import GPT2LMHeadModelPolicy
# auto_policy_dict[GPT2LMHeadModel] = GPT2LMHeadModelPolicy

return auto_policy_dict
def _fullname(obj):
"""
Return the full name of an object, including the module name.
"""
klass = obj.__class__
module = klass.__module__
if module == 'builtins':
return klass.__qualname__ # avoid outputs like 'builtins.str'
return module + '.' + klass.__qualname__


def get_autopolicy(model: nn.Module) -> Policy:
Expand All @@ -71,16 +83,14 @@ def get_autopolicy(model: nn.Module) -> Policy:
Return:
:class:`Policy`: The auto policy for the model
"""
auto_policy_dict = build_policies()
policy = auto_policy_dict.get(model.__class__, None)
if policy is None:
full_name = _fullname(model)
policy_location = _POLICY_LIST.get(full_name, None)

if policy_location is None:
raise NotImplementedError(
f"Auto policy for {model.__class__.__qualname__} is not implemented\n Supported models are {[i.__qualname__ for i in auto_policy_dict.keys()]}"
f"Auto policy for {model.__class__.__qualname__} is not implemented\n. Supported models are {list(_POLICY_LIST.keys())}"
)
else:
policy = import_policy(policy_location)
return policy()
return policy()


# from transformers.models.bert.modeling_bert import BertForMaskedLM, BertForPreTraining
# model = BertForPreTraining
# policy = get_autopolicy(model)
# print(policy)
5 changes: 5 additions & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ class for the example.
"""

def __init__(self) -> None:
self.shard_config = None
self.model = None
self.shard_config = None

Expand All @@ -101,6 +102,7 @@ def preprocess(self) -> nn.Module:
r"""
Perform some preprocessing of the model, like reshaping the embedding layer
"""
pass

@abstractmethod
def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
Expand Down Expand Up @@ -135,6 +137,7 @@ def module_policy(self) -> Dict[Union[str, nn.Module], ModulePolicyDescription]:
...
}
"""
pass

@abstractmethod
def new_model_class(self) -> Union[Type[nn.Module], None]:
Expand All @@ -149,10 +152,12 @@ def new_model_class(self) -> Union[Type[nn.Module], None]:
return BertModel_
```
"""
pass

@abstractmethod
def postprocess(self) -> nn.Module:
r"""
Perform some postprocessing of the model, like binding the weight of embedding layer with
the classifier layer
"""
pass
Loading