Skip to content

Commit

Permalink
Merge branch 'add_dist_opt_in_clip' into 'internal/main'
Browse files Browse the repository at this point in the history
Add dist opt in clip

See merge request dl/JoC/nemo_multimodal!38
  • Loading branch information
Yu Yao committed May 26, 2023
2 parents 2329912 + 8e41778 commit 366c8ec
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 30 deletions.
70 changes: 64 additions & 6 deletions nemo/collections/multimodal/models/clip/megatron_clip_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,6 +391,7 @@ def configure_optimizers(self):
module = self.model[0] # only the first virtual rank has the embeddings
else:
module = self.model
# TODO (yuya): text transformer's embedding needs to be taken care of when PP>1
# if module.share_token_embeddings:
# param = module.word_embeddings_weight()
# param._disable_greedy_grad_copy = not self.megatron_amp_O2
Expand All @@ -408,10 +409,48 @@ def configure_optimizers(self):
# Disable overlapped grad sync for layer norm grads when
# sequence parallelism is enabled
for param in self.parameters():
if getattr(param, 'sequence_parallel_enabled', False):
if getattr(param, 'sequence_parallel', False):
param._disable_greedy_grad_copy = not self.megatron_amp_O2
param._disable_overlap_grad_sync = True

# Initialize parameter buckets for overlapped grad and param syncs
# Note: Params with disabled overlapping are put in the
# last param bucket
buckets = []
if self.cfg.get('virtual_pipeline_model_parallel_size', None) is not None:
# Initialize a bucket for each virtual pipeline stage
for module in self.model:
if isinstance(module, Float16Module):
module = module.module
stage_bucket = []
for layer in itertools.chain(
module.vision_encoder.backbone.transformer.layers,
module.text_encoder.language_model.encoder.layers,
):
stage_bucket.extend(
p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)
)
buckets.append(stage_bucket)
else:
# Initialize a bucket for each Transformer layer
modules = self.model if isinstance(self.model, list) else [self.model]
for module in modules:
if isinstance(module, Float16Module):
module = module.module
for layer in itertools.chain(
module.vision_encoder.backbone.transformer.layers,
module.text_encoder.language_model.encoder.layers,
):
buckets.append(
[p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)]
)
buckets.reverse()
used_params = set()
for bucket in buckets:
used_params.update(bucket)
buckets[-1].extend(p for p in self.parameters() if p not in used_params)
self.distributed_adam_buckets = buckets

return super().configure_optimizers()

def forward(self, image, text):
Expand All @@ -431,6 +470,24 @@ def training_step(self, dataloader_iter, batch_idx):
# we zero grads here because we also call backward in the megatron-core fwd/bwd functions
self._optimizer.zero_grad()

if self.with_distributed_adam:
# hack to enable overlapping param sync and forward compute
# note: the distributed optimizer monkey-patches each
# parameter's __getattribute__ function so that it can
# launch parameter all-gathers the first time the
# parameter is accessed after the optimizer step. However,
# PyTorch directly passes embedding parameters into a C++,
# bypassing this process. A quick-and-dirty hack is to
# manually interact with the parameter.
modules = self.model if isinstance(self.model, list) else [self.model]
for module in modules:
if isinstance(module, Float16Module):
module = module.module
module = module.text_encoder.language_model
if hasattr(module, 'embedding'):
for param in module.embedding.parameters():
param.data_ptr()

# TODO (yuya): fix this shape
tensor_shape = None

Expand Down Expand Up @@ -465,20 +522,21 @@ def training_step(self, dataloader_iter, batch_idx):
self.allreduce_sequence_parallel_gradients()

if self.with_distributed_adam:
# gradients are reduced internally in distributed optimizer
pass
# synchronize asynchronous grad reductions
# note: not necessary, but reduces performance degradation
# from multiple simultaneous NCCL calls
self._optimizer._finish_bucket_grad_sync()
elif self.megatron_amp_O2:
# # when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
# when using pipeline parallelism grads must be all-reduced after the pipeline (not asynchronously)
# if self.cfg.get('pipeline_model_parallel_size', 1) > 1 or self.cfg.get('sequence_parallel', False):
# # main grads are stored in the MainParamsOptimizer wrapper
# self._optimizer.allreduce_main_grads()
self._optimizer.allreduce_main_grads()
else:
# async grad allreduce is not currently implemented for O1/autocasting mixed precision training
# so we all-reduce gradients after the pipeline
self.allreduce_gradients() # @sangkug we think this is causing memory to blow up (hurts perf)

