Skip to content

Commit

Permalink
[shardformer] Add dropout layer in shard model and refactor policy api (
Browse files Browse the repository at this point in the history
hpcaitech#3949)

* add dist dropout in model

* update docstring and bert policy with dropout

* refactor basepolicy and sharded, update bert

* update format

* update gpt2 policy

* update bert policy

* remove unused code

* update readme for new policy usage
  • Loading branch information
FoolPlayer authored and flybird11111 committed Jul 3, 2023
1 parent 7865eeb commit 2b5df70
Show file tree
Hide file tree
Showing 7 changed files with 255 additions and 186 deletions.
80 changes: 47 additions & 33 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ colossalai run --nproc_per_node 2 --master_port 29500 test.py --config config.py

## 💡 Policy

If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model.
If you wanna parallel the model in a custom way, just overwrite the policy class for the Hugging Face model. Please refer to any policy that we have pre-established, like [bert policy](./policies/bert.py) or [gpt2 policy](./policies/gpt2.py).

You should do:

Expand All @@ -68,7 +68,7 @@ You should do:
- Shardformer will inject the model according to this method. If you need to modify the forward or backward progress (like distributed corssentropy loss in Bert) you need to overwrite this method.
4. Overwrite or add the param functions
- These functions use a suffix to record the path of weight or bias for the layer.
- The return is a list contains some `Col_Layer` or `Row_Layer` objects, which means slice along col and row respectively.
- The return is a list contains some `Col_Layer`, `Row_Layer` or `Dropout_Layer` objects, which means slice along col and row respectively or as dropout layer, refer to CLASS `Layer` for more details.
5. Overwrite `binding_policy` (Optional)
- Overwrite to specify Shardformer will bind some weight between layers, like embedding and unembedding layers.
- This function will return a dict, the key and value are the suffix of weight need to be binded.
Expand Down Expand Up @@ -123,7 +123,7 @@ class CustomPolicy(Policy):
raise NotImplementedError

@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
r"""
Return the dict for the inject model
Expand All @@ -133,12 +133,12 @@ class CustomPolicy(Policy):
(OrignModel, CustomModel)
in `CustomModel`, we can overwrite the forward and backward process
"""
return ()
return None

@staticmethod
def binding_policy() -> Dict:
def binding_policy() -> Union[Dict[str, str], None]:
r"""
Return the dict for the binding model
Return the dict for the binding model, None means no need to bind
Return:
This method should return the binding relationship for some layers share the weight or bias,
Expand All @@ -148,69 +148,70 @@ class CustomPolicy(Policy):
"bert.embeddings.word_embeddings.weight": "cls.predictions.decoder.weight",
}
"""
return NotImplementedError
return None

@staticmethod
def attn_in() -> List:
"""
def attn_in() -> Union[List, None]:
r"""
Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
``Layer`` for no slicing, ``Col_Layer`` for col slicing, ``Row_Layer`` for row slicing. And the parameters
in ``Layer`` object can refer to the ``Layer`` class.
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
return None

@staticmethod
def attn_out() -> List:
"""
def attn_out() -> Union[List, None]:
r"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

@staticmethod
def mlp_in() -> List:
"""
def mlp_in() -> Union[List, None]:
r"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

@staticmethod
def mlp_out() -> List:
"""
def mlp_out() -> Union[List, None]:
r"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

@staticmethod
def embedding() -> List:
"""
def embedding() -> Union[List, None]:
r"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

@staticmethod
def unembedding() -> List:
"""
Partially slice the embedding layer
vocab_size->vocab_size//gpu_nums
def unembedding() -> Union[List, None]:
r"""
Partially slice the embedding layer, None means there is no unembedding layer
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

```

Expand All @@ -232,51 +233,64 @@ class CustomPolicy(Policy):
- CLASS `Layer`:

Parameters:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
- suffix: (str): the suffix of the layer to indicate the attribute of the layer.
- replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
- ignore (bool): Whether to ignore this layer if it is not in the model
- reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in], but in GPT2 `Conv1D` layer is [in, out] which is reversed.
- n_cast (int): The number of weight will cast to, like q, k, v in attention layer, n_cast should be 3. commonly in TP, we just chunk the weight with the number of devices, but in multi-head attention, we need to chunk the weight with the number of $ devices * n\_head $, and each device should have a part of Q, K and V weight.

This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class.
This class is a base class used to specify the replacement policy and the suffix the layer for a particular layer.

CLASS `Col_Layer(Layer)`:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
- gather_output (bool): Whether the output of this layer can be gathered, like the last layer can be gathered, but most of the time, the intermediate layers of the model do not need to be gathered.

This class inherited from `Layer`, representing the layer will be sliced along column.
This class inherited from `Layer`, representing the layer will be sliced along colum and indicate the attributes of weight and bias. Setting `bias` to `None` means ignoring bias, regardless of whether or not it originally exists.

CLASS `Row_Layer(Layer)`:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer

This class inherited from `Layer`, representing the layer will be sliced along row.
This class inherited from `Layer`, representing the layer will be sliced along row. Just like `Col_Layer` but in tensor parrallel, there is no need to gather the output of layer sliced by row.

- CLASS `Policy`:

In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class.
- `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`......

These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions.

- `Policy.argument_policy()`

In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach.

- `Policy.inject_policy()`

This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else.

- `Policy.binding_policy()`

This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters.


- CLASS `ModelSharder(model, policy)`:

This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model.

- `ModelShard.inject_model()`

This function is used to inject the model to modify the forward and backward progress.

- `ModelShard.replace_layer()`

This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication.

- `ModelShard.bind_layer()`

This function is used to help different layers share weight or bias.


- CLASS `Slicer`:

This class is used to slice tensor according to policy.
Expand Down
68 changes: 43 additions & 25 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# part of code modified from https://github.com/tunib-ai/parallelformers

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Tuple, Type
from typing import Any, Callable, Dict, List, Tuple, Union

import torch.nn as nn

Expand All @@ -25,8 +25,7 @@ class Layer:
The layer object for the policy
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
suffix: (str): the suffix of the layer.
replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
ignore (bool): Whether to ignore this layer if it is not in the model
reversed (bool): Whether the weight in layer is reversed, commonly the weight in `torch.nn.Linear` is [out, in],
Expand All @@ -35,8 +34,7 @@ class Layer:
but in multi-head attention, we need to chunk the weight with the number of devices * n_head, and
each device should have a part of Q, K and V weight.
"""
weight: str = None
bias: str = None
suffix: str = None
replace_layer: Any = None
ignore: bool = False
reversed: bool = False
Expand All @@ -46,20 +44,40 @@ class Layer:
@dataclass
class Col_Layer(Layer):
r"""
Class for col shard layer in MegatronLM
Class for col shard layer in tensor parrallel
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
gather_output (bool): Whether to gather the output of the layer
"""
weight: str = None
bias: str = None
gather_output: bool = False


@dataclass
class Row_Layer(Layer):
r"""
Class for col shard layer in MegatronLM
Class for col shard layer in tensor parrallel
Args:
weight (str): The weight suffix of the layer
bias (str): The bias suffix of the layer
"""
pass
weight: str = None
bias: str = None


@dataclass
class Dropout_Layer(Layer):
r"""
Class for dropout layer in tensor parrallel
Args:
p (str): The dropout rate suffix of the layer
"""
p: str = None


class Policy():
Expand All @@ -82,14 +100,14 @@ class for the example.
"""

@staticmethod
def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument]:
def argument_policy(model_config, world_size: int) -> Dict[nn.Module, Argument]:
r"""
Return the dict for the modify policy, the key is the original layer class and the value is the
argument for the modify layer
Args:
model_config (:class:`tansformer.Config`): The config of transformer model
shard_config (:class:`ShardConfig`): The config for sharding model
world_size (int)): The world size of sharding model
Return:
Dict for the modify policy,
Expand Down Expand Up @@ -126,7 +144,7 @@ def argument_policy(model_config, shard_config: int) -> Dict[nn.Module, Argument
raise NotImplementedError

@staticmethod
def inject_policy() -> Tuple[nn.Module, nn.Module]:
def inject_policy() -> Union[Tuple[nn.Module, nn.Module], None]:
r"""
Return the dict for the inject model
Expand All @@ -139,9 +157,9 @@ def inject_policy() -> Tuple[nn.Module, nn.Module]:
return None

@staticmethod
def binding_policy() -> Dict:
def binding_policy() -> Union[Dict[str, str], None]:
r"""
Return the dict for the binding model
Return the dict for the binding model, None means no need to bind
Return:
This method should return the binding relationship for some layers share the weight or bias,
Expand All @@ -154,7 +172,7 @@ def binding_policy() -> Dict:
return None

@staticmethod
def attn_in() -> List:
def attn_in() -> Union[List, None]:
r"""
Attention qkv layer
In this kind of method, we should return the list of ``Layer`` object, each ``Layer`` object should be
Expand All @@ -164,52 +182,52 @@ def attn_in() -> List:
Returns:
List[Layer]: List of layer object, each layer is the new
"""
return NotImplementedError
return None

@staticmethod
def attn_out() -> List:
def attn_out() -> Union[List, None]:
r"""
Attention output projection layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

@staticmethod
def mlp_in() -> List:
def mlp_in() -> Union[List, None]:
r"""
h -> 4h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

@staticmethod
def mlp_out() -> List:
def mlp_out() -> Union[List, None]:
r"""
4h -> h mlp layer
Returns:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

@staticmethod
def embedding() -> List:
def embedding() -> Union[List, None]:
r"""
Partially slice the embedding layer
Return:
List[Layer]: List of layer object
"""
return NotImplementedError
return None

@staticmethod
def unembedding() -> List:
def unembedding() -> Union[List, None]:
r"""
Partially slice the embedding layer
Partially slice the embedding layer, None means there is no unembedding layer
Return:
List[Layer]: List of layer object
Expand Down
Loading

0 comments on commit 2b5df70

Please sign in to comment.