From 544108dcba4e87f36f6eb2dca7e4bf5e08c62a34 Mon Sep 17 00:00:00 2001 From: Jake Tae <> Date: Wed, 24 Nov 2021 01:26:26 -0500 Subject: [PATCH 1/6] refactor: compute model param count once --- megatron/training.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index bc4223dc4..7f232f6b9 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -134,8 +134,10 @@ def pretrain(train_valid_test_dataset_provider, # Model, optimizer, and learning rate. timers('model-and-optimizer-setup').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) - print_rank_0(f'estimated model parameters: {get_parameters_in_billions(model)}') - print_rank_0(f'estimated model parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)}') + args.parameters_in_billions = get_parameters_in_billions(model) + args.parameters_in_billions_no_embedding = get_parameters_in_billions(model, exclude_embeddings=True) + print_rank_0(f'estimated model parameters: {args.parameters_in_billions}') + print_rank_0(f'estimated model parameters without embeddings: {args.parameters_in_billions_no_embedding}') timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') @@ -740,7 +742,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, tp_rank = mpu.get_tensor_model_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() preamble = f"[{tp_rank:0>3d}-{pp_rank:0>3d}]" - print(f"{preamble} {get_parameters_in_billions(model):.4f}B / {get_parameters_in_billions(model, exclude_embeddings=True):.4f}B") + print(f"{preamble} {args.parameters_in_billions:.4f}B / {args.parameters_in_billions_no_embedding:.4f}B") torch.distributed.barrier() else: torch.distributed.barrier() @@ -815,7 +817,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, args.consumed_train_tokens += new_samples * args.curriculum_seqlen else: args.consumed_train_tokens += new_samples * args.seq_length - args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * get_parameters_in_billions(model, exclude_embeddings=True)) + args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * args.parameters_in_billions_no_embedding) # Logging. if args.deepspeed: From 816b867023c65c6676d98cabd556847f170a916f Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 24 Nov 2021 18:03:42 -0500 Subject: [PATCH 2/6] Update megatron/training.py Co-authored-by: Stas Bekman --- megatron/training.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/training.py b/megatron/training.py index 7f232f6b9..9d9b36287 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -136,8 +136,8 @@ def pretrain(train_valid_test_dataset_provider, model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) args.parameters_in_billions = get_parameters_in_billions(model) args.parameters_in_billions_no_embedding = get_parameters_in_billions(model, exclude_embeddings=True) - print_rank_0(f'estimated model parameters: {args.parameters_in_billions}') - print_rank_0(f'estimated model parameters without embeddings: {args.parameters_in_billions_no_embedding}') + print_rank_0(f'estimated model parameters: {args.parameters_in_billions:.4f}B') + print_rank_0(f'estimated model parameters without embeddings: {args.parameters_in_billions_no_embedding:.4f}B') timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') From c2d6390322f20b24d33cfaba517fc7230c421e81 Mon Sep 17 00:00:00 2001 From: Stas Bekman Date: Thu, 25 Nov 2021 09:48:32 -0800 Subject: [PATCH 3/6] Update megatron/training.py Co-authored-by: Jake Tae --- megatron/training.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/megatron/training.py b/megatron/training.py index 9d9b36287..ec63560d9 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -742,7 +742,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, tp_rank = mpu.get_tensor_model_parallel_rank() pp_rank = mpu.get_pipeline_model_parallel_rank() preamble = f"[{tp_rank:0>3d}-{pp_rank:0>3d}]" - print(f"{preamble} {args.parameters_in_billions:.4f}B / {args.parameters_in_billions_no_embedding:.4f}B") + print(f"{preamble} {get_parameters_in_billions(model):.4f}B / {get_parameters_in_billions(model, exclude_embeddings=True):.4f}B") torch.distributed.barrier() else: torch.distributed.barrier() From f4c7c67e7cb8be483fd3269509c7218d22ac2bcb Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 1 Dec 2021 06:03:34 +0000 Subject: [PATCH 4/6] fix: use deepspeed param count method --- megatron/utils.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/megatron/utils.py b/megatron/utils.py index f234d1650..f11a7ff16 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -265,16 +265,28 @@ def non_embedding_params(module): return unique_param_count(non_embedding_parameters) +def _param_count_from_deepspeed(model): + # from https://github.com/microsoft/DeepSpeed/blob/7a132a9f4b37959f951b7c04a05207aba6054965/deepspeed/runtime/pipe/engine.py#L134-L157 + + model_parameters = filter(lambda p: p.requires_grad, model.module.parameters()) + num_params = sum([p.numel() for p in model_parameters]) + unique_params = num_params + # Subtract tied parameters if we don't own them + if model.module.tied_comms: + tied_params = 0 + for key, d in model.module.tied_comms.items(): + if model.global_rank != min(d['ranks']): + tied_params += sum(p.numel() for p in d['module'].parameters()) + unique_params -= tied_params + return unique_params + def get_parameters_in_billions(model, exclude_embeddings=False): gpus_per_model = torch.distributed.get_world_size(group=mpu.get_model_parallel_group()) if exclude_embeddings: approx_parameters_in_billions = sum([non_embedding_params(model_module) for model_module in model]) else: - args = get_args() - if args.rank == 0: - warnings.warn("Parameter count with the embeddings will be inaccurate with PP > 1, as the first and last stage hold several copies of the embeddings") - approx_parameters_in_billions = unique_param_count([p for model_module in model for p in model_module.parameters()]) + approx_parameters_in_billions = sum([_param_count_from_deepspeed(model_module) for model_module in model]) return approx_parameters_in_billions*gpus_per_model/(1e9) From a7b10b7cfc28971eb37ff956ebbd921a618be531 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 1 Dec 2021 07:07:56 +0000 Subject: [PATCH 5/6] refactor: replace filter w/ list comp, generator to list --- megatron/utils.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/megatron/utils.py b/megatron/utils.py index f11a7ff16..1a3065228 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -268,17 +268,15 @@ def non_embedding_params(module): def _param_count_from_deepspeed(model): # from https://github.com/microsoft/DeepSpeed/blob/7a132a9f4b37959f951b7c04a05207aba6054965/deepspeed/runtime/pipe/engine.py#L134-L157 - model_parameters = filter(lambda p: p.requires_grad, model.module.parameters()) - num_params = sum([p.numel() for p in model_parameters]) - unique_params = num_params + num_params = sum([p.numel() for p in model.module.parameters() if p.requires_grad]) # Subtract tied parameters if we don't own them if model.module.tied_comms: tied_params = 0 for key, d in model.module.tied_comms.items(): if model.global_rank != min(d['ranks']): - tied_params += sum(p.numel() for p in d['module'].parameters()) - unique_params -= tied_params - return unique_params + tied_params += sum([p.numel() for p in d['module'].parameters()]) + num_params -= tied_params + return num_params def get_parameters_in_billions(model, exclude_embeddings=False): gpus_per_model = torch.distributed.get_world_size(group=mpu.get_model_parallel_group()) From ac3e138be3f5e4845f5098f04e9b5fa103ad7b17 Mon Sep 17 00:00:00 2001 From: Jake Tae Date: Wed, 1 Dec 2021 07:11:13 +0000 Subject: [PATCH 6/6] refactor: use set for constant time lookup --- megatron/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/megatron/utils.py b/megatron/utils.py index 1a3065228..1dc5895c3 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -256,9 +256,9 @@ def unique_param_count(param_list): def non_embedding_params(module): - embedding_param_names = [ + embedding_param_names = set([ f"{name}.weight" for name, module_type in module.named_modules() if isinstance(module_type, nn.Embedding) or isinstance(module_type, VocabParallelEmbedding) - ] + ]) non_embedding_parameters = [ parameter for name, parameter in module.named_parameters() if name not in embedding_param_names ]