Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add feature / fix bug: I fixed the kNRows feature in forward #161

Closed
wants to merge 10 commits into from

Conversation

MzeroMiko
Copy link

@MzeroMiko MzeroMiko commented Feb 4, 2024

Thank you for sharing this splendid work!

I found that kNRows is always 1 in original selective_scan, and I observed that if I use greater kNRows in selective scan, the faster the code would run. The phenomenon is consistent with mamba.py, when adding d_state, the time consumption keeps. Though it is not strictly right, but adding the burden of one thread and reducing the number of blocks (as SM is limited) really works in most of cases.

So I reopen that feature which may be deprecated in original selective_scan, and fixed some bugs related to it.
I have tested with pytest tests/ops/test_selective_scan_.py (which you may delete later), and all tests pass.

Note that I have only fixed the forward procedure, so in backward, nrows is still 1.

Before Merging: I found that, when I uncomment all alternative parameters, the test is not all pass. However, mamba_ssm-1.1.3.post1+cu122torch2.2cxx11abiFALSE-cp310-cp310-linux_x86_64.whl acts the same.

@apoorv2904
Copy link

apoorv2904 commented Feb 9, 2024

@MzeroMiko could you please share you benchmark numbers and platform? I see that computation slows down quite a bit when using nrows>1. Am I missing something? These are the times I get on A100 80GB GPU
image

My benchmark code can be found here
#27 (comment)

Thanks!

@MzeroMiko
Copy link
Author

MzeroMiko commented Feb 17, 2024

Thank you very much, @apoorv2904.
You are right, and I nearly failed to reproduce the results I have observed before.
These days, I have been working on it. (the environment I use is 4090 24G, with py310+cu121+torch2.2)

  1. I added nrow feature in backward to better compare with different nrow settings.
  2. I compared my code (selective_scan_test here, or selective_scan_core in VMamba) with mamba_ssm rather than selective_scan_ref, and keeps no difference (tested all pass with test file).
  3. I realised that the issue proves nothing here, since raising d_state only inference the flops in SSM (nearly equals selective scan) while raising d_model or seqlen inferences the whole mamba model. As SSM is fast compared to the whole model + data loading, the speed difference is small and hard to observe (which is one possibility to that issue).
  4. I used my newly written simple benchmark, and found the results are consistent with yours. It seems that raising nrows would only make the code slower, until I finally realised that ***the test which shows raising the nrow will raise the speed, was done in 7x7 feature maps, which means seqlen is 49! extremely small! ***. Then I add seqlen=64 in testing, and found in some fwdnrow+bwdnrow patterns, the speed is fast, see log for details. Though I still do not know how bwd codes inferences the fwd procedure.
  5. I modified your benchmark, and the results are consistent with test_selective_scan_speed, see log for details.
    To conclude, with short seqlen, bigger nrow may lead to faster speed, but the reason remains unknown.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants