Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Model][Jamba] Mamba cache single buffer #6739

Merged

Conversation

mzusman
Copy link
Contributor

@mzusman mzusman commented Jul 24, 2024

By carefully allocating the Mamba cache at the first "n" slots in the mamba cache before FWD pass ,
We can now remove the redundant CG Mamba buffer.
This PR saves memory, simplifies the Jamba inner state management code and accelerates latency (by removing redundant data copies).

This PR is also applicable to #6484 @tlrmchlsmth .

Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which consists a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of default ones by unblocking the steps in your fast-check build on Buildkite UI.

Once the PR is approved and ready to go, please make sure to run full CI as it is required to merge (or just use auto-merge).

To run full CI, you can do one of these:

  • Comment /ready on the PR
  • Add ready label to the PR
  • Enable auto-merge.

🚀

@mzusman
Copy link
Contributor Author

mzusman commented Jul 24, 2024

/ready

@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Jul 24, 2024
@mzusman mzusman force-pushed the mamba_cache_single_buffer_upstream branch from dc9bf07 to d57ccb6 Compare July 28, 2024 07:41
@mzusman
Copy link
Contributor Author

mzusman commented Jul 28, 2024

PR is ready, CI failures are not related to this PR.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just starting to read through now. At a high level, the approach makes sense to me. Do you anticipate any cases where we'll end up shuffling a lot of data in and out of the first N slots?

And do you have any end-to-end performance numbers you can share?

You mentioned an added test for parallel sampling, but it's not present in this PR. Did you mean to remove it? I noticed that the added test was there previously

@mzusman
Copy link
Contributor Author

mzusman commented Aug 4, 2024

Thank you for the review! Sorry for the long delay.

Most shuffling occurs during the transition from prefill steps to decoding steps. However, shuffling between sequential decoding steps ( which populate the majority of the steps distribution under a regular load ) doesn't happen very often since the cache is already in place ( previous implementation copied the mamba cache from buffer to buffer in each and every decode step).

And regarding end-to-end perf - so yeah, we benchmark prefill and decoding forward passes independently. We've seen 1-2 ms speed up in decoding, and no change in prefill steps. However, the major purpose of this PR is to reduce the memory usage.

image

Red line is the previous implementation, blue line is this PR implementation.

RE - Parallel sampling test. Yeah, I've intended to add it but the tiny Jamba model we use for unittest behaves differently on different devices. So I've left it out for now until we have a trained tiny model for tests.

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I left a few comments in-line. Generally, I think the approach makes sense and don't see any specific problems, but I think we should get somebody working on multi-step scheduling to review in case any conflicts might arise there. @alexm-neuralmagic could you look into that and suggest other reviewers as well?

I think the functions that manage the mamba cache might be better organized if they were factored out and encapsulated in their own class. I was thinking we could try to make it behave similarly to the BlockManager in terms of interface. Goal would be to incrementally make the mamba cache fit into vLLM's native systems. Doesn't have to be in this PR but curious to hear your thoughts on this.

One last question: A lot of this would be simpler if the two mamba cache update functions took a list of indices rather than requiring contiguous tensors. Have you looked into this at all? To me it looks like it wouldn't be too technically difficult to do, but would require a pair of PRs on https://github.com/Dao-AILab/causal-conv1d and https://github.com/state-spaces/mamba. Might be worth it just to avoid the state management.

vllm/model_executor/models/jamba.py Show resolved Hide resolved
vllm/model_executor/models/jamba.py Outdated Show resolved Hide resolved
vllm/model_executor/models/jamba.py Outdated Show resolved Hide resolved
@mzusman
Copy link
Contributor Author

mzusman commented Aug 8, 2024

I left a few comments in-line. Generally, I think the approach makes sense and don't see any specific problems, but I think we should get somebody working on multi-step scheduling to review in case any conflicts might arise there. @alexm-neuralmagic could you look into that and suggest other reviewers as well?

I think the functions that manage the mamba cache might be better organized if they were factored out and encapsulated in their own class. I was thinking we could try to make it behave similarly to the BlockManager in terms of interface. Goal would be to incrementally make the mamba cache fit into vLLM's native systems. Doesn't have to be in this PR but curious to hear your thoughts on this.

One last question: A lot of this would be simpler if the two mamba cache update functions took a list of indices rather than requiring contiguous tensors. Have you looked into this at all? To me it looks like it wouldn't be too technically difficult to do, but would require a pair of PRs on https://github.com/Dao-AILab/causal-conv1d and https://github.com/state-spaces/mamba. Might be worth it just to avoid the state management.

