Skip to content

Commit

Permalink
Prompt tuning bug fix (#3780)
Browse files Browse the repository at this point in the history
* Making updated code backwards compatible with previous prompt tuned models

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Fixed backward compatiablity bug

Signed-off-by: Virginia Adams <vadams@nvidia.com>

* Removed random import

Signed-off-by: Virginia Adams <vadams@nvidia.com>

Co-authored-by: Eric Harper <complex451@gmail.com>
  • Loading branch information
vadam5 and ericharper authored Mar 1, 2022
1 parent afba754 commit 256236f
Showing 1 changed file with 15 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -139,10 +139,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1),
)

# TODO: Not sure how to use lists of modules with PTL.
# This means we can only use pipeline parallelism without the interleaved schedule.
self.model = build_model(model_provider_func=self.model_provider_func, wrap_with_ddp=False)[0]

# Prompt tuning initialization
self.use_soft_prompts = self.cfg.get('use_soft_prompts', False)

Expand All @@ -156,12 +152,27 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.num_prompt_tokens = cfg.get('num_prompt_tokens', 100)

if self.cfg.get('existing_prompt_tags', None):
# Assign prompt tag ids if none were present in the config
if type(self.cfg.existing_prompt_tags[0]) == str:
existing_prompt_tags = self.cfg.existing_prompt_tags
num_prompt_tags = len(existing_prompt_tags)
existing_prompt_tags = [
(existing_prompt_tags[tag_id], tag_id + 1) for tag_id in range(num_prompt_tags)
]

with open_dict(self.cfg):
self.cfg.existing_prompt_tags = existing_prompt_tags

# Fill table with prev tuned prompt tags and their ids
self.prompt_table = set(self.cfg.existing_prompt_tags)

# Get max prompt id from table for starting point of new prompt ids
self.next_prompt_id = max(self.prompt_table, key=lambda x: x[1])[1]

# TODO: Not sure how to use lists of modules with PTL.
# This means we can only use pipeline parallelism without the interleaved schedule.
self.model = build_model(model_provider_func=self.model_provider_func, wrap_with_ddp=False)[0]

self.setup_optimizer_param_groups()

self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False)
Expand Down

0 comments on commit 256236f

Please sign in to comment.