From 5119eaef7a495c9eda68e896d5279f6a0775785a Mon Sep 17 00:00:00 2001 From: Alexey Panteleev Date: Tue, 9 Apr 2024 13:22:38 -0700 Subject: [PATCH] P-tuning related fixes: - Remember the vtoken counts for each p-tuning table when the tables are added; - Prepend the right number of vtokens to each query based on its task_id; - Preserve the dtype of the p-tuning table when it is padded; - Validate that all p-tuning tables fit into max_prompt_embedding_table_size limit. Signed-off-by: Alexey Panteleev --- nemo/export/tensorrt_llm.py | 22 ++++++++-- nemo/export/trt_llm/tensorrt_llm_run.py | 55 ++++++++++++++++++++----- 2 files changed, 64 insertions(+), 13 deletions(-) diff --git a/nemo/export/tensorrt_llm.py b/nemo/export/tensorrt_llm.py index 473fefaea6a2f..2c6a0f5d100d9 100644 --- a/nemo/export/tensorrt_llm.py +++ b/nemo/export/tensorrt_llm.py @@ -94,6 +94,7 @@ def __init__(self, model_dir: str, lora_ckpt_list: List[str] = None, load_model: self.ptuning_tables = [] self.p_table = None self.task_vocab_size = 0 + self.task_vtoken_counts = [] self.task_ids = {} if load_model: @@ -337,12 +338,15 @@ def forward( prompt_embeddings_table, prompt_embeddings_checkpoint_path ) tv_size = prompt_table.size(dim=0) + task_vtoken_counts = [tv_size] elif len(self.ptuning_tables) > 0: prompt_table = self.p_table tv_size = self.task_vocab_size + task_vtoken_counts = self.task_vtoken_counts else: prompt_table = None tv_size = None + task_vtoken_counts = None if task_ids is None: assert prompt_table is None, "There is a prompt embedding table and task_ids cannot be None" @@ -383,6 +387,7 @@ def forward( temperature=temperature, prompt_table=prompt_table, task_vocab_size=tv_size, + task_vtoken_counts=task_vtoken_counts, task_ids=input_task_ids, lora_uids=lora_uids, stop_words_list=stop_words_list, @@ -402,6 +407,7 @@ def forward( temperature=temperature, prompt_table=prompt_table, task_vocab_size=tv_size, + task_vtoken_counts=task_vtoken_counts, task_ids=input_task_ids, lora_uids=lora_uids, stop_words_list=stop_words_list, @@ -557,19 +563,29 @@ def _prep_ptuning_table(self): if self.task_vocab_size < pt["table"].size(dim=0): self.task_vocab_size = pt["table"].size(dim=0) - # pad tasks to longest task embedding table + # pad tasks to longest task embedding table, remember the original task vtoken counts vtokens_embeddings = [] + self.task_vtoken_counts = [] self.task_ids = {} tid = 0 for i, ptuning_table in enumerate(self.ptuning_tables): - padded_table = torch.zeros((self.task_vocab_size, self.get_hidden_size)) - padded_table[: ptuning_table["table"].size(dim=0), :] = ptuning_table["table"] + original_table = ptuning_table["table"] + vtoken_count = original_table.size(dim=0) + padded_table = torch.zeros((self.task_vocab_size, self.get_hidden_size), dtype=original_table.dtype) + padded_table[:vtoken_count, :] = original_table vtokens_embeddings.append(padded_table) self.task_ids[ptuning_table["task_name"]] = tid + self.task_vtoken_counts.append(vtoken_count) tid = tid + 1 if len(vtokens_embeddings) > 0: self.p_table = torch.stack(vtokens_embeddings, dim=0).view(-1, self.get_hidden_size) + + max_prompt_embedding_table_size = self.config['builder_config']['max_prompt_embedding_table_size'] + actual_prompt_table_size = self.p_table.shape[0] + + if actual_prompt_table_size > max_prompt_embedding_table_size: + raise Exception(f"The size of the combined prompt embedding table ({actual_prompt_table_size}) is greater than max_prompt_embedding_table_size ({max_prompt_embedding_table_size}).") else: self.p_table = None diff --git a/nemo/export/trt_llm/tensorrt_llm_run.py b/nemo/export/trt_llm/tensorrt_llm_run.py index cdc0b78d6c184..f870905b81ba1 100644 --- a/nemo/export/trt_llm/tensorrt_llm_run.py +++ b/nemo/export/trt_llm/tensorrt_llm_run.py @@ -481,6 +481,47 @@ def forward( raise RuntimeError("Internal error") +def prepare_input_tensors( + input_texts: List[str], + host_context: TensorrtLLMHostContext, + prompt_table=None, + task_vtoken_counts: List[int]=None, + task_ids: List[int]=None +): + tokenizer = host_context.tokenizer + + if host_context.add_bos: + bos_tokens = [tokenizer.bos_token_id] + else: + bos_tokens = [] + + input_tokens = [bos_tokens + tokenizer.encode(t) for t in input_texts] + + # If p-tuning is used, we need to prepend vtokens to each input. + if prompt_table is not None: + + # Go over the tokenized prompts and prepend vtokens. + # The number of vtokens could be different for each task. + for prompt_index in range(len(input_texts)): + # Find out the number of vtokens to generate + task_id = task_ids[prompt_index] + num_vtokens = task_vtoken_counts[task_id] + + # Create a tensor with vtokens, e.g. 32000, 32001, 32002... when vocab_size=32000 + # TRT-LLM will convert each vtoken into its corresponding embedding row from the prompt table. + vocab_size = tokenizer.vocab_size + vtokens = list(range(vocab_size, vocab_size + num_vtokens)) + + # Concatenate the vtokens with the real tokens + real_tokens = input_tokens[prompt_index] + input_tokens[prompt_index] = vtokens + real_tokens + + # Convert input token lists to tensors + input_tensors = [torch.IntTensor(token_list) for token_list in input_tokens] + + return input_tensors + + def generate( input_texts: List[str], max_output_len: int, @@ -490,6 +531,7 @@ def generate( temperature: float = 1.0, prompt_table=None, task_vocab_size=None, + task_vtoken_counts: List[int]=None, task_ids: List[int] = None, lora_uids: List[str] = None, stop_words_list=None, @@ -505,11 +547,7 @@ def generate( Returns a 2D string list with shape [batch_size, num_beams]. """ tokenizer = host_context.tokenizer - - if host_context.add_bos: - input_tensors = [torch.IntTensor([tokenizer.bos_token_id] + tokenizer.encode(t)) for t in input_texts] - else: - input_tensors = [torch.IntTensor(tokenizer.encode(t)) for t in input_texts] + input_tensors = prepare_input_tensors(input_texts, host_context, prompt_table, task_vtoken_counts, task_ids) stop_words_list_tensors = None if stop_words_list is not None: @@ -572,6 +610,7 @@ def generate_streaming( temperature: float = 1.0, prompt_table=None, task_vocab_size=None, + task_vtoken_counts: List[int]=None, task_ids: List[int] = None, lora_uids: List[str] = None, stop_words_list=None, @@ -584,11 +623,7 @@ def generate_streaming( Returns a 2D string list with shape [batch_size, num_beams]. """ tokenizer = host_context.tokenizer - - if host_context.add_bos: - input_tensors = [torch.IntTensor([tokenizer.bos_token_id] + tokenizer.encode(t)) for t in input_texts] - else: - input_tensors = [torch.IntTensor(tokenizer.encode(t)) for t in input_texts] + input_tensors = prepare_input_tensors(input_texts, host_context, prompt_table, task_vtoken_counts, task_ids) batch_size = len(input_texts)