Skip to content

Commit

Permalink
Fix for prompt table restore error (NVIDIA#5393)
Browse files Browse the repository at this point in the history
* Fix for prompt table restore error

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Added more saftey checks

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Added more condition checks

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

Signed-off-by: Virginia Adams <vadams@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
vadam5 and pre-commit-ci[bot] authored Nov 14, 2022
1 parent 4c9c858 commit dbe41af
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -436,6 +436,12 @@ def save_checkpoint_as_nemo_file(self):
self.virtual_prompt_style = current_virtual_prompt_style
self.virtual_prompt_source = current_virtual_prompt_source

# Revert prompt table back to previous state
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.first_stage_of_pipeline():
for taskname in current_new_tasks:
if taskname in self.prompt_table.prompt_table:
del self.prompt_table.prompt_table[taskname]

with open_dict(self.cfg):
self.cfg.existing_tasks = current_existing_tasks
self.cfg.new_tasks = current_new_tasks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -707,6 +707,12 @@ def save_checkpoint_as_nemo_file(self):
self.virtual_prompt_style = current_virtual_prompt_style
self.virtual_prompt_source = current_virtual_prompt_source

# Revert prompt table back to previous state
if self.virtual_prompt_style == VirtualPromptStyle.P_TUNING and self.frozen_model.model.pre_process:
for taskname in current_new_tasks:
if taskname in self.prompt_table.prompt_table:
del self.prompt_table.prompt_table[taskname]

with open_dict(self.cfg):
self.cfg.existing_tasks = current_existing_tasks
self.cfg.new_tasks = current_new_tasks
Expand Down

0 comments on commit dbe41af

Please sign in to comment.