Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for batched input_pos to model #1700

Merged
merged 3 commits into from
Aug 29, 2024
Merged

add support for batched input_pos to model #1700

merged 3 commits into from
Aug 29, 2024

Conversation

t-vi
Copy link
Contributor

@t-vi t-vi commented Aug 29, 2024

This allows to feed both x, input_pos of shape (batch, T) to the model.
I included a test that compares to running the batch items separately.

To add some commentary around the for loop:

  • it's not terribly cool to use the for loop, but to reduce to index_copy_, we would need to fold the batch and indexed dimensions, which we cannot for KV-Cache,
  • if we can use cudagraphs, it might not be too bad.

@t-vi t-vi force-pushed the tom/simple_batch branch from 1abfa02 to 6fb8465 Compare August 29, 2024 10:59
@t-vi t-vi merged commit fdf6a12 into main Aug 29, 2024
9 checks passed
@t-vi t-vi deleted the tom/simple_batch branch August 29, 2024 16:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants