Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About the speed test #8

Closed
MzeroMiko opened this issue Jan 30, 2024 · 9 comments
Closed

About the speed test #8

MzeroMiko opened this issue Jan 30, 2024 · 9 comments

Comments

@MzeroMiko
Copy link

Thank you for sharing your fantastic work.
We have noticed the image that with rising the dimension of d_state, the mamba's time occupation doesn't rise.
However, we found in code that writes (selective_scan_fwd_kernel.cuh#163):

for (int state_idx = 0; state_idx < params.dstate; ++state_idx) {
    ...
   if constexpr (kIsVariableB) {
                load_weight<Ktraits>(Bvar + state_idx * params.B_dstate_stride, B_vals,
                    smem_load_weight, (params.seqlen - chunk * kChunkSize) * (!kIsComplex ? 1 : 2));
   }
}

which shows a for loop with related to state_idx that reads from HBM to shared memory.

Then I tested the speed again and finds that with the d_state rises, the time occupation of mamba rises linearly, which is aligned with the code.

    device = torch.device("cuda")
    dtype = torch.float32
    B, L, G, D, N, R = 3, 4096, 4, 192, 16, 192 // 16
    xi = torch.randn((B, G * D, L), device=device, dtype=dtype)
    Ai = torch.randn((G * D, N), device=device, dtype=dtype)
    Di = torch.randn((G * D), device=device, dtype=dtype)
    dti = torch.randn((B, G * D, L), device=device, dtype=dtype)
    Bi = torch.randn((B, G, N, L), device=device, dtype=dtype)
    Ci = torch.randn((B, G, N, L), device=device, dtype=dtype)
    tpb = torch.randn((G * D), device=device, dtype=dtype)

    Ai2 = torch.randn((G * D, 4*N), device=device, dtype=dtype)
    Bi2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)
    Ci2 = torch.randn((B, G, 4*N, L), device=device, dtype=dtype)

    import time
    tim0 = time.time()
    for _ in range(1000):
        y = selective_scan_fn(xi, dti, Ai, Bi, Ci, Di, tpb, True)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    tim1 = time.time()
    for _ in range(1000):
        y = selective_scan_fn(xi, dti, Ai2, Bi2, Ci2, Di, tpb, True)
    torch.cuda.synchronize()
    torch.cuda.empty_cache()
    tim2 = time.time()
    print(tim1-tim0, tim2-tim1, torch.cuda.max_memory_allocated()) # 0.7172577381134033 2.400775194168091 185063424
    time.sleep(100000)

So what did I miss?

@alxndrTL
Copy link
Owner

Hello, thank you for taking interest in my work !

I just re-ran the script I used to generate the graph showing the training speed time over d_state, and I still get the same results, mainly that for a fixed d_model and input, increasing d_state doesn't increase significantly the time of one pass (fwd and bwd) of a Mamba block. I also tried using the same B, L, d_model as you:

B=3, L=4096, d_model=192, d_state=16, time (ms)=13.451464843750001
B=3, L=4096, d_model=192, d_state=64, time (ms)=15.76099319458008
B=3, L=4096, d_model=192, d_state=128, time (ms)=13.820765972137451

(A100 80GB)

Here is the code I use for the benchmark :

import torch
import time

from mamba_ssm import Mamba

batch, length, dim = 1, 4000, 1024
d_state = 128

torch.manual_seed(1)

