21
21
)
22
22
23
23
from ._operation import gather_forward_split_backward , reduce_forward
24
- from .parallel_module import ParallelModule
24
+ from .parallel_module import PaddingParallelModule , ParallelModule
25
25
from .utils import create_randomizer_with_offset
26
26
27
- __all__ = ["Embedding1D" , "VocabParallelEmbedding1D" ]
27
+ __all__ = ["Embedding1D" , "VocabParallelEmbedding1D" , "PaddingEmbedding" ]
28
28
29
29
30
30
class Embedding1D (ParallelModule ):
@@ -161,7 +161,80 @@ def forward(self, input_: Tensor) -> Tensor:
161
161
return output_parallel
162
162
163
163
164
- class VocabParallelEmbedding1D (ParallelModule ):
164
+ class PaddingEmbedding (PaddingParallelModule ):
165
+ def __init__ (
166
+ self ,
167
+ num_embeddings : int ,
168
+ embedding_dim : int ,
169
+ padding_idx : int = None ,
170
+ dtype : torch .dtype = None ,
171
+ device : torch .device = None ,
172
+ weight : Optional [nn .Parameter ] = None ,
173
+ make_vocab_size_divisible_by : int = 64 ,
174
+ * args ,
175
+ ** kwargs ,
176
+ ):
177
+ self .num_embeddings = num_embeddings
178
+ self .embedding_dim = embedding_dim
179
+ self .embed_args = args
180
+ self .embed_kwargs = kwargs
181
+ self .padding_idx = padding_idx
182
+ if num_embeddings % make_vocab_size_divisible_by != 0 :
183
+ self .num_embeddings = (
184
+ num_embeddings + make_vocab_size_divisible_by - (num_embeddings % make_vocab_size_divisible_by )
185
+ )
186
+ # create weight and bias
187
+ if weight is None :
188
+ factory_kwargs = {"device" : device , "dtype" : dtype }
189
+ weight = nn .Parameter (torch .empty ((num_embeddings , self .embedding_dim ), ** factory_kwargs ))
190
+ else :
191
+ weight .data = weight .data .to (device = device , dtype = dtype )
192
+
193
+ super ().__init__ (self .num_embeddings , num_embeddings , weight )
194
+
195
+ if weight is None :
196
+ self .reset_parameters ()
197
+
198
+ def reset_parameters (self ) -> None :
199
+ init .normal_ (self .weight )
200
+ self ._fill_padding_idx_with_zero ()
201
+
202
+ def _fill_padding_idx_with_zero (self ) -> None :
203
+ if self .padding_idx is not None :
204
+ with torch .no_grad ():
205
+ self .weight [self .padding_idx ].fill_ (0 )
206
+
207
+ def forward (self , input : Tensor ) -> Tensor :
208
+ return F .embedding (input , self .weight , self .padding_idx , * self .embed_args , ** self .embed_kwargs )
209
+
210
+ @staticmethod
211
+ def from_native_module (
212
+ module : nn .Embedding , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , ** kwargs
213
+ ) -> PaddingParallelModule :
214
+ r"""
215
+ Convert a native pytorch embedding module to a parallel module.
216
+ """
217
+ LazyInitContext .materialize (module )
218
+ # get the origin attributes
219
+ num_embeddings = module .num_embeddings
220
+ embedding_dim = module .embedding_dim
221
+ padding_idx = module .padding_idx
222
+ device = module .weight .device
223
+ # create the parallel module
224
+ padding_embedding = PaddingEmbedding (
225
+ num_embeddings = num_embeddings ,
226
+ embedding_dim = embedding_dim ,
227
+ padding_idx = padding_idx ,
228
+ device = device ,
229
+ weight = module .weight ,
230
+ * args ,
231
+ ** kwargs ,
232
+ )
233
+
234
+ return padding_embedding
235
+
236
+
237
+ class VocabParallelEmbedding1D (PaddingParallelModule ):
165
238
r"""Embedding parallelized in the vocabulary dimension.
166
239
167
240
Args:
@@ -201,10 +274,10 @@ def __init__(
201
274
process_group : ProcessGroup = None ,
202
275
weight : Optional [nn .Parameter ] = None ,
203
276
weight_initializer : Callable = init .normal_ (),
277
+ make_vocab_size_divisible_by : int = 64 ,
204
278
* args ,
205
279
** kwargs ,
206
280
):
207
- super ().__init__ ()
208
281
self .num_embeddings = num_embeddings
209
282
self .embedding_dim = embedding_dim
210
283
self .embed_args = args
@@ -214,8 +287,23 @@ def __init__(
214
287
tensor_parallel_size = dist .get_world_size (group = process_group )
215
288
tensor_parallel_rank = dist .get_rank (group = process_group )
216
289
217
- self .num_embeddings_per_partition = divide (num_embeddings , tensor_parallel_size )
218
- self .num_embeddings = self .num_embeddings_per_partition
290
+ # generate weight and bias
291
+ if weight is None :
292
+ factory_kwargs = {"device" : device , "dtype" : dtype }
293
+ weight = nn .Parameter (torch .empty ((num_embeddings , self .embedding_dim ), ** factory_kwargs ))
294
+ else :
295
+ weight .data = weight .data .to (device = device , dtype = dtype )
296
+
297
+ # calculate new padding size
298
+ multiple = make_vocab_size_divisible_by * tensor_parallel_size
299
+ if num_embeddings % multiple != 0 :
300
+ self .num_embeddings = num_embeddings + multiple - (num_embeddings % multiple )
301
+
302
+ # resize vocabulary size
303
+ super ().__init__ (self .num_embeddings , num_embeddings , weight )
304
+
305
+ # deal with tensor parallelism
306
+ self .num_embeddings_per_partition = divide (self .num_embeddings , tensor_parallel_size )
219
307
self .vocab_start_index = tensor_parallel_rank * self .num_embeddings_per_partition
220
308
self .vocab_end_index = self .vocab_start_index + self .num_embeddings_per_partition
221
309
@@ -226,13 +314,6 @@ def __init__(
226
314
seed = torch .random .initial_seed ()
227
315
self .randomizer = create_randomizer_with_offset (seed , process_group = self .process_group )
228
316
229
- # parameter
230
- if weight is None :
231
- factory_kwargs = {"device" : device , "dtype" : dtype }
232
- self .weight = nn .Parameter (torch .empty ((num_embeddings , self .embedding_dim ), ** factory_kwargs ))
233
- else :
234
- weight .data = weight .data .to (device = device , dtype = dtype )
235
- self .weight = weight
236
317
if not is_distributed_tensor (self .weight ):
237
318
sharded_weight = shard_rowwise (self .weight .data , process_group )
238
319
sharded_tensor_to_existing_param (sharded_weight , self .weight )
@@ -243,7 +324,7 @@ def __init__(
243
324
@staticmethod
244
325
def from_native_module (
245
326
module : nn .Embedding , process_group : Union [ProcessGroup , List [ProcessGroup ]], * args , ** kwargs
246
- ) -> ParallelModule :
327
+ ) -> PaddingParallelModule :
247
328
r"""
248
329
Convert a native pytorch embedding module to a parallel module.
249
330
"""
@@ -303,11 +384,9 @@ def forward(self, input_: Tensor) -> Tensor:
303
384
# Mask the input.
304
385
masked_input = input_ .clone () - self .vocab_start_index
305
386
masked_input [input_mask ] = 0
306
-
307
387
output_parallel = F .embedding (
308
388
masked_input , self .weight , self .padding_idx , * self .embed_args , ** self .embed_kwargs
309
389
)
310
-
311
390
# Mask the output embedding.
312
391
embedding_output = output_parallel .clone ()
313
392
embedding_output [input_mask , :] = 0.0
0 commit comments