From 256236f684cd08abd27008d6a9c41d77993be138 Mon Sep 17 00:00:00 2001 From: Virginia Adams <78445382+vadam5@users.noreply.github.com> Date: Tue, 1 Mar 2022 14:27:32 -0800 Subject: [PATCH] Prompt tuning bug fix (#3780) * Making updated code backwards compatible with previous prompt tuned models Signed-off-by: Virginia Adams * Fixed backward compatiablity bug Signed-off-by: Virginia Adams * Removed random import Signed-off-by: Virginia Adams Co-authored-by: Eric Harper --- .../language_modeling/megatron_gpt_model.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 59f87fa3c174..14847b3d9b24 100755 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -139,10 +139,6 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): tensor_model_parallel_size=cfg.get('tensor_model_parallel_size', 1), ) - # TODO: Not sure how to use lists of modules with PTL. - # This means we can only use pipeline parallelism without the interleaved schedule. - self.model = build_model(model_provider_func=self.model_provider_func, wrap_with_ddp=False)[0] - # Prompt tuning initialization self.use_soft_prompts = self.cfg.get('use_soft_prompts', False) @@ -156,12 +152,27 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): self.num_prompt_tokens = cfg.get('num_prompt_tokens', 100) if self.cfg.get('existing_prompt_tags', None): + # Assign prompt tag ids if none were present in the config + if type(self.cfg.existing_prompt_tags[0]) == str: + existing_prompt_tags = self.cfg.existing_prompt_tags + num_prompt_tags = len(existing_prompt_tags) + existing_prompt_tags = [ + (existing_prompt_tags[tag_id], tag_id + 1) for tag_id in range(num_prompt_tags) + ] + + with open_dict(self.cfg): + self.cfg.existing_prompt_tags = existing_prompt_tags + # Fill table with prev tuned prompt tags and their ids self.prompt_table = set(self.cfg.existing_prompt_tags) # Get max prompt id from table for starting point of new prompt ids self.next_prompt_id = max(self.prompt_table, key=lambda x: x[1])[1] + # TODO: Not sure how to use lists of modules with PTL. + # This means we can only use pipeline parallelism without the interleaved schedule. + self.model = build_model(model_provider_func=self.model_provider_func, wrap_with_ddp=False)[0] + self.setup_optimizer_param_groups() self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False)