model = Mamba(
    d_model=dim, # Model dimension d_model
    d_state=d_state,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

optim = torch.optim.AdamW(model.parameters(), lr=3e-3)

start_time = time.time()
N = 500
for _ in range(N):
    x = torch.randn(batch, length, dim).to("cuda")
    y_cuda = model(x)
    loss = torch.norm(y_cuda)

    optim.zero_grad()
    loss.backward()
    optim.step()

end_time = time.time()

res = (end_time-start_time)/N
print(f"B={batch}, L={length}, d_model={dim}, d_state={model.d_state}, time (ms)={res*1000}")

The main difference between your benchmark and mine is that I benchmark a whole Mamba block while you benchmark only a part of it, the selective scan. There are some other computations taking place in the Mamba block like nn.Linear, a convolution, etc. Maybe the time induced by the operations depending on d_state (that you pointed out from selective_scan_fwd_kernel.cuh#163) is small compared to the other operations happening in the block and that's why we don't see any major difference ?

@MzeroMiko
Copy link
Author

Thank you for your quick reply. In my experiments (VMamba), the selective_scan_fwd and bwd takes 80%+ of the training time. So your observations are very valuable to me.
Thank you for sharing the code, and I will tell you as soon as I figure out why.

@MzeroMiko
Copy link
Author

MzeroMiko commented Jan 31, 2024

I think I have got the answer partially:

# test_mamba_a(16) # 34.6ms in 4090
# test_mamba_a(64) # 34ms ien 4090
# test_mamba_a(128) # 45.3ms in 4090

# test_mamba_(16) # 17.4ms in 4090
# test_mamba_(64) # 28ms in 4090
# test_mamba_(128) # 41.8ms in 4090

test_mamba_a is the code you provide, while test_mamba_ modified one line:

# x = torch.randn(batch, length, dim).to("cuda")
x = torch.randn(batch, length, dim, device=torch.device("cuda"))

It is very time consuming transporting data from host memory to device side, so I think maybe the time gap between different d_states have been covered by the operation in your experiments.
Also, note that I run those 6 lines one by one from terminal cmdline in a 4090 24G, but not sequentially in one script, to avoid possible warm start.

Looking forward to your experiments in A100!

@alxndrTL
Copy link
Owner

Hello, indeed, when creating x directly on the device, I get :

B=1, L=4000, d_model=1024, d_state=16, time (ms)=10.95
B=1, L=4000, d_model=1024, d_state=64, time (ms)=17.22
B=1, L=4000, d_model=1024, d_state=128, time (ms)=26.23

That is much more logical considering what happens in the code, thank you for pointing it out !

However, considering what happens in practice, it's usual to move data from host to device before feeding it to the model no ?
For example, in classification/main.py:237 from your VMamba work, we find :

samples = samples.cuda(non_blocking=True)
targets = targets.cuda(non_blocking=True)

So in practice, I guess rising d_state impacts much less the time occupation of Mamba than "theorically" ?
For instance, with B=16, d_model=192 :

  • no data movement : 80ms, 86ms, 88ms
  • with data movement : 14ms, 39ms, 76ms

@MzeroMiko
Copy link
Author

MzeroMiko commented Jan 31, 2024

Yes, you are right.
So I have this version (simulating that the data processing is not in the main thread):

    ...
    start_time = time.time()
    N = 500
    data = torch.randn(batch, length, dim)
    for _ in range(N):
        x = data.to('cuda', non_blocking=True)
        y_cuda = model(x)
        loss = torch.norm(y_cuda)
        ...
    ...

results in:

# test_mamba2_(16) # 18.54ms in 4090
# test_mamba2_(64) # 30ms in 4090
# test_mamba2_(128) # 44.4ms in 4090

and for bigger batch size (bs=48, resulting in bigger gridDim with limited Stream Multiprocessor)

    ...
    start_time = time.time()
    # N = 100
    data = torch.randn(batch, length, dim) # batch=48
    for _ in range(N):
        x = data.to('cuda', non_blocking=True)
        y_cuda = model(x)
        loss = torch.norm(y_cuda)
        ...
    ...

results in:

# test_mamba2_(16, 48, 100) # 348ms in 4090; 21G
# test_mamba2_(64, 48, 100) # 638ms in 4090;  21G
# test_mamba2_(128, 48, 100) # 1028ms in 4090; 22G

I think the results are more obvious now😂

@alxndrTL
Copy link
Owner

Yes, it's becoming clearer !
Note that on the A100 80GB, even with non_blocking=True, for batch=1, I get identical times for d_state=16,64,128 :

B=1, L=4000, d_model=1024, d_state=16, time (ms)=33.5
B=1, L=4000, d_model=1024, d_state=64, time (ms)=34.5
B=1, L=4000, d_model=1024, d_state=128, time (ms)=33.9

which is different from what you observe !

But increasing the batch size makes the times differ for the different d_state values.
(the more the batch size increases, the bigger the differences)

I tried increasing the batch size but lowering other parameters (like length and dim) and once again the d_state times were the same. I guess that the difference in time between different d_state values happens when the GPU is near its full capacity (memory-wise or compute-wise or both, idk).
That would explain why :

  • increasing the batch size makes the time differ
  • on the A100, for batch=1, I get the same times, but not you on the 4090.

I will update the Performances on the repo. Thank you !

@smallscientist1
Copy link

I think I have got the answer partially:

# test_mamba_a(16) # 34.6ms in 4090
# test_mamba_a(64) # 34ms ien 4090
# test_mamba_a(128) # 45.3ms in 4090

# test_mamba_(16) # 17.4ms in 4090
# test_mamba_(64) # 28ms in 4090
# test_mamba_(128) # 41.8ms in 4090

test_mamba_a is the code you provide, while test_mamba_ modified one line:

# x = torch.randn(batch, length, dim).to("cuda")
x = torch.randn(batch, length, dim, device=torch.device("cuda"))

It is very time consuming transporting data from host memory to device side, so I think maybe the time gap between different d_states have been covered by the operation in your experiments. Also, note that I run those 6 lines one by one from terminal cmdline in a 4090 24G, but not sequentially in one script, to avoid possible warm start.

Looking forward to your experiments in A100!

I profile the script using nvidia nsight system. From my observation, copying data host to device only spend a small amount of time. It seems that torch.randn() on cpu spends much time.
I modify the script to

x = torch.ones(batch, length, dim).to("cuda")
# x = torch.randn(batch, length, dim).to("cuda")

and

    x = torch.ones(batch, length, dim, device="cuda")
    # x = torch.randn(batch, length, dim, device="cuda")

result on A100

# x = torch.ones(batch, length, dim).to("cuda")
B=3, L=4096, d_model=192, d_state=16, time (ms)=4.19
B=3, L=4096, d_model=192, d_state=64, time (ms)=6.51
B=3, L=4096, d_model=192, d_state=128, time (ms)=10.46

# x = torch.ones(batch, length, dim, device="cuda")
B=3, L=4096, d_model=192, d_state=16, time (ms)= 3.15
B=3, L=4096, d_model=192, d_state=64, time (ms)=5.77
B=3, L=4096, d_model=192, d_state=128, time (ms)=9.16

# x = torch.randn(batch, length, dim).to("cuda")
B=3, L=4096, d_model=192, d_state=16, time (ms)= 13.2
B=3, L=4096, d_model=192, d_state=64, time (ms)=13.2
B=3, L=4096, d_model=192, d_state=128, time (ms)=13.2

# x = torch.randn(batch, length, dim, device="cuda")
B=3, L=4096, d_model=192, d_state=16, time (ms)= 3.16
B=3, L=4096, d_model=192, d_state=64, time (ms)=5.74
B=3, L=4096, d_model=192, d_state=128, time (ms)=9.15

@MzeroMiko
Copy link
Author

MzeroMiko commented Feb 18, 2024

@smallscientist1 Thank you for your observation. It seems that generating randn in cpu is quite slow (3.16 vs 13.2), and .to(cuda) also costs time (3.16 vs 4.19).

And that may explain why the speed for different d_state feels like the same: Gpu and cpu work parallelly. The gpu is calculating while the cpu is generating randn, and the cpu keeps working even when the gpu finishes. However, in other cases, the cpu finished generating before the gpu finishes, that' s why we can see the difference.

To confirm the hypnosis above, we need to record gpu time and cpu time separately.

By the way, have you ever compared the following 2 ways?

1.  x = torch.randn(...); benchmark(x = x.to('cuda'));
2.  benchmark(x = torch.randn(device=torch.device('cuda')));

@smallscientist1
Copy link

smallscientist1 commented Feb 18, 2024

By the way, have you ever compared the following 2 ways?

1.  x = torch.randn(...); benchmark(x = x.to('cuda'));
2.  benchmark(x = torch.randn(device=torch.device('cuda')));

@MzeroMiko Thanks for your suggestions. I use 2 scripts to give a more obvious conclusion.

script 1(compute torch.randn only once):

import torch
import time

from mamba_ssm import Mamba

batch, length, dim = 3, 4096, 192
d_state = 16 

torch.manual_seed(1)

model = Mamba(
    d_model=dim, # Model dimension d_model
    d_state=d_state,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

optim = torch.optim.AdamW(model.parameters(), lr=3e-3)

start_time = time.time()
N = 500

x_ = torch.randn(batch, length, dim)
for _ in range(N):
    x = x_.to("cuda")
    y_cuda = model(x)
    loss = torch.norm(y_cuda)

    optim.zero_grad()
    loss.backward()
    optim.step()

end_time = time.time()

res = (end_time-start_time)/N
print(f"B={batch}, L={length}, d_model={dim}, d_state={model.d_state}, time (ms)={res*1000}")

script 2(compute torch.randn in loop)

import torch
import time

from mamba_ssm import Mamba

batch, length, dim = 3, 4096, 192
d_state = 16

torch.manual_seed(1)

model = Mamba(
    d_model=dim, # Model dimension d_model
    d_state=d_state,  # SSM state expansion factor
    d_conv=4,    # Local convolution width
    expand=2,    # Block expansion factor
).to("cuda")

optim = torch.optim.AdamW(model.parameters(), lr=3e-3)

start_time = time.time()
N = 500
for _ in range(N):
    x = torch.randn(batch, length, dim).to("cuda")
    y_cuda = model(x)
    loss = torch.norm(y_cuda)

    optim.zero_grad()
    loss.backward()
    optim.step()

end_time = time.time()

res = (end_time-start_time)/N
print(f"B={batch}, L={length}, d_model={dim}, d_state={model.d_state}, time (ms)={res*1000}")

result

# script 1
B=3, L=4096, d_model=192, d_state=16, time (ms)= 3.7
# script 2
B=3, L=4096, d_model=192, d_state=16, time (ms)= 13.1

Script1 only compute torch.randn once and benchmarks the memcpy and mamba kernel while script2 benchmarks the torch.randn on cpu, memcpy host to device and mamba kernel. The result is 3.7 vs 13.1.
This leads to the conclusion more obviously: torch.randn(device="cpu") spends a lot of time on cpu. When d_state is small, torch.randn() takes most of the time, longer than mamba kernel on gpu.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants