diff --git a/megatron/training.py b/megatron/training.py index bc4223dc4..ec63560d9 100644 --- a/megatron/training.py +++ b/megatron/training.py @@ -134,8 +134,10 @@ def pretrain(train_valid_test_dataset_provider, # Model, optimizer, and learning rate. timers('model-and-optimizer-setup').start() model, optimizer, lr_scheduler = setup_model_and_optimizer(model_provider) - print_rank_0(f'estimated model parameters: {get_parameters_in_billions(model)}') - print_rank_0(f'estimated model parameters without embeddings: {get_parameters_in_billions(model, exclude_embeddings=True)}') + args.parameters_in_billions = get_parameters_in_billions(model) + args.parameters_in_billions_no_embedding = get_parameters_in_billions(model, exclude_embeddings=True) + print_rank_0(f'estimated model parameters: {args.parameters_in_billions:.4f}B') + print_rank_0(f'estimated model parameters without embeddings: {args.parameters_in_billions_no_embedding:.4f}B') timers('model-and-optimizer-setup').stop() print_datetime('after model, optimizer, and learning rate ' 'scheduler are built') @@ -815,7 +817,7 @@ def train(forward_step_func, model, optimizer, lr_scheduler, args.consumed_train_tokens += new_samples * args.curriculum_seqlen else: args.consumed_train_tokens += new_samples * args.seq_length - args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * get_parameters_in_billions(model, exclude_embeddings=True)) + args.gigaflos_no_embeds += (6 * new_samples * args.seq_length * args.parameters_in_billions_no_embedding) # Logging. if args.deepspeed: diff --git a/megatron/utils.py b/megatron/utils.py index f234d1650..1dc5895c3 100644 --- a/megatron/utils.py +++ b/megatron/utils.py @@ -256,25 +256,35 @@ def unique_param_count(param_list): def non_embedding_params(module): - embedding_param_names = [ + embedding_param_names = set([ f"{name}.weight" for name, module_type in module.named_modules() if isinstance(module_type, nn.Embedding) or isinstance(module_type, VocabParallelEmbedding) - ] + ]) non_embedding_parameters = [ parameter for name, parameter in module.named_parameters() if name not in embedding_param_names ] return unique_param_count(non_embedding_parameters) +def _param_count_from_deepspeed(model): + # from https://github.com/microsoft/DeepSpeed/blob/7a132a9f4b37959f951b7c04a05207aba6054965/deepspeed/runtime/pipe/engine.py#L134-L157 + + num_params = sum([p.numel() for p in model.module.parameters() if p.requires_grad]) + # Subtract tied parameters if we don't own them + if model.module.tied_comms: + tied_params = 0 + for key, d in model.module.tied_comms.items(): + if model.global_rank != min(d['ranks']): + tied_params += sum([p.numel() for p in d['module'].parameters()]) + num_params -= tied_params + return num_params + def get_parameters_in_billions(model, exclude_embeddings=False): gpus_per_model = torch.distributed.get_world_size(group=mpu.get_model_parallel_group()) if exclude_embeddings: approx_parameters_in_billions = sum([non_embedding_params(model_module) for model_module in model]) else: - args = get_args() - if args.rank == 0: - warnings.warn("Parameter count with the embeddings will be inaccurate with PP > 1, as the first and last stage hold several copies of the embeddings") - approx_parameters_in_billions = unique_param_count([p for model_module in model for p in model_module.parameters()]) + approx_parameters_in_billions = sum([_param_count_from_deepspeed(model_module) for model_module in model]) return approx_parameters_in_billions*gpus_per_model/(1e9)