Thank you for the review! @alexm-neuralmagic would love to hear your opinion.
RE I totally agree that an ideal solution would be a BlockManager subtype that would also be able to deal space states cache.
This would definitely will require us to add a functionality to the mamba/casual-conv1d kernels to make them take list of indices in addition to the mamba cache. However, We haven't had any work in this area yet, mostly since it doesn't improve performance.
We would love to see future improvements that would allow us to use vLLM's native systems.

@alexm-neuralmagic
Copy link
Collaborator

@mzusman @tlrmchlsmth Did a quick pass over the PR and I see that the changes are inside the forward() function of the model itself. The multi-step logic is "above" this function, so I don't think it should interfere with the changes here. Btw, nice optimization!

@tlrmchlsmth
Copy link
Collaborator

tlrmchlsmth commented Aug 8, 2024

@mzusman FYI I am working on modifying the kernels to take a tensor of indices for the batch coordinates. I think this branch gives us the interface we'd need to avoid all of the state copying for causal_conv1d_update:
Dao-AILab/causal-conv1d@main...neuralmagic:causal-conv1d:tms/list_causal_conv1d_update

Going to try to do the same thing to selective_state_update as well. I think it would make sense to go for this approach instead if we can make the kernel updates usable quickly enough, but landing this PR is obviously more expedient. Curious to hear your urgency on getting this improvement landed.

@mzusman
Copy link
Contributor Author

mzusman commented Aug 8, 2024

That's really great! Cache management will be easier to handle.

That's right, landing this PR is quite urgent for us at the moment and does not block future improvements. I think it would be better to split those improvements/PRs.
I suggest we land this PR first, then add the adaptations from the Mamba kernel adjustments you will add in the next PR.

@tlrmchlsmth
Copy link
Collaborator

FYI I just restarted the failed jobs

Copy link
Collaborator

@tlrmchlsmth tlrmchlsmth left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does it make sense to add unit tests for the utils that maintain the cache? Seems like they're complicated enough to want additional testing. Beyond that, LGTM if green

@mzusman
Copy link
Contributor Author

mzusman commented Aug 9, 2024

I think it makes sense to add unittests that test the cache management utils,
I guess we will add them in the future in a following PR ,
since at the moment we've already got unittests that test the cache mechanism indirectly, like this one that verifies that generations are correct ,and that batching is also correct . and this one that verifies state cleanup and more

RE CI - I'll rebase, maybe it will help, failures doesn't seems to relate to this PR.

mzusman and others added 15 commits August 9, 2024 11:40
* WIP - working on swaping indices

* WIP

* Save changes

* Orginize indices during assigment, working and passing tests!

* Add TODOs

* Remove diff

* Format

* Remove TODOs

* Remove unused code

* Cleanup

* Cleanup

* Cleanup the redundant 10 blocks

* Small changes

* Simplify code and add comments

* Renaming and simplify

* Remove return

* Clean up

* Cleanup

* Renaming

* Another clean up

* Clean up

* Clean up and simplify more

* Add n > 1 test

* Format

* cosmetics

* Add functionality to find first free

* Raise exception if could not find spot

* Typos

* Add 2 slots as precaution

---------

Co-authored-by: Mor Zusman <morz@ai21.com>
This reverts commit 381c2aa.
@mzusman mzusman force-pushed the mamba_cache_single_buffer_upstream branch from 2f5293b to 3eeeeb7 Compare August 9, 2024 08:42
@tlrmchlsmth
Copy link
Collaborator

Going to merge this one, and then try to simplify with updated kernels :)

Thanks!

@tlrmchlsmth tlrmchlsmth merged commit 07ab160 into vllm-project:main Aug 9, 2024
48 checks passed
sfc-gh-mkeralapura pushed a commit to sfc-gh-mkeralapura/vllm that referenced this pull request Aug 12, 2024
Co-authored-by: Mor Zusman <morz@ai21.com>
kylesayrs pushed a commit to neuralmagic/vllm that referenced this pull request Aug 17, 2024
Co-authored-by: Mor Zusman <morz@ai21.com>
fialhocoelho pushed a commit to opendatahub-io/vllm that referenced this pull request Aug 22, 2024
Co-authored-by: Mor Zusman <morz@ai21.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants