Skip to content

Commit

Permalink
GPT Prompt Learning Improvements (#4496)
Browse files Browse the repository at this point in the history
* Updated pipeline parallel code to speed up training

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Load global batch size not local mini batch size

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Python reformatting

Signed-off-by: Virginia Adams <vadams@nvidia.com>
  • Loading branch information
vadam5 authored Jul 6, 2022
1 parent 4638799 commit ad61479
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 25 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ model:
sched:
name: CosineAnnealing
warmup_steps: 50
min_lr: 0.0 # min_lr must be 0.0 for prompt learning when pipeline parallel > 1
constant_steps: 0 # Constant steps should also be 0 when min_lr=0
min_lr: 0.0 # min_lr must be 0.0 for prompt learning
monitor: val_loss
reduce_on_plateau: false
Original file line number Diff line number Diff line change
Expand Up @@ -104,21 +104,20 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
override_config_path=frozen_model_cfg,
)

if self.frozen_model.cfg.precision == 16:
self.float_type = torch.float16
elif self.frozen_model.cfg.precision == 'bf16':
self.float_type = torch.bfloat16
else:
self.float_type = torch.float

# TODO: Enable amp_o2 training
self.megatron_amp_o2 = False
self.pipeline_parallel = self.cfg.get('pipeline_model_parallel_size', 1) > 1
self.tokenizer = self.frozen_model.tokenizer
self.hidden_size = self.frozen_model.cfg.hidden_size
self.existing_tasks = list(self.cfg.get('existing_tasks', []))
self.new_tasks = list(self.cfg.get('new_tasks', []))
self.virtual_prompt_style = VirtualPromptStyle(cfg.virtual_prompt_style)

if self.pipeline_parallel:
assert (
self.cfg.optim.sched.get("min_lr", 0.0) == 0.0
), "Minimum lr must be 0.0 when pipeline parallel size is > 1"

# Load templates for assigning virtual prompt token positions
self.load_task_templates(self.cfg.task_templates)

Expand Down Expand Up @@ -348,16 +347,33 @@ def setup_optimizer_param_groups(self):
to be passed around in pipeline parallel models. The prompt-encoder
and/or prompt table will use the learning rate set by the user.
"""
virtual_prompt_params = {'params': []}
frozen_model_params = {'params': [param for param in self.frozen_model.parameters()], 'lr': 0.0}
# Freeze frozen model
for param in self.frozen_model.parameters():
param.requires_grad = False

if self.frozen_model.model.pre_process:
virtual_prompt_params['params'].extend([param for param in self.prompt_table.parameters()])
# Need to handle frozen model freezing differently when pp > 1
if self.pipeline_parallel:
virtual_prompt_params = {'params': []}
frozen_model_params = {'params': [], 'lr': 0.0}

if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
virtual_prompt_params['params'].extend([param for param in self.prompt_encoder.parameters()])
if self.frozen_model.model.pre_process:
virtual_prompt_params['params'].extend([param for param in self.prompt_table.parameters()])

if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
virtual_prompt_params['params'].extend([param for param in self.prompt_encoder.parameters()])

self._optimizer_param_groups = virtual_prompt_params, frozen_model_params
# Unfreeze one part of each transformer layer setting lr to 0.0 so DDP
# and AMP won't complain but model still remains frozen
for layer in self.frozen_model.model.language_model.encoder.layers:
for param in layer.input_layernorm.parameters():
param.requires_grad = True

frozen_model_params['params'].extend([param for param in self.frozen_model.parameters()])

self._optimizer_param_groups = virtual_prompt_params, frozen_model_params

else:
super().setup_optimizer_param_groups()

def forward(
self,
Expand Down Expand Up @@ -388,7 +404,7 @@ def forward(
encoder_input = None

# Call forward on GPT model with preprocessed embeddings
if self.float_type == torch.float32:
if self.autocast_dtype == torch.float32:
output = self.frozen_model.model(
input_ids=None,
position_ids=None,
Expand All @@ -399,7 +415,7 @@ def forward(
inference_max_sequence_len=inference_max_sequence_len,
)
else:
with torch.autocast(device_type="cuda", dtype=self.float_type):
with torch.autocast(device_type="cuda", dtype=self.autocast_dtype):
output = self.frozen_model.model(
input_ids=None,
position_ids=None,
Expand Down Expand Up @@ -524,7 +540,7 @@ def fwd_bwd_step(self, batch, batch_idx, forward_only):
_, seq_length = batch[0].shape
tensor_shape = [seq_length, self.cfg.micro_batch_size, self.hidden_size]

if self.cfg.get('pipeline_model_parallel_size', 1) > 1:
if self.pipeline_parallel:
losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving(
forward_step_func=self.get_forward_output_and_loss_func(),
batch=batch,
Expand Down Expand Up @@ -580,7 +596,8 @@ def training_step(self, batch, batch_idx):

# Need to make sure the frozen model param learning rate stays 0.0
# so forceing lr to be 0.0 for gpt layers before param update
self._optimizer.param_groups[1]['lr'] = 0.0
if self.pipeline_parallel:
self._optimizer.param_groups[1]['lr'] = 0.0

return loss_mean

Expand Down Expand Up @@ -712,24 +729,24 @@ def build_virtual_prompt_dataset(
task_templates=self.task_templates,
pseudo_tokens=self.pseudo_tokens,
pad_token_id=self.pad_token_id,
max_seq_length=self.cfg.data.get('max_seq_length', self.frozen_model.cfg.max_position_embeddings),
max_seq_length=self.frozen_model.cfg.encoder_seq_length,
min_seq_length=self.cfg.data.get('min_seq_length', 1),
add_bos=self.cfg.data.get('add_bos', False),
add_eos=self.cfg.data.get('add_eos', True),
for_train=for_train,
)

rank = parallel_state.get_data_parallel_rank()
world_size = parallel_state.get_data_parallel_world_size()
data_parallel_size = parallel_state.get_data_parallel_world_size()
sampler = torch.utils.data.distributed.DistributedSampler(
dataset, num_replicas=world_size, rank=rank, shuffle=shuffle
dataset, num_replicas=data_parallel_size, rank=rank, shuffle=shuffle
)

dataloader = torch.utils.data.DataLoader(
dataset,
collate_fn=dataset.collate_fn,
sampler=sampler,
batch_size=batch_size,
batch_size=batch_size // data_parallel_size,
drop_last=drop_last,
num_workers=num_workers,
pin_memory=pin_memory,
Expand Down Expand Up @@ -771,7 +788,7 @@ def dummy():
task_templates=self.task_templates,
pseudo_tokens=self.pseudo_tokens,
pad_token_id=self.pad_token_id,
max_seq_length=self.cfg.data.get('max_seq_length', self.frozen_model.cfg.max_position_embeddings),
max_seq_length=self.frozen_model.cfg.encoder_seq_length,
min_seq_length=self.cfg.data.get('min_seq_length', 1),
add_bos=sampling_params["add_BOS"],
add_eos=False,
Expand Down Expand Up @@ -820,7 +837,7 @@ def set_input_tensor(self, input_tensor):
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
# self.input_tensor = input_tensor

self.frozen_model.model.set_input_tensor(input_tensor)

def get_forward_output_and_loss_func(self):
Expand Down

0 comments on commit ad61479

Please sign in to comment.