Skip to content

Commit

Permalink
P-tuning related fixes:
Browse files Browse the repository at this point in the history
- Remember the vtoken counts for each p-tuning table when the tables are added;
- Prepend the right number of vtokens to each query based on its task_id;
- Preserve the dtype of the p-tuning table when it is padded;
- Validate that all p-tuning tables fit into max_prompt_embedding_table_size limit.

Signed-off-by: Alexey Panteleev <alpanteleev@nvidia.com>
  • Loading branch information
apanteleev committed Apr 9, 2024
1 parent 3402caa commit 5119eae
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 13 deletions.
22 changes: 19 additions & 3 deletions nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ def __init__(self, model_dir: str, lora_ckpt_list: List[str] = None, load_model:
self.ptuning_tables = []
self.p_table = None
self.task_vocab_size = 0
self.task_vtoken_counts = []
self.task_ids = {}

if load_model:
Expand Down Expand Up @@ -337,12 +338,15 @@ def forward(
prompt_embeddings_table, prompt_embeddings_checkpoint_path
)
tv_size = prompt_table.size(dim=0)
task_vtoken_counts = [tv_size]
elif len(self.ptuning_tables) > 0:
prompt_table = self.p_table
tv_size = self.task_vocab_size
task_vtoken_counts = self.task_vtoken_counts
else:
prompt_table = None
tv_size = None
task_vtoken_counts = None

if task_ids is None:
assert prompt_table is None, "There is a prompt embedding table and task_ids cannot be None"
Expand Down Expand Up @@ -383,6 +387,7 @@ def forward(
temperature=temperature,
prompt_table=prompt_table,
task_vocab_size=tv_size,
task_vtoken_counts=task_vtoken_counts,
task_ids=input_task_ids,
lora_uids=lora_uids,
stop_words_list=stop_words_list,
Expand All @@ -402,6 +407,7 @@ def forward(
temperature=temperature,
prompt_table=prompt_table,
task_vocab_size=tv_size,
task_vtoken_counts=task_vtoken_counts,
task_ids=input_task_ids,
lora_uids=lora_uids,
stop_words_list=stop_words_list,
Expand Down Expand Up @@ -557,19 +563,29 @@ def _prep_ptuning_table(self):
if self.task_vocab_size < pt["table"].size(dim=0):
self.task_vocab_size = pt["table"].size(dim=0)

# pad tasks to longest task embedding table
# pad tasks to longest task embedding table, remember the original task vtoken counts
vtokens_embeddings = []
self.task_vtoken_counts = []
self.task_ids = {}
tid = 0
for i, ptuning_table in enumerate(self.ptuning_tables):
padded_table = torch.zeros((self.task_vocab_size, self.get_hidden_size))
padded_table[: ptuning_table["table"].size(dim=0), :] = ptuning_table["table"]
original_table = ptuning_table["table"]
vtoken_count = original_table.size(dim=0)
padded_table = torch.zeros((self.task_vocab_size, self.get_hidden_size), dtype=original_table.dtype)
padded_table[:vtoken_count, :] = original_table
vtokens_embeddings.append(padded_table)
self.task_ids[ptuning_table["task_name"]] = tid
self.task_vtoken_counts.append(vtoken_count)
tid = tid + 1

if len(vtokens_embeddings) > 0:
self.p_table = torch.stack(vtokens_embeddings, dim=0).view(-1, self.get_hidden_size)

max_prompt_embedding_table_size = self.config['builder_config']['max_prompt_embedding_table_size']
actual_prompt_table_size = self.p_table.shape[0]

if actual_prompt_table_size > max_prompt_embedding_table_size:
raise Exception(f"The size of the combined prompt embedding table ({actual_prompt_table_size}) is greater than max_prompt_embedding_table_size ({max_prompt_embedding_table_size}).")
else:
self.p_table = None

Expand Down
55 changes: 45 additions & 10 deletions nemo/export/trt_llm/tensorrt_llm_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,47 @@ def forward(
raise RuntimeError("Internal error")


def prepare_input_tensors(
input_texts: List[str],
host_context: TensorrtLLMHostContext,
prompt_table=None,
task_vtoken_counts: List[int]=None,
task_ids: List[int]=None
):
tokenizer = host_context.tokenizer

if host_context.add_bos:
bos_tokens = [tokenizer.bos_token_id]
else:
bos_tokens = []

input_tokens = [bos_tokens + tokenizer.encode(t) for t in input_texts]

# If p-tuning is used, we need to prepend vtokens to each input.
if prompt_table is not None:

# Go over the tokenized prompts and prepend vtokens.
# The number of vtokens could be different for each task.
for prompt_index in range(len(input_texts)):
# Find out the number of vtokens to generate
task_id = task_ids[prompt_index]
num_vtokens = task_vtoken_counts[task_id]

# Create a tensor with vtokens, e.g. 32000, 32001, 32002... when vocab_size=32000
# TRT-LLM will convert each vtoken into its corresponding embedding row from the prompt table.
vocab_size = tokenizer.vocab_size
vtokens = list(range(vocab_size, vocab_size + num_vtokens))

# Concatenate the vtokens with the real tokens
real_tokens = input_tokens[prompt_index]
input_tokens[prompt_index] = vtokens + real_tokens

# Convert input token lists to tensors
input_tensors = [torch.IntTensor(token_list) for token_list in input_tokens]

return input_tensors


def generate(
input_texts: List[str],
max_output_len: int,
Expand All @@ -490,6 +531,7 @@ def generate(
temperature: float = 1.0,
prompt_table=None,
task_vocab_size=None,
task_vtoken_counts: List[int]=None,
task_ids: List[int] = None,
lora_uids: List[str] = None,
stop_words_list=None,
Expand All @@ -505,11 +547,7 @@ def generate(
Returns a 2D string list with shape [batch_size, num_beams].
"""
tokenizer = host_context.tokenizer

if host_context.add_bos:
input_tensors = [torch.IntTensor([tokenizer.bos_token_id] + tokenizer.encode(t)) for t in input_texts]
else:
input_tensors = [torch.IntTensor(tokenizer.encode(t)) for t in input_texts]
input_tensors = prepare_input_tensors(input_texts, host_context, prompt_table, task_vtoken_counts, task_ids)

stop_words_list_tensors = None
if stop_words_list is not None:
Expand Down Expand Up @@ -572,6 +610,7 @@ def generate_streaming(
temperature: float = 1.0,
prompt_table=None,
task_vocab_size=None,
task_vtoken_counts: List[int]=None,
task_ids: List[int] = None,
lora_uids: List[str] = None,
stop_words_list=None,
Expand All @@ -584,11 +623,7 @@ def generate_streaming(
Returns a 2D string list with shape [batch_size, num_beams].
"""
tokenizer = host_context.tokenizer

if host_context.add_bos:
input_tensors = [torch.IntTensor([tokenizer.bos_token_id] + tokenizer.encode(t)) for t in input_texts]
else:
input_tensors = [torch.IntTensor(tokenizer.encode(t)) for t in input_texts]
input_tensors = prepare_input_tensors(input_texts, host_context, prompt_table, task_vtoken_counts, task_ids)

batch_size = len(input_texts)

Expand Down

0 comments on commit 5119eae

Please sign in to comment.