-
Notifications
You must be signed in to change notification settings - Fork 558
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add Beam Search sampler #618
Conversation
137740a
to
a06f4b5
Compare
16dc89e
to
f4c1eeb
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Included a few questions and a documentation fix.
Nice to see how straightforward a sampler implementation can be when accompanied with a well designed SequenceGenerator
.
docs/reference/samplers.md
Outdated
from outlines import models, generate, samplers | ||
|
||
|
||
model = models.transformers("mistralai/Mistral-7B-0.1") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Typo, it should be mistralai/Mistral-7B-v0.1
However I think we should be recommending mistralai/Mistral-7B-Instruct-v0.2
I overlooked something: some Beam Search implementations clone each beam K times and then down sample them to preserve some kind of sample diversity, see this implementation for instance. This can be done in another PR. |
d23d3fd
to
c269c0b
Compare
Closes #258.
In order to implement Beam Search I had to make a few changes to the samplers, they now:
token_ids
's ancestors, i.e. the sequence to which they need to be added. While trivial for greedy and multinomial sampling, this is what will allow us to update "beams" in beam search.token_ids
,attention_masks
,kv_cache
,fsm
andfsm_states
are now updated using theancestors
information.GenerationState
contains the sequence's weights and ancestors so we can inspect the sampling process.I also simplified the
get_generated_token_ids
method ofSequenceGenerator
. We should soon do a cleaning pass on this class: remove deprecated init arguments, make some methods independant functions and test them.