Skip to content

Commit

Permalink
[shardformer] update readme with modules implement doc (hpcaitech#3834)
Browse files Browse the repository at this point in the history
* update readme with modules content

* remove img
  • Loading branch information
FoolPlayer authored and FrankLeeeee committed Jun 8, 2023
1 parent 537a52b commit 997544c
Showing 1 changed file with 69 additions and 0 deletions.
69 changes: 69 additions & 0 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
- [🔨 Usage](#-usage)
- [🔮 Simple example](#-simple-example)
- [💡 Policy](#-policy)
- [😊 Module](#-module)


## 🔗 Introduction

Expand Down Expand Up @@ -188,3 +190,70 @@ CustomPolicy(Policy):
return NotImplementedError

```


## 😊 Module

1. Flowchart

<p align="center">
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/shardformer/shardformer_flowchart.png" width="600" />
</p>

2. Important Modules

- CLASS `shard_model`:

This is the user api to use shardformer, just create a model from transformers and define a custom policy or use shardformer autopolicy to make a shard model.

- CLASS `Layer`:

Parameters:
- weight (str): The weight suffix of the layer
- bias (str): The bias suffix of the layer
- replace_layer (:class:`colosalai.nn`): The layer to replace the original layer
- ignore (bool): Whether to ignore this layer if it is not in the model

This class is used to specify the replacement policy for a particular layer. If `replace_layer` is None, only parameter partitioning will be performed without replacing the layer class.

CLASS `Col_Layer(Layer)`:
- gather_output (bool): Whether to gather the output of the layer

This class inherited from `Layer`, representing the layer will be sliced along column.

CLASS `Row_Layer(Layer)`:

This class inherited from `Layer`, representing the layer will be sliced along row.

- CLASS `Policy`:

In Shardformer, this class holds significant importance as it defines the model partitioning methods, required parameter modifications, and model injection techniques all within a single Policy class.
- `Policy.attn_in()/attn_out()/mlp_in()/mlp_out()/embedding()/unembedding()`......

These functions define the partitioning methods of the parameters at different locations in the model. Each function returns a list of objects of Layer class that specify the replacement approach for these parameters. Shardformer also supports user-defined functions for modifying their models, in addition to the listed functions.
- `Policy.argument_policy()`

In this function, the user should use multiple dict to define which class of layers will require replacement. This includes the attributes and parameters that need to be modified or replaced. Attributes are stored in the form of a "suffix-string: value" dict, while parameters are stored via multiple static methods that return the replacement approach.
- `Policy.inject_policy()`

This function will return the injected model to replace the original model. The new model should be a nn.Module class which includes modified forward or backward functions or anything else.
- `Policy.binding_policy()`

This function will return the weight sharing information in the model in some dict. The key and value are both the suffixes of the shared parameters.

- CLASS `ModelSharder(model, policy)`:

This class helps shard the model, the parameter is the created transformers model and the custom policy. If custom policy is None, shardformer will automatically get already defined policy for the model.
- `ModelShard.inject_model()`

This function is used to inject the model to modify the forward and backward progress.
- `ModelShard.replace_layer()`

This function is used to replace the original layers with colossalai layer to make them paralleled and can do distributed communication.
- `ModelShard.bind_layer()`

This function is used to help different layers share weight or bias.

- CLASS `Slicer`:

This class is used to slice tensor according to policy.

0 comments on commit 997544c

Please sign in to comment.