Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add left aligned cache support. #133

Merged
merged 6 commits into from
Jun 28, 2024
Merged

Add left aligned cache support. #133

merged 6 commits into from
Jun 28, 2024

Conversation

wang2yn84
Copy link
Collaborator

No description provided.

@wang2yn84 wang2yn84 requested review from qihqi and FanhaiLu1 June 21, 2024 23:04
# fill mask first
mask = decode_state.mask.at[:, decode_state.current_position].set(0)
if self.env.ring_buffer:
input_indexes = jnp.full((1,), pos)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we change current_position to [batch_size, 1], can we use same logic do mask for both ring_buffer and onn_ring_buffer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. For non ring buffer case, there is one single value of current position to indicate the decoding position for all the batches. But for ring buffer, every batch has different position, so we cannot use the current_position here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I mean if we change current_position to [batch_size, 1], different slot can have different the current_position. For non ring buffer case, the current_position should be same as input_pos.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will cause performance regression. Please check jax_experiments.py/test7, inserting with batching + position array takes much longer, like x4~x5

input_pos = jnp.where(
decode_state.input_pos == 0,
0,
decode_state.input_pos + 1 % self.env.cache_len,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In non ring buffer case, can input_pos be larger than cache len?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no, I feel we don't need do % since it never reach the cache len.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have control for this. Generate() will keep running if no new prefill results are inserted.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the details!

jetstream_pt/engine.py Show resolved Hide resolved
@qihqi qihqi merged commit 175d956 into main Jun 28, 2024
4 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants