-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add left aligned cache support. #133
Conversation
… no return issue;
# fill mask first | ||
mask = decode_state.mask.at[:, decode_state.current_position].set(0) | ||
if self.env.ring_buffer: | ||
input_indexes = jnp.full((1,), pos) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we change current_position to [batch_size, 1], can we use same logic do mask for both ring_buffer and onn_ring_buffer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. For non ring buffer case, there is one single value of current position to indicate the decoding position for all the batches. But for ring buffer, every batch has different position, so we cannot use the current_position here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I mean if we change current_position to [batch_size, 1], different slot can have different the current_position. For non ring buffer case, the current_position should be same as input_pos.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will cause performance regression. Please check jax_experiments.py/test7, inserting with batching + position array takes much longer, like x4~x5
input_pos = jnp.where( | ||
decode_state.input_pos == 0, | ||
0, | ||
decode_state.input_pos + 1 % self.env.cache_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In non ring buffer case, can input_pos be larger than cache len?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If no, I feel we don't need do % since it never reach the cache len.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We don't have control for this. Generate() will keep running if no new prefill results are inserted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for sharing the details!
No description provided.