Skip to content

Commit

Permalink
Pipeline parallel working
Browse files Browse the repository at this point in the history
Signed-off-by: Virginia Adams <vadams@nvidia.com>
  • Loading branch information
vadam5 committed Jun 1, 2022
1 parent 9a8c7bd commit b13448d
Showing 1 changed file with 21 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,11 @@ def state_dict(self):
nemo checkpoints at the end of training will contain prompt table parameters only.
"""
state_dict_ = {}
state_dict_[self._prompt_table_key] = self.prompt_table.state_dict()
if self.frozen_model.model.pre_process:
state_dict_[self._prompt_table_key] = self.prompt_table.state_dict()

if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
state_dict_[self._prompt_encoder_key] = self.prompt_encoder.state_dict()
if self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
state_dict_[self._prompt_encoder_key] = self.prompt_encoder.state_dict()

return state_dict_

Expand All @@ -313,21 +314,22 @@ def load_state_dict(self, state_dict, strict: bool = True):
Custom load state dict method that only loads prompt table and prompt encoder
parameters. Matching load method for this class' custom state dict method.
"""
if self._prompt_table_key in state_dict:
state_dict_ = state_dict[self._prompt_table_key]
else:
# Handle loading state dict before weight saving change for backward compatibility.
state_dict_ = OrderedDict()
for key in state_dict.keys():
if key.startswith(self._prompt_table_key):
key_substring = ".".join(key.split(".")[1:])
state_dict_[key_substring] = state_dict[key]
if self.frozen_model.model.pre_process:
if self._prompt_table_key in state_dict:
state_dict_ = state_dict[self._prompt_table_key]
else:
# Handle loading state dict before weight saving change for backward compatibility.
state_dict_ = OrderedDict()
for key in state_dict.keys():
if key.startswith(self._prompt_table_key):
key_substring = ".".join(key.split(".")[1:])
state_dict_[key_substring] = state_dict[key]

self.prompt_table.load_state_dict(state_dict_, strict)
self.prompt_table.load_state_dict(state_dict_, strict)

if self._prompt_encoder_key in state_dict and self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
state_dict_ = state_dict[self._prompt_encoder_key]
self.prompt_encoder.load_state_dict(state_dict_, strict)
if self._prompt_encoder_key in state_dict and self.virtual_prompt_source == VirtualPromptSource.PROMPT_ENCODER:
state_dict_ = state_dict[self._prompt_encoder_key]
self.prompt_encoder.load_state_dict(state_dict_, strict)

def setup_optimizer_param_groups(self):
"""
Expand Down Expand Up @@ -670,7 +672,7 @@ def allreduce_gradients(self):

def on_train_end(self):
# Save p-tuned prompts to prompt table for inference or future task training
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.model.model.pre_process:
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.frozen_model.model.pre_process:
self.add_ptuned_prompts_to_prompt_table()
logging.info(f"All p-tuned prompts where moved to the prompt table.")

Expand Down Expand Up @@ -862,7 +864,8 @@ def set_input_tensor(self, input_tensor):
model's forward_step_func won't have it. This function is thus
used by internal code to bypass the input provided by the
forward_step_func"""
self.input_tensor = input_tensor
#self.input_tensor = input_tensor
self.frozen_model.model.set_input_tensor(input_tensor)

def get_forward_output_and_loss_func(self):
def fwd_output_and_loss_func(batch, model):
Expand Down

0 comments on commit b13448d

Please sign in to comment.