Skip to content

Commit

Permalink
[shardformer] supported T5 and its variants (hpcaitech#4045)
Browse files Browse the repository at this point in the history
  • Loading branch information
FrankLeeeee authored and FoolPlayer committed Jun 21, 2023
1 parent 36936c2 commit 74aaead
Show file tree
Hide file tree
Showing 10 changed files with 316 additions and 221 deletions.
5 changes: 2 additions & 3 deletions colossalai/shardformer/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,16 +81,15 @@ We will follow this roadmap to develop Shardformer:
- [ ] Hugging Face
- [ ] NLP
- [x] BERT
- [ ] T5
- [ ] LlaMa
- [x] T5
- [x] LlaMa
- [ ] GPT2
- [ ] BLOOM
- [ ] RoBERTa
- [ ] ALBERT
- [ ] ERNIE
- [ ] GPT Neo
- [ ] GPT-J
- [ ] CV
- [ ] CV
- [ ] ViT
- [ ] BEiT
Expand Down
26 changes: 17 additions & 9 deletions colossalai/shardformer/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,14 @@ def __init__(self,
dtype: torch.dtype = None,
device: torch.device = None,
process_group: ProcessGroup = None,
gather_output: bool = True,
weight_initializer: Callable = init.normal_(),
*args,
**kwargs):
super().__init__()

self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.process_group = process_group
self.num_partitions = dist.get_world_size(process_group)
self.embed_dim_per_partition = divide(embedding_dim, self.num_partitions)
Expand All @@ -499,7 +500,9 @@ def __init__(self,

@staticmethod
def from_native_module(module: nn.Embedding,
process_group: Union[ProcessGroup, List[ProcessGroup]] = None) -> "Embedding1D":
process_group: Union[ProcessGroup, List[ProcessGroup]] = None,
*args,
**kwargs) -> "Embedding1D":
r"""
Build a 1D parallelized Embedding from a native nn.Embedding module.
"""
Expand Down Expand Up @@ -527,7 +530,9 @@ def from_native_module(module: nn.Embedding,
max_norm=max_norm,
norm_type=norm_type,
scale_grad_by_freq=scale_grad_by_freq,
sparse=sparse)
sparse=sparse,
*args,
**kwargs)

# copy the weight
with torch.no_grad():
Expand All @@ -537,7 +542,7 @@ def from_native_module(module: nn.Embedding,
return embedding

def reset_parameters(self, weight_initializer) -> None:
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

Expand All @@ -548,9 +553,12 @@ def _fill_padding_idx_with_zero(self) -> None:

def forward(self, input_: Tensor) -> Tensor:
output_parallel = F.embedding(input_, self.weight, self.padding_idx, *self.embed_args, **self.embed_kwargs)
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)

return output
if self.gather_output:
output = gather_forward_split_backward(output_parallel, dim=-1, process_group=self.process_group)
return output
else:
return output_parallel


class VocabParallelEmbedding1D(ParallelLayer):
Expand Down Expand Up @@ -595,7 +603,7 @@ def __init__(self,
**kwargs):
super().__init__()
self.num_embeddings = num_embeddings
self.embed_dim = embedding_dim
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.embed_args = args
self.embed_kwargs = kwargs
Expand All @@ -610,7 +618,7 @@ def __init__(self,
self.vocab_end_index = self.vocab_start_index + self.num_embeddings_per_partition

self.weight = Parameter(
torch.empty((self.num_embeddings_per_partition, self.embed_dim), device=device, dtype=dtype))
torch.empty((self.num_embeddings_per_partition, self.embedding_dim), device=device, dtype=dtype))

# offset the seed with randomizer index and rank
seed = torch.random.initial_seed()
Expand Down Expand Up @@ -662,7 +670,7 @@ def _set_tensor_parallel_attributes(self):

def reset_parameters(self, weight_initializer) -> None:
with seed(ParallelMode.TENSOR):
fan_in, fan_out = self.num_embeddings, self.embed_dim
fan_in, fan_out = self.num_embeddings, self.embedding_dim
weight_initializer(self.weight, fan_in=fan_in, fan_out=fan_out)
self._fill_padding_idx_with_zero()

Expand Down
6 changes: 6 additions & 0 deletions colossalai/shardformer/policies/autopolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@ class PolicyLocation:
PolicyLocation(file_name="llama", class_name="LlamaForSequenceClassificationPolicy"),

# T5
"transformers.models.t5.modeling_t5.T5Model":
PolicyLocation(file_name="t5", class_name="T5ModelPolicy"),
"transformers.models.t5.modeling_t5.T5ForConditionalGeneration":
PolicyLocation(file_name="t5", class_name="T5ForConditionalGenerationPolicy"),
"transformers.models.t5.modeling_t5.T5EncoderModel":
PolicyLocation(file_name="t5", class_name="T5EncoderPolicy"),

# GPT2
}
Expand Down
1 change: 1 addition & 0 deletions colossalai/shardformer/policies/basepolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class SubModuleReplacementDescription:
suffix: str
target_module: ParallelModule
kwargs: Dict[str, Any] = None
ignore_if_not_exist: bool = False


@dataclass
Expand Down
Loading

0 comments on commit 74aaead

Please sign in to comment.