From 3e342f83f54ab1c009f75ef136085d297072c7be Mon Sep 17 00:00:00 2001 From: Virginia Adams <78445382+vadam5@users.noreply.github.com> Date: Thu, 17 Feb 2022 15:49:31 -0800 Subject: [PATCH] Parallel prompt tuning (#3670) * Started combined tensor parallel and pipeline parallel changes Signed-off-by: Virginia Adams * Gets through validation sanity checks Signed-off-by: Virginia Adams * Still working through bugs Signed-off-by: Virginia Adams * Able to run training but virtual token parameters don't get updated Signed-off-by: Virginia Adams * params weren't updating because they weren't setup w/ optimizer Signed-off-by: Virginia Adams * Parallel with single GPU is working! Signed-off-by: Virginia Adams * Tensor parallel = 2 is working Signed-off-by: Virginia Adams * Tensor parallel working and code cleaned up Signed-off-by: Virginia Adams * Added prompt tuning testing back in Signed-off-by: Virginia Adams * Complete method works again for prompt tuned mdoels Signed-off-by: Virginia Adams * removed random imports Signed-off-by: Virginia Adams --- Jenkinsfile | 47 +++--- .../conf/megatron_gpt_config.yaml | 0 .../conf/megatron_prompt_tuning_gpt.yaml | 129 +++++++++++++++ .../megatron_gpt_prompt_tuning.py | 84 +++++++--- .../megatron/gpt_prompt_tuning_dataset.py | 108 ++++++++----- .../language_modeling/megatron/gpt_model.py | 16 +- .../language_modeling/megatron_gpt_model.py | 151 +++++++++++++----- .../nlp/modules/common/megatron/clip_grads.py | 3 + .../modules/common/megatron/language_model.py | 147 ++++++++++------- tests/collections/nlp/test_prompt_tuning.py | 11 +- 10 files changed, 494 insertions(+), 202 deletions(-) mode change 100644 => 100755 examples/nlp/language_modeling/conf/megatron_gpt_config.yaml create mode 100755 examples/nlp/language_modeling/conf/megatron_prompt_tuning_gpt.yaml mode change 100644 => 100755 examples/nlp/language_modeling/megatron_gpt_prompt_tuning.py mode change 100644 => 100755 nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_tuning_dataset.py mode change 100644 => 100755 nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py mode change 100644 => 100755 nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py mode change 100644 => 100755 nemo/collections/nlp/modules/common/megatron/language_model.py diff --git a/Jenkinsfile b/Jenkinsfile index dbf1750aa952a..b6d115d9e7956 100644 --- a/Jenkinsfile +++ b/Jenkinsfile @@ -2085,43 +2085,38 @@ pipeline { 16" } } + stage('L2: Megatron GPT Prompt Tuning and Inference') { when { anyOf { - branch 'r1.6.1' - changeRequest target: 'r1.6.1' + branch 'main' + changeRequest target: 'main' } } failFast true steps { sh "python tests/collections/nlp/test_prompt_tuning.py" sh "python examples/nlp/language_modeling/megatron_gpt_prompt_tuning.py \ - --config-name=megatron_gpt_config \ - trainer.gpus=1 \ - trainer.max_steps=10 \ - trainer.val_check_interval=1 \ - exp_manager.name='megatron_gpt125M_prompt_tuning' \ - exp_manager.checkpoint_callback_params.save_top_k=2 \ - exp_manager.checkpoint_callback_params.save_nemo_on_train_end=True \ - restore_from_path='/home/TestData/nlp/megatron_gpt/125M/megatron_gpt.nemo' \ - +model.use_soft_prompts=True \ - +model.num_prompt_tokens=10 \ - +model.new_prompt_tags=['Winogrande, BoolQ'] \ - +model.new_prompt_init_text=['logic choose person name, None'] \ - +model.new_prompt_init_methods=['text, random'] \ - model.data.data_prefix=None \ - +model.data.train_ds='/home/TestData/nlp/prompt_tuning/wino_bool_prompt_tuning_train.json' \ - +model.data.valid_ds='/home/TestData/nlp/prompt_tuning/wino_bool_prompt_tuning_val.json' \ - +model.data.test_ds='/home/TestData/nlp/prompt_tuning/wino_bool_prompt_tuning_val.json' \ - +model.data.batch_size=8 \ - model.optim.lr=2e-2 \ - model.optim.sched.min_lr=2e-3 \ - model.optim.sched.warmup_steps=2 \ - model.optim.sched.constant_steps=8 \ - model.encoder_seq_length=2048" + --config-name=megatron_prompt_tuning_gpt \ + restore_from_path='/home/TestData/nlp/megatron_gpt/125M/megatron_gpt.nemo' \ + trainer.val_check_interval=2 \ + trainer.max_steps=5 \ + model.new_prompt_tags=['Winogrande, BoolQ'] \ + model.new_prompt_init_text=['logic choose person name, None'] \ + model.new_prompt_init_methods=['text, random'] \ + model.data.train_ds='/home/TestData/nlp/prompt_tuning/wino_bool_prompt_tuning_train.json' \ + model.data.valid_ds='/home/TestData/nlp/prompt_tuning/wino_bool_prompt_tuning_val.json' \ + +model.data.test_ds='/home/TestData/nlp/prompt_tuning/wino_bool_prompt_tuning_val.json' \ + model.micro_batch_size=2 \ + model.global_batch_size=4 \ + model.optim.lr=2e-2 \ + model.optim.sched.min_lr=2e-3 \ + model.optim.sched.warmup_steps=2 \ + model.optim.sched.constant_steps=8 \ + model.encoder_seq_length=2048" sh "python examples/nlp/language_modeling/megatron_gpt_eval.py \ --use_soft_prompts \ - --model_file=nemo_experiments/megatron_gpt125M_prompt_tuning/checkpoints/megatron_gpt125M_prompt_tuning.nemo \ + --model_file=nemo_experiments/PromptTuning/checkpoints/PromptTuning.nemo \ --tokens_to_generate=3 \ --prompt_tag='Winogrande' \ --prompt='option1: wood option2: bag sentence: The _ is soft. answer:'" diff --git a/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml b/examples/nlp/language_modeling/conf/megatron_gpt_config.yaml old mode 100644 new mode 100755 diff --git a/examples/nlp/language_modeling/conf/megatron_prompt_tuning_gpt.yaml b/examples/nlp/language_modeling/conf/megatron_prompt_tuning_gpt.yaml new file mode 100755 index 0000000000000..a89f0b07e99e1 --- /dev/null +++ b/examples/nlp/language_modeling/conf/megatron_prompt_tuning_gpt.yaml @@ -0,0 +1,129 @@ +name: PromptTuning +restore_from_path: ??? # used when starting from a .nemo file + +trainer: + gpus: 1 + num_nodes: 1 + accelerator: ddp + precision: 32 + logger: False # logger provided by exp_manager + checkpoint_callback: False + replace_sampler_ddp: False + max_epochs: null + max_steps: 1000 # consumed_samples = global_step * micro_batch_size * data_parallel_size * accumulate_grad_batches + log_every_n_steps: 10 + val_check_interval: 50 + limit_val_batches: 50 + limit_test_batches: 500 + accumulate_grad_batches: 1 # do not modify, grad acc is automatic for training megatron models + gradient_clip_val: null + +exp_manager: + explicit_log_dir: null + exp_dir: null + name: PromptTuning + create_wandb_logger: False + wandb_logger_kwargs: + project: None + name: None + resume_if_exists: True + resume_ignore_no_checkpoint: True + create_checkpoint_callback: True + checkpoint_callback_params: + monitor: val_loss + save_top_k: 3 + mode: min + always_save_nemo: False # saves nemo file during validation, not implemented for model parallel + save_nemo_on_train_end: True # not recommended when training large models on clusters with short time limits + filename: 'megatron_gpt--{val_loss:.2f}-{step}-{consumed_samples}' + model_parallel_size: ${multiply:${model.tensor_model_parallel_size}, ${model.pipeline_model_parallel_size}} + + +model: + # specify micro_batch_size, global_batch_size, and model parallelism + # gradient accumulation will be done automatically based on data_parallel_size + micro_batch_size: 4 # limited by GPU memory + global_batch_size: 16 # will use more micro batches to reach global batch size + tensor_model_parallel_size: 1 # intra-layer model parallelism + pipeline_model_parallel_size: 1 # inter-layer model parallelism + + # model architecture + encoder_seq_length: 2048 + max_position_embeddings: ${.encoder_seq_length} + num_layers: 12 + hidden_size: 768 + ffn_hidden_size: 3072 # Transformer FFN hidden size. Usually 4 * hidden_size. + num_attention_heads: 12 + init_method_std: 0.02 # Standard deviation of the zero mean normal distribution used for weight initialization.') + hidden_dropout: 0.1 # Dropout probability for hidden state transformer. + kv_channels: null # Projection weights dimension in multi-head attention. Set to hidden_size // num_attention_heads if null + apply_query_key_layer_scaling: True # scale Q * K^T by 1 / layer-number. + layernorm_epsilon: 1e-5 + make_vocab_size_divisible_by: 128 # Pad the vocab size to be divisible by this value for computation efficiency. + pre_process: True # add embedding + post_process: True # add pooler + persist_layer_norm: True # Use of persistent fused layer norm kernel. + gradient_as_bucket_view: False # Allocate gradients in a contiguous bucket to save memory (less fragmentation and buffer memory) + + tokenizer: + library: 'megatron' + type: 'GPT2BPETokenizer' + model: null + vocab_file: null + merge_file: null + + # Prompt Tuning + use_soft_prompts: True + num_prompt_tokens: 150 + existing_prompt_tags: [] + new_prompt_tags: ??? + new_prompt_init_text: ['some initialization text goes here'] + new_prompt_init_methods: ['text'] + calc_loss_on_answer_only: False + + + # precision + native_amp_init_scale: 4294967296 # 2 ** 32 + native_amp_growth_interval: 1000 + hysteresis: 2 # Gradient scale hysteresis + fp32_residual_connection: False # Move residual connections to fp32 + fp16_lm_cross_entropy: False # Move the cross entropy unreduced loss calculation for lm head to fp16 + + # Megatron O2-style half-precision + megatron_amp_O2: False # Enable O2-level automatic mixed precision using master parameters + + # miscellaneous + seed: 1234 + use_cpu_initialization: False # Init weights on the CPU (slow for large models) + onnx_safe: False # Use work-arounds for known problems with Torch ONNX exporter. + apex_transformer_log_level: 30 # Python logging level displays logs with severity greater than or equal to this + + activations_checkpoint_method: null # 'uniform', 'block' + activations_checkpoint_num_layers: 1 + + data: + data_prefix: None + train_ds: ??? + valid_ds: ??? + data_impl: mmap + splits_string: 900,50,50 + seq_length: ${model.encoder_seq_length} + skip_warmup: True + num_workers: 0 + dataloader_type: single # cyclic + reset_position_ids: False # Reset position ids after end-of-document token + reset_attention_mask: False # Reset attention mask after end-of-document token + eod_mask_loss: False # Mask loss for the end of document tokens + + optim: + name: fused_adam + lr: 2e-4 + weight_decay: 0.01 + betas: + - 0.9 + - 0.98 + sched: + name: CosineAnnealing + warmup_steps: 50 + constant_steps: 10 + min_lr: 2e-5 diff --git a/examples/nlp/language_modeling/megatron_gpt_prompt_tuning.py b/examples/nlp/language_modeling/megatron_gpt_prompt_tuning.py old mode 100644 new mode 100755 index 0af5c07d5b9a9..c14049381999a --- a/examples/nlp/language_modeling/megatron_gpt_prompt_tuning.py +++ b/examples/nlp/language_modeling/megatron_gpt_prompt_tuning.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +14,21 @@ from omegaconf.omegaconf import OmegaConf, open_dict from pytorch_lightning import Trainer +from pytorch_lightning.callbacks.timer import Timer +from pytorch_lightning.plugins.environments.torchelastic_environment import TorchElasticEnvironment from nemo.collections.nlp.models.language_modeling.megatron_gpt_model import MegatronGPTModel -from nemo.collections.nlp.parts.nlp_overrides import NLPDDPPlugin +from nemo.collections.nlp.modules.common.megatron.megatron_init import fake_initialize_model_parallel +from nemo.collections.nlp.parts.nlp_overrides import ( + GradScaler, + MegatronHalfPrecisionPlugin, + NLPDDPPlugin, + PipelineMixedPrecisionPlugin, +) from nemo.core.config import hydra_runner from nemo.utils import logging -from nemo.utils.exp_manager import exp_manager +from nemo.utils.app_state import AppState +from nemo.utils.exp_manager import StatelessTimer, exp_manager """ @@ -27,9 +36,9 @@ run inference with multiple soft-prompts/tasks within a batch. Datasets should be formatted with in a json file like: -{"prompt_tag": , "text": } -{"prompt_tag": , "text": } -{"prompt_tag": , "text": } +{"prompt_tag": , "text": , "answer": } +{"prompt_tag": , "text": , "answer": } +{"prompt_tag": , "text": , "answer": } Example Usage for first prompt tuning task: @@ -139,38 +148,63 @@ """ -@hydra_runner(config_path="conf", config_name="megatron_gpt_config") +@hydra_runner(config_path="conf", config_name="megatron_prompt_tuning_gpt") def main(cfg) -> None: logging.info("\n\n************** Experiment configuration ***********") logging.info(f'\n{OmegaConf.to_yaml(cfg)}') - plugins = [NLPDDPPlugin(num_nodes=cfg.trainer.num_nodes)] + megatron_amp_o2 = cfg.model.get('megatron_amp_O2', False) + plugins = [ + NLPDDPPlugin( + num_nodes=cfg.trainer.num_nodes, + no_ddp_communication_hook=True, + gradient_as_bucket_view=cfg.model.gradient_as_bucket_view, + ) + ] + if cfg.trainer.precision in [16, 'bf16']: + scaler = None + if cfg.trainer.precision == 16: + scaler = GradScaler( + init_scale=cfg.model.get('native_amp_init_scale', 2 ** 32), + growth_interval=cfg.model.get('native_amp_growth_interval', 1000), + hysteresis=cfg.model.get('hysteresis', 2), + ) + if megatron_amp_o2: + plugins.append(MegatronHalfPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) + else: + plugins.append(PipelineMixedPrecisionPlugin(precision=cfg.trainer.precision, device='cuda', scaler=scaler)) - trainer = Trainer(plugins=plugins, **cfg.trainer) + if cfg.get('cluster_type', None) == 'BCP': + plugins.append(TorchElasticEnvironment()) + trainer = Trainer(plugins=plugins, **cfg.trainer) exp_manager(trainer, cfg.exp_manager) + app_state = AppState() + if cfg.model.tensor_model_parallel_size > 1 or cfg.model.pipeline_model_parallel_size > 1: + app_state.model_parallel_size = cfg.model.tensor_model_parallel_size * cfg.model.pipeline_model_parallel_size + ( + app_state.tensor_model_parallel_rank, + app_state.pipeline_model_parallel_rank, + app_state.model_parallel_size, + _, + ) = fake_initialize_model_parallel( + world_size=app_state.model_parallel_size, + rank=trainer.global_rank, + tensor_model_parallel_size_=cfg.model.tensor_model_parallel_size, + pipeline_model_parallel_size_=cfg.model.pipeline_model_parallel_size, + ) + + # Override timer callback to a stateless one + for idx, callback in enumerate(trainer.callbacks): + if isinstance(callback, Timer): + trainer.callbacks[idx] = StatelessTimer(cfg.trainer.max_time,) + # hydra interpolation does not work here as the interpolation key is lost when PTL saves hparams with open_dict(cfg): cfg.model.precision = cfg.trainer.precision model = MegatronGPTModel.restore_from(cfg.restore_from_path, cfg.model, trainer=trainer) - - # Init all new prompts - for idx, tag in enumerate(cfg.model.new_prompt_tags): - init_method = cfg.model.new_prompt_init_methods[idx] - - if init_method == "text": - init_text = cfg.model.new_prompt_init_text[idx] - model.init_prompt_from_text(tag, init_text) - - elif init_method == 'random': - model.init_prompt_from_random(tag) - - else: - logging.info(f'\n Soft prompt init method {init_method} is not recognized, please use text or random') - - logging.info(f'\nCurrent soft prompts include {model.get_prompt_table()}') trainer.fit(model) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_tuning_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_tuning_dataset.py old mode 100644 new mode 100755 index 48fec64d867e2..323ae390643be --- a/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_tuning_dataset.py +++ b/nemo/collections/nlp/data/language_modeling/megatron/gpt_prompt_tuning_dataset.py @@ -36,20 +36,28 @@ def __init__( self, dataset_path, tokenizer, + prompt_table, num_prompt_tokens: int, + micro_batch_size: int, max_seq_length: int, min_seq_length: int = 1, - add_bos_eos: bool = True, - calc_loss_on_answer_only=True, + add_bos: bool = False, + add_eos: bool = True, + calc_loss_on_answer_only=False, ): self.tokenizer = tokenizer - self.add_bos_eos = add_bos_eos + self.prompt_tag_to_id = dict(prompt_table) + self.add_bos = add_bos + self.add_eos = add_eos self.calc_loss_on_answer_only = calc_loss_on_answer_only self.max_seq_length = max_seq_length self.min_seq_length = min_seq_length self.num_prompt_tokens = num_prompt_tokens + self.micro_batch_size = micro_batch_size self.max_sent_length = max_seq_length - num_prompt_tokens - self.tags_and_tokens = [] + self.prompt_ids_and_tokens = [] + + print(f"\n\nMicro batch size: {micro_batch_size}") assert min_seq_length <= max_seq_length, "Min sequence length should be less than or equal to max" assert max_seq_length > 0, "Max sequence length should be greater than 0" @@ -72,44 +80,70 @@ def __init__( answer_ids = tokenizer.text_to_ids(answer) answer_len = len(answer_ids) - if self.add_bos_eos: - sent_ids = [tokenizer.bos_id] + sent_ids + [tokenizer.eos_id] + if self.add_bos: + sent_ids = [tokenizer.bos_id] + sent_ids + + if self.add_eos: + sent_ids = sent_ids + [tokenizer.eos_id] answer_len += 1 # To account for EOS token # Need to leave space for prompt tokens in sequence if self.min_seq_length <= len(sent_ids) <= self.max_sent_length: - self.tags_and_tokens.append((prompt_tag, sent_ids, answer_len)) - + prompt_id = self.prompt_tag_to_id[prompt_tag] + self.prompt_ids_and_tokens.append((prompt_id, sent_ids, answer_len)) else: skipped += 1 logging.info(f'Skipped {skipped} sentences, sequence length too long or too short') def __len__(self): - return len(self.tags_and_tokens) + return len(self.prompt_ids_and_tokens) def __getitem__(self, idx): - return self.tags_and_tokens[idx] + return self.prompt_ids_and_tokens[idx] def collate_fn(self, batch): - """Build masks and position id for left to right model with prompt tuning.""" + """ Prepares global batch, then splits into micro batches if pipeline parallel is > 1""" + + prompt_ids, input_ids, answer_lens = zip(*batch) + prompt_ids = torch.tensor(prompt_ids) + + # Prepare global batch + tokens, labels, loss_mask, attention_mask, text_position_ids = self.process_global_batch( + input_ids, answer_lens, + ) - prompt_tags, input_ids, answer_lens = zip(*batch) + return tokens, labels, loss_mask, attention_mask, text_position_ids, prompt_ids + def process_global_batch(self, input_ids, answer_lens): + """ Perpare tokens, labels, loss mask, attention_mask, and position ids for global batch """ # Get max sequence length of batch batch_size = len(input_ids) batch_max = max(len(ids) for ids in input_ids) + tokens, loss_mask = self.pad_batch_and_build_loss_mask(input_ids, answer_lens, batch_max) - # Add prompt token length - batch_max_with_prompt = batch_max + self.num_prompt_tokens + # Labels for prompt tokens, just padding because the loss mask masks these out + prompt_token_labels = torch.full( + size=(batch_size, self.num_prompt_tokens - 1), fill_value=self.tokenizer.bos_id, dtype=torch.long, + ) + + # Should be a label for every token in batch, label is the next token, starting with the virtual tokens + labels = torch.cat((prompt_token_labels, tokens.contiguous()), dim=1) + tokens = tokens[:, :-1].contiguous() + text_position_ids, attention_mask = self.get_ltor_attention_mask_and_position_ids(batch_size, tokens) - # Pad tokens in batch to max batch length while building loss mask - loss_masks = [] + return tokens, labels, loss_mask, attention_mask, text_position_ids + + def pad_batch_and_build_loss_mask(self, input_ids, answer_lens, batch_max): + """ Pad tokens in batch to max batch length while building loss mask """ + loss_mask = [] for idx, ids in enumerate(input_ids): text_length = len(ids) answer_length = answer_lens[idx] - prompt_loss_mask = [0.0] * self.num_prompt_tokens + # Loss mask should match labels + # Subtracting one because loss mask should align with labels + prompt_loss_mask = [0.0] * (self.num_prompt_tokens - 1) # Loss mask everything except the answer if self.calc_loss_on_answer_only: @@ -122,38 +156,38 @@ def collate_fn(self, batch): text_loss_mask = [1.0] * text_length text_loss_mask = prompt_loss_mask + text_loss_mask - padding_length = batch_max - text_length - # Pad loss mask and text tokens + padding_length = batch_max - text_length ids.extend([self.tokenizer.eos_id] * padding_length) text_loss_mask.extend([0.0] * padding_length) - loss_masks.append(torch.tensor(text_loss_mask, dtype=torch.float)) + loss_mask.append(torch.tensor(text_loss_mask, dtype=torch.float)) + # Make into a torch tensor tokens = torch.tensor(input_ids, dtype=torch.long) - loss_mask = torch.stack(loss_masks) + loss_mask = torch.stack(loss_mask) + + return tokens, loss_mask + + def get_ltor_attention_mask_and_position_ids(self, batch_size, tokens): + """ Makes prompt tuning left to right attention mask and position ids. + position ids for text start after soft tokens. Position ids for soft + prompts are always the same so they are automatically infered during + the forward pass + """ + + # Full length of every sequence in the batch + full_seq_length = len(tokens[0]) + self.num_prompt_tokens # Position ids for text - text_position_ids = torch.arange(start=self.num_prompt_tokens, end=batch_max_with_prompt, dtype=torch.long,) + text_position_ids = torch.arange(start=self.num_prompt_tokens, end=full_seq_length, dtype=torch.long,) text_position_ids = text_position_ids.unsqueeze(0).expand_as(tokens).clone() # Attention mask (lower triangular) starting with prompt tokens - attention_mask = torch.tril(torch.ones((batch_size, batch_max_with_prompt, batch_max_with_prompt))).view( - batch_size, 1, batch_max_with_prompt, batch_max_with_prompt + attention_mask = torch.tril(torch.ones((batch_size, full_seq_length, full_seq_length))).view( + batch_size, 1, full_seq_length, full_seq_length ) # Convert attention mask to binary: attention_mask = attention_mask < 0.5 - # Labels for prompt tokens - prompt_token_labels = torch.full( - size=(batch_size, self.num_prompt_tokens), fill_value=self.tokenizer.bos_id, dtype=torch.long, - ) - - # Should be a label for every token in batch - labels = torch.cat((prompt_token_labels, tokens[:, 1:].contiguous()), dim=1) - final_label = torch.full(size=(batch_size, 1), fill_value=self.tokenizer.eos_id, dtype=torch.long,) - - # Last label should be eos, even for longest sequence in batch - labels = torch.cat((labels, final_label), dim=1) - - return tokens, labels, prompt_tags, attention_mask, loss_mask, text_position_ids + return text_position_ids, attention_mask diff --git a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py old mode 100644 new mode 100755 index 35d88b9b2a506..fef5198caa441 --- a/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron/gpt_model.py @@ -101,7 +101,7 @@ def __init__( onnx_safe=False, use_soft_prompts=False, num_prompt_tokens=10, - prompt_tags=None, + existing_prompt_tags=None, ): super(GPTModel, self).__init__() @@ -147,7 +147,7 @@ def __init__( onnx_safe=onnx_safe, use_soft_prompts=use_soft_prompts, num_prompt_tokens=num_prompt_tokens, - prompt_tags=prompt_tags, + existing_prompt_tags=existing_prompt_tags, ) self.initialize_word_embeddings( @@ -164,7 +164,7 @@ def forward( position_ids, attention_mask, labels=None, - prompt_tags=None, + prompt_ids=None, tokentype_ids=None, layer_past=None, get_key_value=False, @@ -176,7 +176,7 @@ def forward( input_ids, position_ids, attention_mask, - prompt_tags=prompt_tags, + prompt_ids=prompt_ids, layer_past=layer_past, get_key_value=get_key_value, encoder_input=encoder_input, @@ -219,8 +219,8 @@ def load_state_dict(self, state_dict, strict=True): state_dict = state_dict[self._language_model_key] self.language_model.load_state_dict(state_dict, strict=strict) - def _init_prompt_from_random(self, prompt_tag): - self.language_model._init_prompt_from_random(prompt_tag) + def _init_prompt_from_random(self, prompt_tag, prompt_id): + self.language_model._init_prompt_from_random(prompt_tag, prompt_id) - def _init_prompt_from_text(self, prompt_tag, init_token_ids): - self.language_model._init_prompt_from_text(prompt_tag, init_token_ids) + def _init_prompt_from_text(self, prompt_tag, prompt_id, init_token_ids): + self.language_model._init_prompt_from_text(prompt_tag, prompt_id, init_token_ids) diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py old mode 100644 new mode 100755 index d58a2e80f4f26..03693fcf2848e --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -114,17 +114,25 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): # This means we can only use pipeline parallelism without the interleaved schedule. self.model = build_model(model_provider_func=self.model_provider_func, wrap_with_ddp=False)[0] - self.use_soft_prompts = False + # Prompt tuning initialization + self.use_soft_prompts = self.cfg.get('use_soft_prompts', False) + + if self.use_soft_prompts: + if self.cfg.get('pipeline_model_parallel_size', 1) > 1: + raise NotImplementedError("Prompt tuning is not yet supported for pipeline parallel > 1") - if self.cfg.get('use_soft_prompts', False): - self.use_soft_prompts = True self.prompts_to_tune = set([]) self.prompt_table = set([]) - self.num_prompt_tokens = cfg.get('num_prompt_tokens', 10) + self.next_prompt_id = 0 + self.num_prompt_tokens = cfg.get('num_prompt_tokens', 100) if self.cfg.get('existing_prompt_tags', None): + # Fill table with prev tuned prompt tags and their ids self.prompt_table = set(self.cfg.existing_prompt_tags) - raise ValueError('prompt tuning is temporarily disabled. Please use NeMo 1.6') + + # Get max prompt id from table for starting point of new prompt ids + self.next_prompt_id = max(self.prompt_table, key=lambda x: x[1])[1] + self.setup_optimizer_param_groups() self.megatron_amp_o2 = cfg.get('megatron_amp_O2', False) @@ -172,15 +180,15 @@ def model_provider_func(self, pre_process, post_process): layernorm_epsilon=self.cfg.get('layernorm_epsilon', 1e-5), onnx_safe=self.cfg.get('onnx_safe', False), use_soft_prompts=self.cfg.get('use_soft_prompts', False), - num_prompt_tokens=self.cfg.get('num_prompt_tokens', 10), - prompt_tags=self.cfg.get('existing_prompt_tags', None), + num_prompt_tokens=self.cfg.get('num_prompt_tokens', 100), + existing_prompt_tags=self.cfg.get('existing_prompt_tags', None), persist_layer_norm=self.cfg.get('persist_layer_norm', False), ) return model - def forward(self, tokens, text_position_ids, attention_mask, labels, prompt_tags=None): - output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels, prompt_tags=prompt_tags,) + def forward(self, tokens, text_position_ids, attention_mask, labels, prompt_ids=None): + output_tensor = self.model(tokens, text_position_ids, attention_mask, labels=labels, prompt_ids=prompt_ids,) return output_tensor def setup_optimizer_param_groups(self): @@ -203,10 +211,11 @@ def training_step(self, batch, batch_idx): if self.use_soft_prompts: # The micro batches are already prepared for apex by the prompt tuning dataclass batch_for_pipeline = batch + tensor_shape = [len(batch_for_pipeline[0][0]), self.cfg.micro_batch_size, self.cfg.hidden_size] else: # we prepare the micro batches for the apex fwd/bwd function batch_for_pipeline = self.process_global_batch(batch) - tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] + tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: @@ -273,6 +282,7 @@ def training_step(self, batch, batch_idx): prog_bar=True, rank_zero_only=True, ) + return loss_mean def on_train_batch_end(self, outputs, batch, batch_idx: int, unused: Optional[int] = 0) -> None: @@ -332,7 +342,6 @@ def allreduce_gradients(self): """ # Bucketize and all-reduce buckets = {} - # Pack the buckets. for param in self.parameters(): if param.requires_grad and param.grad is not None: tp = param.data.type() @@ -373,9 +382,14 @@ def allreduce_first_last_embeddings(self): def get_forward_output_and_loss_func(self): def fwd_output_and_loss_func(batch, model): batch = [x.cuda() for x in batch] - tokens, labels, loss_mask, attention_mask, position_ids = batch - attention_mask = attention_mask[0:1] - output_tensor = model(tokens, position_ids, attention_mask, labels) + + if self.use_soft_prompts: + tokens, labels, loss_mask, attention_mask, position_ids, prompt_ids = batch + output_tensor = model(tokens, position_ids, attention_mask, labels, prompt_ids=prompt_ids) + else: + tokens, labels, loss_mask, attention_mask, position_ids = batch + attention_mask = attention_mask[0:1] + output_tensor = model(tokens, position_ids, attention_mask, labels) def loss_func(output_tensor): loss = self.loss_func(loss_mask, output_tensor) @@ -389,9 +403,14 @@ def loss_func(output_tensor): def get_forward_output_only_func(self): def fwd_output_only_func(batch, model): batch = [x.cuda() for x in batch] - tokens, attention_mask, position_ids = batch - attention_mask = attention_mask[0:1] - output_tensor = model(tokens, position_ids, attention_mask) + + if self.use_soft_prompts: + tokens, attention_mask, position_ids, prompt_ids = batch + output_tensor = model(tokens, position_ids, attention_mask, prompt_ids=prompt_ids) + else: + tokens, attention_mask, position_ids = batch + attention_mask = attention_mask[0:1] + output_tensor = model(tokens, position_ids, attention_mask) def id_func(output_tensor): return output_tensor, {'logits': output_tensor} @@ -408,8 +427,14 @@ def validation_step(self, batch, batch_idx): The list of microbatches is then piped through the pipeline using Apex fwd/bwd functions. """ - batch_for_pipeline = self.process_global_batch(batch) - tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] + if self.use_soft_prompts: + # The micro batches are already prepared for apex by the prompt tuning dataclass + batch_for_pipeline = batch + tensor_shape = [len(batch_for_pipeline[0][0]), self.cfg.micro_batch_size, self.cfg.hidden_size] + else: + batch_for_pipeline = self.process_global_batch(batch) + tensor_shape = [self.cfg.encoder_seq_length, self.cfg.micro_batch_size, self.cfg.hidden_size] + if self.cfg.get('pipeline_model_parallel_size', 1) > 1: losses_reduced_per_micro_batch = forward_backward_pipelining_without_interleaving( forward_step_func=self.get_forward_output_and_loss_func(), @@ -592,19 +617,23 @@ def build_prompt_tuning_dataset(self, dataset_path): dataset = GPTPromptTuningDataset( dataset_path=dataset_path, tokenizer=self.tokenizer, + prompt_table=self.prompt_table, num_prompt_tokens=self.cfg.num_prompt_tokens, micro_batch_size=self.cfg.micro_batch_size, - max_seq_length=self.cfg.data.get('max_seq_length', 512), + max_seq_length=self.cfg.data.get('max_seq_length', self.cfg.max_position_embeddings), min_seq_length=self.cfg.data.get('min_seq_length', 1), - add_bos_eos=self.cfg.data.get('add_bos_eos', True), - calc_loss_on_answer_only=self.cfg.get('calc_loss_on_answer_only', True), + add_bos=self.cfg.data.get('add_bos', False), + add_eos=self.cfg.data.get('add_eos', True), + calc_loss_on_answer_only=self.cfg.get('calc_loss_on_answer_only', False), ) dataloader = torch.utils.data.DataLoader( dataset, - batch_size=self.cfg.data.batch_size, + batch_size=self.cfg.global_batch_size, collate_fn=dataset.collate_fn, num_workers=self.cfg.data.num_workers, + drop_last=True, + shuffle=True, pin_memory=True, ) @@ -617,6 +646,10 @@ def setup(self, stage=None): Args: stage (str, optional): Can be 'fit', 'validate', 'test' or 'predict'. Defaults to None. """ + # Initalize soft prompts before loading datasets and training + if self.use_soft_prompts: + self.init_new_prompts() + if stage == 'predict': return else: @@ -638,8 +671,8 @@ def setup_training_data(self, cfg): else: raise AttributeError('No prompt tuning train dataset was specified in the cfg file') - # Freeze all weights except prompt embeddings - self.prompt_tuning_freeze() + # Freeze all weights except prompt embeddings and setup optimizer with prompt embedding params + self.prompt_tuning_param_freeze_and_optimizer_setup() elif hasattr(self, '_train_ds'): resume_checkpoint_path = self.trainer.checkpoint_connector.resume_from_checkpoint_fit_path @@ -741,6 +774,8 @@ def configure_gradient_clipping(self, *args, **kwargs): if self.megatron_amp_o2: # grep fp32 master parameters for gradient clipping + if self.use_soft_prompts: + raise NotImplementedError("Prompt tuning is not implemented for amp_o2") parameters = self._optimizer.get_parameters() else: parameters = self.get_parameters() @@ -749,21 +784,27 @@ def configure_gradient_clipping(self, *args, **kwargs): self.log('grad_norm', grad_norm, rank_zero_only=True) - def prompt_tuning_freeze(self): + def prompt_tuning_param_freeze_and_optimizer_setup(self): """Freeze weights of word embeddings and decoder, leaving only prompt embeddings unfrozen """ + weight_decay_params = {'params': []} + no_weight_decay_params = {'params': [], 'weight_decay': 0.0} + for param in self.model.parameters(): param.requires_grad = False # Only want new prompt tags to be tunable, leave existing prompt tags alone for prompt_tag in self.model.language_model.prompt_table.prompt_table.keys(): if prompt_tag in self.prompts_to_tune: - for param in self.model.language_model.prompt_table.prompt_table[prompt_tag].parameters(): - param.requires_grad = True + for params in self.model.language_model.prompt_table.prompt_table[prompt_tag].parameters(): + params.requires_grad = True + weight_decay_params['params'].append(params) else: for param in self.model.language_model.prompt_table.prompt_table[prompt_tag].parameters(): param.requires_grad = False + self._optimizer_param_groups = weight_decay_params, no_weight_decay_params + @classmethod def _bucketize_gpt_inference(cls, batch, use_soft_prompts=False): batch_tokens, lens, tokens_to_generate, compute_logprobs = batch[:4] @@ -861,8 +902,10 @@ def complete(self, request: Dict, positions: List, tokens_to_generate: int): if self.cfg.get('pipeline_model_parallel_size', 1) > 1: raise ValueError('complete method is not yet supported for pipeline with soft prompts') prompt_tags = request["prompt_tags"][idx] + prompt_tags_to_ids = dict(self.prompt_table) + prompt_ids = torch.tensor([prompt_tags_to_ids[tag] for tag in prompt_tags]) else: - prompt_tags = None + prompt_ids = None logsoftmaxlayer = torch.nn.LogSoftmax(dim=-1) @@ -891,8 +934,10 @@ def complete(self, request: Dict, positions: List, tokens_to_generate: int): reset_attention_mask=self.cfg.get('reset_attention_mask', False), eod_mask_loss=self.cfg.get('eod_mask_loss', False), ) - - batch = [tokens, attention_mask, position_ids] + if self.use_soft_prompts: + batch = [tokens, attention_mask, position_ids, prompt_ids] + else: + batch = [tokens, attention_mask, position_ids] tensor_shape = [tokens.shape[1], 1, self.cfg.hidden_size] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( @@ -981,8 +1026,10 @@ def compute_logprobs(self, request: Dict, positions: List): if self.cfg.get('pipeline_model_parallel_size', 1) > 1: raise ValueError('compute_logprobs method is not yet supported for pipeline with soft prompts') prompt_tags = request["prompt_tags"][idx] + prompt_tags_to_ids = dict(self.prompt_table) + prompt_ids = torch.tensor([prompt_tags_to_ids[tag] for tag in prompt_tags]) else: - prompt_tags = None + prompt_ids = None if self.use_soft_prompts: batch_size = len(tokens_cut) @@ -1007,7 +1054,10 @@ def compute_logprobs(self, request: Dict, positions: List): eod_mask_loss=self.cfg.get('eod_mask_loss', False), ) - batch = [tokens_cut, attention_mask, position_ids] + if self.use_soft_prompts: + batch = [tokens, attention_mask, position_ids, prompt_ids] + else: + batch = [tokens, attention_mask, position_ids] tensor_shape = [tokens_cut.shape[1], 1, self.cfg.hidden_size] if self.cfg.get('pipeline_model_parallel_size', 1) > 1: output_tensor = forward_backward_pipelining_without_interleaving( @@ -1062,14 +1112,33 @@ def compute_logprobs(self, request: Dict, positions: List): return response + def init_new_prompts(self): + for idx, tag in enumerate(self.cfg.new_prompt_tags): + init_method = self.cfg.new_prompt_init_methods[idx] + + if init_method == "text": + init_text = self.cfg.new_prompt_init_text[idx] + self.init_prompt_from_text(tag, init_text) + + elif init_method == 'random': + self.init_prompt_from_random(tag) + + else: + raise AttributeError( + f'\n Soft prompt init method {init_method} is not recognized\ + please use text or random' + ) + def init_prompt_from_random(self, prompt_tag): - self.model._init_prompt_from_random(prompt_tag) - self._add_prompt_tag(prompt_tag) + prompt_id = self._get_next_prompt_id() + self.model._init_prompt_from_random(prompt_tag, prompt_id) + self._add_prompt_tag(prompt_tag, prompt_id) def init_prompt_from_text(self, prompt_tag, init_text): + prompt_id = self._get_next_prompt_id() init_token_ids = self.tokenizer.text_to_ids(init_text) - self.model._init_prompt_from_text(prompt_tag, init_token_ids) - self._add_prompt_tag(prompt_tag) + self.model._init_prompt_from_text(prompt_tag, prompt_id, init_token_ids) + self._add_prompt_tag(prompt_tag, prompt_id) def get_prompt_table(self): if hasattr(self, 'prompt_table'): @@ -1078,11 +1147,15 @@ def get_prompt_table(self): def list_available_models(self): return None - def _add_prompt_tag(self, prompt_tag): + def _get_next_prompt_id(self): + self.next_prompt_id += 1 + return self.next_prompt_id + + def _add_prompt_tag(self, prompt_tag, prompt_id): if not hasattr(self, 'prompt_table'): raise AttributeError('Please set "use_soft_prompts" in cfg to True') - self.prompt_table.add(prompt_tag) + self.prompt_table.add((prompt_tag, prompt_id)) self.prompts_to_tune.add(prompt_tag) # Add new prompt tag to cfg for loading prompt table at inference diff --git a/nemo/collections/nlp/modules/common/megatron/clip_grads.py b/nemo/collections/nlp/modules/common/megatron/clip_grads.py index da44618db02e2..ea5c522783ebd 100644 --- a/nemo/collections/nlp/modules/common/megatron/clip_grads.py +++ b/nemo/collections/nlp/modules/common/megatron/clip_grads.py @@ -67,6 +67,9 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): if grad_not_none and is_not_shared and is_not_tp_duplicate: grads_for_norm.append(grad) + if not grads_for_norm: + raise ValueError("No grads found, please disable gradient clipping") + # Norm parameters. max_norm = float(max_norm) norm_type = float(norm_type) diff --git a/nemo/collections/nlp/modules/common/megatron/language_model.py b/nemo/collections/nlp/modules/common/megatron/language_model.py old mode 100644 new mode 100755 index b61c9cde6fad1..5cb0d067771ec --- a/nemo/collections/nlp/modules/common/megatron/language_model.py +++ b/nemo/collections/nlp/modules/common/megatron/language_model.py @@ -1,4 +1,4 @@ -# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. +# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -73,7 +73,7 @@ def get_language_model( onnx_safe=False, use_soft_prompts=False, num_prompt_tokens=10, - prompt_tags=None, + existing_prompt_tags=None, ): """Build language model and return along with the key to save.""" @@ -122,7 +122,7 @@ def get_language_model( onnx_safe=onnx_safe, use_soft_prompts=use_soft_prompts, num_prompt_tokens=num_prompt_tokens, - prompt_tags=prompt_tags, + existing_prompt_tags=existing_prompt_tags, ) # key used for checkpoints. language_model_key = 'language_model' @@ -226,10 +226,15 @@ def add_tokentype_embeddings(self, num_tokentypes): # Initialize the token-type embeddings. self.init_method(self.tokentype_embeddings.weight) - def forward(self, input_ids, position_ids, tokentype_ids=None): + def forward(self, input_ids, position_ids, tokentype_ids=None, separate_embeddings=False): # Embeddings. words_embeddings = self.word_embeddings(input_ids) position_embeddings = self.position_embeddings(position_ids) + + # Want word embeddings and position embeddings before addition for soft prompt initalization + if separate_embeddings: + return words_embeddings, position_embeddings + embeddings = words_embeddings + position_embeddings if tokentype_ids is not None: assert self.tokentype_embeddings is not None @@ -332,30 +337,36 @@ def __init__( self.hidden_size = hidden_size self.num_prompt_tokens = num_prompt_tokens + + # Randomly init token and position embeddings self.prompt_embeddings = torch.nn.Embedding(self.num_prompt_tokens, self.hidden_size) + self.position_embeddings = torch.nn.Embedding(self.num_prompt_tokens, self.hidden_size) init_method(self.prompt_embeddings.weight) + init_method(self.position_embeddings.weight) + # Set embedding weights to be embeddings from prompt tokens if init_from_prompt_text: - - # Set embedding weights to be embeddings from prompt tokens self.prompt_embeddings.weight = nn.Parameter(word_embedding_weights) + if position_embedding_weights != None: + self.position_embeddings.weight = nn.Parameter(position_embedding_weights) + # Set keys for loading and saving weights self._prompt_embeddings_key = 'prompt_embeddings' - - self.position_embeddings = torch.nn.Embedding(self.num_prompt_tokens, self.hidden_size) self._position_embeddings_key = 'position_embeddings' - if position_embedding_weights != None: - self.position_embeddings.weight = nn.Parameter(position_embedding_weights) + # Set ids needed for forward pass and broadcast them + # ids = {'ids': torch.arange(self.num_prompt_tokens, dtype=torch.int64)} + # ids_b = tensor_parallel.broadcast_data(['ids'], ids, torch.int64) + # self.ids = ids_b['ids'].long() + self.ids = torch.arange(self.num_prompt_tokens, dtype=torch.int64) - self.prompt_ids = torch.tensor([i for i in range(self.num_prompt_tokens)]) self.embedding_dropout = torch.nn.Dropout(prompt_embedding_dropout_prob) def forward(self, tokentype_ids=None): # Embeddings. device = next(self.prompt_embeddings.parameters()).device - prompt_embeddings = self.prompt_embeddings(self.prompt_ids.to(device)) - position_embeddings = self.position_embeddings(self.prompt_ids.to(device)) + prompt_embeddings = self.prompt_embeddings(self.ids.to(device)) + position_embeddings = self.position_embeddings(self.ids.to(device)) embeddings = prompt_embeddings + position_embeddings # Dropout. @@ -403,50 +414,57 @@ def load_state_dict(self, state_dict, strict=True): class PromptTable(torch.nn.Module): def __init__( - self, prompt_tags, num_prompt_tokens, hidden_size, + self, existing_prompt_tags, num_prompt_tokens, hidden_size, ): super().__init__() self.num_prompt_tokens = num_prompt_tokens self.hidden_size = hidden_size self.prompt_table = torch.nn.ModuleDict() + self.prompt_id_to_tag = {} - if prompt_tags: - for tag in enumerate(prompt_tags): - _, tag = tag + if existing_prompt_tags: + for tag, prompt_id in existing_prompt_tags: + self.prompt_id_to_tag[prompt_id] = tag self.prompt_table[tag] = PromptEmbedding( init_from_prompt_text=False, hidden_size=self.hidden_size, num_prompt_tokens=self.num_prompt_tokens, ) - def forward(self, prompt_tag): + def forward(self, prompt_id): + prompt_id = prompt_id.item() + prompt_tag = self.prompt_id_to_tag[prompt_id] return self.prompt_table[prompt_tag]() def remove_prompt(self, prompt_tag): + if prompt_tag not in prompt_table: + return + + # find the prompt_id assocaited with the tag to delete + prompt_id = None + for key, value in prompt_id_to_tag.items(): + if value == prompt_tag: + prompt_id = key + break + + del self.prompt_id_to_tag[prompt_id] del self.prompt_table[prompt_tag] - def init_prompt_from_random(self, prompt_tag, position_embeddings): + def init_prompt_from_random(self, prompt_tag, prompt_id, embeddings): """Add new soft prompt to be tuned. Intialize prompt weights using pytorch init method """ - device = next(position_embeddings.parameters()).device - position_weights = ( - position_embeddings(torch.tensor([i for i in range(self.num_prompt_tokens)]).to(device)).detach().clone() - ) - # Initalize prompt embeddings from a pytorch random init method prompt_embeddings = PromptEmbedding( - init_from_prompt_text=False, - hidden_size=self.hidden_size, - num_prompt_tokens=self.num_prompt_tokens, - position_embedding_weights=position_weights, + init_from_prompt_text=False, hidden_size=self.hidden_size, num_prompt_tokens=self.num_prompt_tokens, ) self.prompt_table[prompt_tag] = prompt_embeddings + self.prompt_id_to_tag[prompt_id] = prompt_tag - def init_prompt_from_text(self, prompt_tag, init_token_ids, word_embeddings, position_embeddings): + def init_prompt_from_text(self, prompt_tag, prompt_id, init_token_ids, embeddings): """Add new soft prompt to be tuned. Intialize prompt weights from existing embeddings from specific vocab tokens. @@ -460,26 +478,34 @@ def init_prompt_from_text(self, prompt_tag, init_token_ids, word_embeddings, pos elif num_text_tokens < num_prompt_tokens: num_reps = math.ceil(num_prompt_tokens / num_text_tokens) init_token_ids = init_token_ids * num_reps + + # Set dictionary item keys and datatypes for broadcasting + keys = ['text'] + datatype = torch.int64 + + # Broadcast int ids across gpus for tensor parallel init_token_ids = init_token_ids[:num_prompt_tokens] + init_token_ids = {'text': torch.tensor(init_token_ids, dtype=torch.int64)} + init_token_ids_b = tensor_parallel.broadcast_data(keys, init_token_ids, datatype) + init_token_ids = init_token_ids_b['text'].long() + init_position_ids = torch.arange(self.num_prompt_tokens, dtype=torch.long, device=init_token_ids.device) # Use a copy of token embedding weights to initalize the prompt embeddings - device = next(word_embeddings.parameters()).device - embedding_weights = word_embeddings(torch.tensor(init_token_ids, device=device)).detach().clone() - position_weights = ( - position_embeddings(torch.tensor([i for i in range(self.num_prompt_tokens)], device=device)) - .detach() - .clone() - ) + word_embeddings, position_embeddings = embeddings(init_token_ids, init_position_ids, separate_embeddings=True) + + word_embeddings = word_embeddings.detach().clone() + position_embeddings = position_embeddings.detach().clone() prompt_embeddings = PromptEmbedding( init_from_prompt_text=True, hidden_size=self.hidden_size, num_prompt_tokens=self.num_prompt_tokens, - word_embedding_weights=embedding_weights, - position_embedding_weights=position_weights, + word_embedding_weights=word_embeddings, + position_embedding_weights=position_embeddings, ) self.prompt_table[prompt_tag] = prompt_embeddings + self.prompt_id_to_tag[prompt_id] = prompt_tag def load_state_dict(self, state_dict_, strict): for prompt_tag in self.prompt_table: @@ -540,8 +566,8 @@ def __init__( openai_gelu=False, onnx_safe=False, use_soft_prompts=False, - num_prompt_tokens=10, - prompt_tags=None, + num_prompt_tokens=100, + existing_prompt_tags=None, ): super(TransformerLanguageModel, self).__init__() @@ -560,7 +586,7 @@ def __init__( self.hidden_dropout = hidden_dropout self.output_layer_init_method = output_layer_init_method self.use_soft_prompts = use_soft_prompts - self.prompt_tags = prompt_tags + self.existing_prompt_tags = existing_prompt_tags self.num_prompt_tokens = num_prompt_tokens if kv_channels is None: @@ -586,7 +612,9 @@ def __init__( # Soft Prompts if self.use_soft_prompts: self.prompt_table = PromptTable( - prompt_tags=self.prompt_tags, num_prompt_tokens=self.num_prompt_tokens, hidden_size=self.hidden_size, + existing_prompt_tags=self.existing_prompt_tags, + num_prompt_tokens=self.num_prompt_tokens, + hidden_size=self.hidden_size, ) self._prompt_table_key = 'prompt_table' @@ -671,7 +699,7 @@ def forward( enc_input_ids, enc_position_ids, enc_attn_mask, - prompt_tags=None, + prompt_ids=None, dec_input_ids=None, dec_position_ids=None, dec_attn_mask=None, @@ -689,12 +717,10 @@ def forward( embedding_output = self.embedding(enc_input_ids, enc_position_ids, tokentype_ids=tokentype_ids) # Soft prompts - if self.use_soft_prompts and prompt_tags: - prompt_embeddings = [self.prompt_table(tag) for tag in prompt_tags] + if self.use_soft_prompts and prompt_ids != None: + prompt_embeddings = [self.prompt_table(prompt_id) for prompt_id in prompt_ids] prompt_embeddings = torch.stack(prompt_embeddings) - encoder_input = torch.cat((prompt_embeddings, embedding_output), dim=1) - else: encoder_input = embedding_output else: @@ -822,29 +848,26 @@ def load_state_dict(self, state_dict, strict=True): assert 'decoder' in state_dict, 'could not find data for pooler in the checkpoint' self.decoder.load_state_dict(state_dict[self._decoder_key], strict=strict) - def _init_prompt_from_random(self, prompt_tag): + def _init_prompt_from_random(self, prompt_tag, prompt_id): """Add new soft prompt to be tuned. Intialize prompt weights using pytorch init method """ + if self.pre_process: + if not hasattr(self, 'prompt_table'): + raise AttributeError('Please set "use_soft_prompts" in the config to True') - if not hasattr(self, 'prompt_table'): - raise AttributeError('Please set "use_soft_prompts" in the config to True') - - self.prompt_table.init_prompt_from_random(prompt_tag, self.embedding.position_embeddings) + self.prompt_table.init_prompt_from_random(prompt_tag, prompt_id, embeddings=self.embedding) - def _init_prompt_from_text(self, prompt_tag, init_token_ids): + def _init_prompt_from_text(self, prompt_tag, prompt_id, init_token_ids): """Add new soft prompt to be tuned. Intialize prompt weights from existing embeddings from specific vocab tokens. """ + if self.pre_process: + if not hasattr(self, 'prompt_table'): + raise AttributeError('Please set "use_soft_prompts" in the config to True') - if not hasattr(self, 'prompt_table'): - raise AttributeError('Please set "use_soft_prompts" in the config to True') - - self.prompt_table.init_prompt_from_text( - prompt_tag, - init_token_ids, - word_embeddings=self.embedding.word_embeddings, - position_embeddings=self.embedding.position_embeddings, - ) + self.prompt_table.init_prompt_from_text( + prompt_tag, prompt_id, init_token_ids, embeddings=self.embedding, + ) diff --git a/tests/collections/nlp/test_prompt_tuning.py b/tests/collections/nlp/test_prompt_tuning.py index 4528d3e16cf9f..1fd02a22c7b06 100644 --- a/tests/collections/nlp/test_prompt_tuning.py +++ b/tests/collections/nlp/test_prompt_tuning.py @@ -26,11 +26,11 @@ def get_prompt_tuning_dataset(tokenizer, dataset_path, num_prompt_tokens): dataset = GPTPromptTuningDataset( dataset_path=dataset_path, tokenizer=tokenizer, + prompt_table=[('A', 1)], num_prompt_tokens=num_prompt_tokens, + micro_batch_size=4, max_seq_length=512, min_seq_length=1, - add_bos_eos=True, - calc_loss_on_answer_only=True, ) return dataset @@ -78,12 +78,13 @@ def test_prompt_tuning_dataset_collate_fn(self): assert len(batch) == 6 - tokens, labels, prompt_tags, attention_mask, loss_mask, text_position_ids = batch + tokens, labels, loss_mask, attention_mask, text_position_ids, prompt_tags = batch assert len(tokens) == len(loss_mask) == len(attention_mask) == len(text_position_ids) assert len(tokens) == len(prompt_tags) - assert len(tokens[0]) + num_prompt_tokens == len(loss_mask[0]) - assert len(tokens[0]) + num_prompt_tokens == attention_mask[0].size()[-1] + assert len(labels) == len(tokens) + assert len(labels[0]) == len(loss_mask[0]) + assert len(tokens[0]) + (num_prompt_tokens) == attention_mask[0].size()[-1] os.remove(dataset_path)