# TODO (yuya): check if this is needed in text transformer
# TODO (yuya): check if this is needed in text transformer when PP>1
# if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
# # when using pipeline parallelism the first and last stage must keep embeddings in sync
# self.allreduce_first_last_embeddings()
Expand Down
74 changes: 52 additions & 22 deletions nemo/collections/multimodal/models/multimodal_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,10 +441,14 @@ def reduce_overlap_gradients(self):
stages, the grad sync is deferred until the bubble overhead.
"""
if self.with_distributed_adam and self._optimizer.overlap_grad_sync:
if params is None:
params = self._optimizer.parameters()
self._optimizer.try_grad_sync(params)

def sync_overlap_parameters(self, params=None):
if self.with_distributed_adam:
self._optimizer.try_grad_sync(
p for p in self._optimizer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)
)
self._optimizer._try_start_bucket_param_sync(params)

def on_train_batch_end(self, outputs, dataloader_iter: Any, batch_idx: int, unused: Optional[int] = 0) -> None:
super().on_train_batch_end(outputs, dataloader_iter, batch_idx)
Expand Down Expand Up @@ -489,23 +493,33 @@ def setup_optimization(
optim_kwargs = {} if optim_kwargs is None else optim_kwargs.copy()
if self.with_distributed_adam:

# Allocate grads since we are storing between microbatches
# Allocate contiguous buffers to avoid extra copies
optim_kwargs['contiguous_grad_buffer'] = True
optim_kwargs['contiguous_param_buffer'] = True

if self.megatron_amp_O2:
# Match param allgather with model dtype
if hasattr(self, 'autocast_dtype'):
optim_kwargs['param_sync_dtype'] = self.autocast_dtype
if self.autocast_dtype == torch.float:
optim_kwargs['store_params'] = False
elif self.autocast_dtype == torch.float16:
optim_kwargs['store_params'] = True
elif self.autocast_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
optim_kwargs['store_param_remainders'] = True
else:
# Assume FP32 params, so no need to store main params
# Make sure optimizer state is in FP32
optim_dtype = torch.float32
optim_kwargs['dtype'] = optim_dtype

# Make sure embedding grad reductions are in FP32
for name, param in self.named_parameters():
if 'word_embedding' in name or 'position_embedding' in name:
param._with_fp32_optimizer = True

# Match param allgather with model dtype
model_dtype = torch.float32
if self.megatron_amp_O2 and hasattr(self, 'autocast_dtype'):
model_dtype = self.autocast_dtype
optim_kwargs['param_sync_dtype'] = model_dtype

# Determine whether to store master params in optimizer
if optim_dtype == model_dtype:
optim_kwargs['store_params'] = False
elif optim_dtype == torch.float32 and model_dtype == torch.bfloat16:
optim_kwargs['store_params'] = False
optim_kwargs['store_param_remainders'] = True
else:
optim_kwargs['store_params'] = True

return super().setup_optimization(optim_config=optim_config, optim_kwargs=optim_kwargs)

Expand Down Expand Up @@ -562,12 +576,28 @@ def configure_optimizers(self):

# Configure distributed optimizer
if self.with_distributed_adam:
# Initialize params so that main grads are available

# Initialize param buckets if explicitly provided
if hasattr(self, 'distributed_adam_buckets'):
for bucket in self.distributed_adam_buckets:
self._optimizer.init_params_bucket(bucket)
del self.distributed_adam_buckets

# Make sure all params are initialized so main grads are
# available
# Note: Consolidate grads without overlap
self._optimizer.init_params(
p for p in self.parameters() if getattr(p, '_disable_overlap_grad_sync', False)
)
self._optimizer.init_params(self.parameters())
overlap_params = []
no_overlap_params = []
for p in self.parameters():
if getattr(p, '_disable_overlap_grad_sync', False):
no_overlap_params.append(p)
else:
overlap_params.append(p)
self._optimizer.init_params(reversed(overlap_params))
self._optimizer.init_params(reversed(no_overlap_params))

# Initialize contiguous parameter buffer
self._optimizer.init_param_buffer()

if self._scheduler is None:
return self._optimizer
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ def configure_optimizers(self):
if isinstance(module, Float16Module):
module = module.module
stage_bucket = []
#for layer in module.language_model.encoder.layers:
# for layer in module.language_model.encoder.layers:
for layer in module.backbone.transformer.layers:
stage_bucket.extend(
p for p in layer.parameters() if not getattr(p, '_disable_overlap_grad_sync', False)
Expand All @@ -258,7 +258,7 @@ def configure_optimizers(self):
for module in modules:
if isinstance(module, Float16Module):
module = module.module
#for layer in module.language_model.encoder.layers:
# for layer in module.language_model.encoder.layers:
for layer in module.backbone.transformer.layers:

buckets.append(
Expand Down

0 comments on commit 366c8ec

Please sign in to comment.