diff --git a/nemo/collections/nlp/modules/common/megatron_web_server.py b/nemo/collections/nlp/modules/common/megatron_web_server.py index 648bca024ba0..7c04ef201927 100644 --- a/nemo/collections/nlp/modules/common/megatron_web_server.py +++ b/nemo/collections/nlp/modules/common/megatron_web_server.py @@ -90,6 +90,9 @@ def get_generation(prompt, greedy, add_BOS, token_to_gen, min_tokens, temp, top_ response = text_generation(data, port=port) sentences = response['sentences'] bot_message = sentences[0] + if bot_message.find(' token + prompt = prompt.replace('', '').replace('', '').replace('', '') bot_message = bot_message[len(prompt) :] return bot_message diff --git a/nemo/collections/nlp/modules/common/text_generation_strategy.py b/nemo/collections/nlp/modules/common/text_generation_strategy.py index 8608c0c9a680..573bdc80735e 100644 --- a/nemo/collections/nlp/modules/common/text_generation_strategy.py +++ b/nemo/collections/nlp/modules/common/text_generation_strategy.py @@ -153,15 +153,19 @@ def end_of_generation_condition( else: tokenizer = self.model.tokenizer conditions = [] + end_tokens = set() + end_tokens.add(eod_id) + for end_string in end_strings: + ids_1 = tokenizer.text_to_ids(f'{end_string}') + ids_2 = tokenizer.text_to_ids('') + if len(ids_1) <= len(ids_2): + continue + token_id = ids_1[len(ids_2) :][0] + end_tokens.add(token_id) for p, token_item in zip(prev, tokens): text = tokenizer.ids_to_text(token_item.tolist()) conditions.append( - any( - [ - p.item() == eod_id if end_string == END_OF_SEQ else text.endswith(end_string) - for end_string in end_strings - ] - ) + any([text.endswith(end_string) for end_string in end_strings] + [p.item() in end_tokens]) ) return torch.tensor(conditions, dtype=torch.bool, device=tokens.device)