diff --git a/caikit_nlp/toolkit/text_generation/model_run_utils.py b/caikit_nlp/toolkit/text_generation/model_run_utils.py index b4a7dc77..eb45064f 100644 --- a/caikit_nlp/toolkit/text_generation/model_run_utils.py +++ b/caikit_nlp/toolkit/text_generation/model_run_utils.py @@ -19,7 +19,7 @@ # Third Party from peft.peft_model import PeftModel -from transformers import StoppingCriteria, TextStreamer +from transformers import AutoModel, AutoTokenizer, StoppingCriteria, TextStreamer import numpy as np import torch @@ -131,10 +131,10 @@ def __iter__(self): def generate_text_func( - model, - tokenizer, + model: "Union[PeftModel, AutoModel]", + tokenizer: "AutoTokenizer", producer_id: ProducerId, - eos_token: str, + eos_token: Optional[str], text: str, max_new_tokens: Optional[int] = 20, min_new_tokens: Optional[int] = 0, @@ -234,12 +234,20 @@ def generate_text_func( for g in generate_ids ] - if generate_ids[0][-1].item() == eos_token: + if (eos_token and tokenizer.decode(generate_ids[0, -1].item()) == eos_token) or ( + generate_ids[0, -1] == tokenizer.eos_token_id + ): finish_reason = "EOS_TOKEN" - elif generate_ids.size(1) - 1 == max_new_tokens: - finish_reason = "MAX_TOKENS" + elif ("stopping_criteria" in gen_optional_params) and ( + gen_optional_params["stopping_criteria"]( + generate_ids, + None, # scores, unused by SequenceStoppingCriteria + ) + ): + finish_reason = "STOP_SEQUENCE" else: - finish_reason = "OTHER" + finish_reason = "MAX_TOKENS" + return GeneratedTextResult( generated_tokens=token_count, generated_text=preds[0],