-
Notifications
You must be signed in to change notification settings - Fork 50
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
return correct finish reasons in generate_text_func #210
Conversation
cba775e
to
7e31e6a
Compare
Tagging @Xaenalt for visibility |
A small sidenote here: caikit-nlp/caikit_nlp/toolkit/text_generation/model_run_utils.py Lines 236 to 241 in ef283e4
If Would it make sense to add it there? |
I think it does make sense to add it there, especially if it's being referenced here. @gabe-l-hart is this intended? |
Hm, that does seem troubling. @gkumbhat I'll defer to you on this, but I think we probably do need to add |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Just left a small comment, but otherwise looks good. Thanks for catching this and contributing the fix.
@@ -36,6 +36,11 @@ | |||
# Local | |||
from ...data_model import ExponentialDecayLengthPenalty | |||
|
|||
if TYPE_CHECKING: | |||
# Third Party | |||
from transformers import AutoModel, AutoTokenizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Lets include this as a dependency since anyways all these modules anyways do need to work with transformers
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure what you mean here: get rid of if TYPE_CHECKING
? transformers
is already a dependency in pyproject.toml
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@dtrifiro yep. thats exactly what I was suggesting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
Actually |
Thinking about it some more, the only other finish reason that should be returned by this function is I think the overall objective would be to get rid of EDIT: I ended up adding some code to handle the |
bad7bc6
to
4d15fda
Compare
@gkumbhat Another fix which was required in order for CI to pass with the new exception instead of |
@@ -237,7 +237,9 @@ def generate_text_func( | |||
generate_ids[0, -1] == tokenizer.eos_token_id | |||
): | |||
finish_reason = "EOS_TOKEN" | |||
elif generate_ids.size(1) - 1 == max_new_tokens: | |||
elif (generate_ids.size(1) - 1 == max_new_tokens) or ( | |||
generate_ids.size(1) - inputs["input_ids"].size(1) == max_new_tokens |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm. technically input will only be in output for causal-lm
type of models, so there can be some side-effects in this calculations 🤔 May be we would need to check if input is actually in output and then calculate max_new_tokens for that accordingly
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gkumbhat I agree. I gave some thought to it and I ended up with the logic in the latest push. The idea is the following:
- Check if the last token is
eos_token
, if so,finish_reason = "EOS_TOKEN"
- Check if the stop sequence is the the generated tokens, if so,
finish_reason = "STOP_SEQUENCE"
- If none of the above conditions are true, then
finish_reason = "MAX_TOKENS"
My reasoning for always returning MAX_TOKENS
comes from looking at caikit.interfaces.nlp.data_model.text_generation.FinishReason
. Here's the code with some comments I added
@dataobject(package=NLP_PACKAGE)
class FinishReason(Enum):
NOT_FINISHED = 0 # should only be returned by streaming implementations
MAX_TOKENS = 1
EOS_TOKEN = 2 # matches rule #1 above
CANCELLED = 3 # should not be set by generate_text_func
TIME_LIMIT = 4 # should not be set by generate_text_func
STOP_SEQUENCE = 5 # matches rule #2 above
TOKEN_LIMIT = 6 # not sure about this
ERROR = 7 # fairly generic,
I'm unsure about TOKEN_LIMIT
(how is this different from MAX_TOKENS
?), but I think none of the other enum values apply to generate_text_func
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TOKEN_LIMIT
refers to the maximum number of tokens limit defined by the model whereas the MAX_TOKENS
refers to the maximum number defined by the user. So one can reach TOKEN_LIMIT
before MAX_TOKENS
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@gkumbhat Ok, so that means that also need to check whether we reached the token limit for the model, although I'm not sure how we can get that information. Can we open an issue for adding support for TOKEN_LIMIT
so that we can work on that on another PR?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Created #253
Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
c35fd62
to
86271dc
Compare
…text_func "OTHER" is an invalid value for caikit.interfaces.nlp.data_model.text_generation.FinishReason, resulting in failed serialization of responses when querying the text generation endpoint. For `generate_text_func`, it is reasonable to assume that if the finish reason is not `EOS_TOKEN` or `STOP_SEQUENCE`, it must be `MAX_TOKENS`. Signed-off-by: Daniele Trifirò <dtrifiro@redhat.com>
86271dc
to
1b45817
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Hey @gkumbhat, is there anything I can do to get this merged? |
The finish reason in
generate_text_func
is currently broken (always returnsOTHER
). The reason is that a comparison is being made between a tensor element (int)
and a string (eos_token
) without converting it using thetokenizer
This PR fixes the tokenization issue and adds a few type hints
Full reproduction screiopt below
Reproduction script