Skip to content

Commit

Permalink
fix the mpt chatbot (#6957) (#6968)
Browse files Browse the repository at this point in the history
Signed-off-by: Yi Dong <yidong@nvidia.com>
Co-authored-by: Yi Dong <43824965+yidong72@users.noreply.github.com>
Signed-off-by: Gerald Shen <geshen@nvidia.com>
  • Loading branch information
2 people authored and gshennvm committed Jul 12, 2023
1 parent acbfe08 commit 84c468d
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
3 changes: 3 additions & 0 deletions nemo/collections/nlp/modules/common/megatron_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_
response = text_generation(data, port=port)
sentences = response['sentences']
bot_message = sentences[0]
if bot_message.find('<extra_id_0') < 0:
# hack due to the problem that huggingface's tokenizer strips out the <extra_id_x> token
prompt = prompt.replace('<extra_id_0>', '').replace('<extra_id_1>', '').replace('<extra_id_2>', '')
bot_message = bot_message[len(prompt) :]
return bot_message

Expand Down
16 changes: 10 additions & 6 deletions nemo/collections/nlp/modules/common/text_generation_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,15 +153,19 @@ def end_of_generation_condition(
else:
tokenizer = self.model.tokenizer
conditions = []
end_tokens = set()
end_tokens.add(eod_id)
for end_string in end_strings:
ids_1 = tokenizer.text_to_ids(f'<extra_id_1>{end_string}')
ids_2 = tokenizer.text_to_ids('<extra_id_1>')
if len(ids_1) <= len(ids_2):
continue
token_id = ids_1[len(ids_2) :][0]
end_tokens.add(token_id)
for p, token_item in zip(prev, tokens):
text = tokenizer.ids_to_text(token_item.tolist())
conditions.append(
any(
[
p.item() == eod_id if end_string == END_OF_SEQ else text.endswith(end_string)
for end_string in end_strings
]
)
any([text.endswith(end_string) for end_string in end_strings] + [p.item() in end_tokens])
)
return torch.tensor(conditions, dtype=torch.bool, device=tokens.device)

Expand Down

0 comments on commit 84c468d

Please sign in to comment.