Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

how to compare mamba with flashattention2 #27

Open
xiayuqing0622 opened this issue Dec 7, 2023 · 12 comments
Open

how to compare mamba with flashattention2 #27

xiayuqing0622 opened this issue Dec 7, 2023 · 12 comments

Comments

@xiayuqing0622
Copy link

xiayuqing0622 commented Dec 7, 2023

In your paper, you mentioned that mamba scan is faster than flashattention2.
Does it mean comparing

class SelectiveScanFn(torch.autograd.Function):
with https://github.com/Dao-AILab/flash-attention/blob/9356a1c0389660d7e231ff3163c1ac17d9e3824a/flash_attn/flash_attn_interface.py#L432 ?
The inputs of these two modules are different, is this comparation fair? Or the preprocessing(compute q, k, v in flashattention; compute A,B,C,D,delta in mamba scan) need to be be taken into account?

@albertfgu
Copy link
Contributor

We decided to leave those linear projections out because they are orthogonal to the main "sequence mixing mechanism" (attention vs scan) that is of interest to benchmark. You're right that the comparisons become slightly harder to control (e.g. what model dimension to use is fair?), but we chose a setting that seemed reasonable to us. No matter what, the timings will only be off by a small constant factor with any other "reasonable" setting of dimensions, which is dwarfed by the linear vs quadratic complexity.

@tridao
Copy link
Collaborator

tridao commented Dec 7, 2023

We compared attention time (softmax(QK^T)V) vs scan time, without the linear projection.
The dimensions are different, e.g. in a 1.3B model Transformers would typically have Q, K, V of hidden dimensions 2048, while the input to the scan would be of dimension 4096 (due to expand=2). So we measured attention with Q, K, V of dim 2048 and scan with input of dimension 4096 to be fair / favorable to attention.

@xiayuqing0622
Copy link
Author

xiayuqing0622 commented Dec 8, 2023

We compared attention time (softmax(QK^T)V) vs scan time, without the linear projection. The dimensions are different, e.g. in a 1.3B model Transformers would typically have Q, K, V of hidden dimensions 2048, while the input to the scan would be of dimension 4096 (due to expand=2). So we measured attention with Q, K, V of dim 2048 and scan with input of dimension 4096 to be fair / favorable to attention.

And what datatype did you use? When I try to run scan using fp16, it always raises the error:
Traceback (most recent call last):
File "/home/yuqing/mamba/run.py", line 29, in
y = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, True)
RuntimeError: Expected weight_type == at::ScalarType::Float || weight_type == at::ScalarType::ComplexFloat to be true, but got false. (Could this error message be improved? If so, please report an enhancement request to PyTorch.)

@tridao
Copy link
Collaborator

tridao commented Dec 8, 2023

Q, K, V are bf16 for attention.
u, delta, B, C, z are bf16, A and D are fp32 for scan.

@xiayuqing0622
Copy link
Author

Q, K, V are bf16 for attention. u, delta, B, C, z are bf16, A and D are fp32 for scan.

it works now, thank you!

@xiayuqing0622
Copy link
Author

xiayuqing0622 commented Dec 8, 2023

Q, K, V are bf16 for attention. u, delta, B, C, z are bf16, A and D are fp32 for scan.

I write a simple script to compare these two component(scan and flashattn2 with causal), and tested it on A100. As instructed, input dim of scan is 4096 and input dim of flashattn is 2048( 32heads * 64 head dim). however, scan is much slower than flashattention2. (fwd: scan is 0.25ms, and flash2 is 0.14ms, fwd+bwd: scan is 1.25ms, flash2 is 0.59ms) Did I make any settings wrong?

import torch
import time

test_bwd=False
batch, length, dim, d_state =1, 2048, 2048, 16
from mamba_ssm.ops.selective_scan_interface import SelectiveScanFn
u = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta = torch.randn(batch, dim * 2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
A = torch.randn(dim*2, d_state).to("cuda").requires_grad_(True)
B = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
C = torch.randn(batch, 1, d_state, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
D = torch.randn(dim*2).to("cuda").requires_grad_(True)
z = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16).requires_grad_(True)
delta_bias = torch.randn(dim*2).to("cuda").requires_grad_(True)
doutssm = torch.randn(batch, dim*2, length).to("cuda").to(torch.bfloat16)
ssm = SelectiveScanFn.apply

for i in range(10):
    y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
    if test_bwd:
        y.backward(doutssm)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
    y = ssm(u, delta, A, B, C, D, z, delta_bias, True)
    if test_bwd:
        y.backward(doutssm)
torch.cuda.synchronize()
print(time.time() - start)

from flash_attn import flash_attn_func

dim_head = 64
n_heads = dim//dim_head
q = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
k = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
v = torch.randn(batch, length, n_heads, dim_head).to("cuda").to(torch.bfloat16).requires_grad_(True)
dout = torch.randn(batch, length, n_heads,dim_head).to("cuda").to(torch.bfloat16)

for i in range(10):
    y = flash_attn_func(q, k, v, causal=True)
    if test_bwd:
        y.backward(dout)
torch.cuda.synchronize()
start = time.time()
for i in range(1000):
    y = flash_attn_func(q, k, v, causal=True)
    if test_bwd:
        y.backward(dout)
torch.cuda.synchronize()
print(time.time() - start)

@albertfgu
Copy link
Contributor

Please format your code with triple backticks followed by "python": ``` python

The appendix of the paper says that the dimension D is actually 1024, not 2048. We'll have to double check our script to see if anything is different than yours.

@xiayuqing0622
Copy link
Author

Please format your code with triple backticks followed by "python": ``` python

The appendix of the paper says that the dimension D is actually 1024, not 2048. We'll have to double check our script to see if anything is different than yours.

Sorry for the format issue. I've re-edited the code above. I also tested input with D=1024, for fwd, it's scan 0.13ms vs flash 0.08ms, for fwd+bwd, it's scan 0.71ms vs flash 0.35 ms.

@apoorv2904
Copy link

Hi, @tridao and @albertfgu, first of all thank you for releasing both FlashAttention (v1 and v2) and Mamba model source codes including the cuda kernels!

I too had this issue about not being able to reproduce the benchmarks in particular agains flash attention v2. I tried several settings. (D=768, 1024, 2048) and for N/d_state=16, flash attention was significantly faster than scan. Only at N=4, I start to see the curves reported in the paper. In particular, for N=16 the scan is about 2X slower.

Following are the times in ms that I see.
image

It would be immensely useful if you could spare some time to please review the mamba benchmark below or provide few more details to reproduce the benchmark. Thanks @xiayuqing0622 for the starting code.

Environment:
- A100 80 GB
- pytorch 2.1 / cuda 11.8
def benchmark_mamba(batch, head, length, dim_head, d_state):
   from mamba_ssm.ops.selective_scan_interface import SelectiveScanFn
   from mamba_ssm.ops.selective_scan_interface import selective_scan_cuda
   from einops import rearrange, repeat

   d_model = dim_head * head
   expand = 2
   d_inner = d_model * expand
   device = "cuda"

   # S4D real initialization
   A = repeat(
       torch.arange(1, d_state + 1, dtype=torch.float32, device=device),
       "n -> d n",
       d=d_inner,
   ).contiguous()
   A_log = torch.log(A)  # Keep A_log in fp32

   x = torch.rand(
       (batch, d_inner, length), device=device, dtype=torch.bfloat16
   ).requires_grad_(True)
   z = torch.rand(
       (batch, d_inner, length), device=device, dtype=torch.bfloat16
   ).requires_grad_(True)
   delta = torch.rand(
       (batch, d_inner, length), device=device, dtype=torch.bfloat16
   ).requires_grad_(True)
   delta_bias = torch.randn(d_inner).to("cuda").requires_grad_(True)
   A = -torch.exp(A_log.float())  # (d_inner, d_state)
   B = (
       torch.randn(batch, 1, d_state, length)
       .to("cuda")
       .to(torch.bfloat16)
       .requires_grad_(True)
   )
   C = (
       torch.randn(batch, 1, d_state, length)
       .to("cuda")
       .to(torch.bfloat16)
       .requires_grad_(True)
   )
   D = torch.ones(d_inner, device=device)  # Keep in fp32
   delta_softplus = True

   ms = triton.testing.do_bench(
       lambda: selective_scan_cuda.fwd(
           x, delta, A, B, C, D, z, delta_bias, delta_softplus
       ),
       warmup=100,
   )
   return ms

The full code is below but please feel free to ignore the rest. Here is the code

import itertools
from math import sqrt

import pandas
import torch
from tqdm import tqdm
import triton

from flash_attn.bert_padding import pad_input, unpad_input
from flash_attn import flash_attn_qkvpacked_func, flash_attn_varlen_qkvpacked_func


def get_inputs(B, H, L, E=64, ret_padding_mask=False, dtype=torch.float32):
    q = torch.rand((B, H, L, E), device="cuda", dtype=dtype)
    k = torch.rand((B, H, L, E), device="cuda", dtype=dtype)
    v = torch.rand((B, H, L, E), device="cuda", dtype=dtype)

    input_lengths = torch.randint(1, L, (B,), device=q.device).long()
    input_lengths[-1] = L
    padding_mask = torch.zeros((B, L), dtype=q.dtype, device=q.device)
    padding_mask[
        (
            torch.arange(padding_mask.shape[0], device=padding_mask.device),
            input_lengths - 1,
        )
    ] = 1
    padding_mask = (1 - padding_mask.flip([-1]).cumsum(-1).flip([-1])).bool()
    if not ret_padding_mask:
        padding_mask = None
    return (q, k, v), padding_mask
    
def flash_attn_forward(queries, keys, values, padding_mask=None):
    qkv = torch.stack([queries, keys, values], dim=2)
    qkv = qkv.permute(0, 3, 2, 1, 4)
    B, T, _, H, D = qkv.shape
    scale = 1.0 / sqrt(D)

    if padding_mask is not None:
        # unpad_input expectes True to correspond to valid indices and False to invalid
        qkv, indices, cu_q_lens, max_s = unpad_input(qkv, ~padding_mask)
        packed_res = flash_attn_varlen_qkvpacked_func(
            qkv,
            cu_q_lens,
            max_s,
            dropout_p=0.0,
            softmax_scale=scale,
            causal=False,
            alibi_slopes=None,
            deterministic=False,
        )
        res = pad_input(packed_res, indices, B, T)
        res = res.transpose(1, 2)
    else:
        res = flash_attn_qkvpacked_func(
            qkv,
            dropout_p=0.0,
            softmax_scale=scale,
            causal=False,
            alibi_slopes=None,
            deterministic=False,
        )
        res = res.transpose(1, 2)  # B x T x H x D -> B x H x T x D
    return res

    
def benchmark_flash(q, k, v, padding_mask):
    dim_E = q.shape[-1]
    H = q.shape[1]
    E = dim_E * H
    ms = triton.testing.do_bench(
        lambda: flash_attn_forward(q, k, v, padding_mask=padding_mask), warmup=100
    )
    return ms


if __name__ == "__main__":
    batch_sizes = [16]
    heads = [12, 16, 32]
    time_steps = [1000, 1600, 3200, 6400]
    get_padding_masks = [True, False]
    d_states = [2, 4, 8, 16]
    dtypes = [torch.bfloat16]
    E = 64

    results = []

    for B, H, L, pm, dtype in tqdm(
        itertools.product(batch_sizes, heads, time_steps, get_padding_masks, dtypes)
    ):
        (q, k, v), padding_mask = get_inputs(
            B, H, L, E=64, ret_padding_mask=pm, dtype=dtype
        )
        ms = benchmark_flash(q, k, v, padding_mask)
        results.append(
            {
                "name": "flash",
                "batch_size": B,
                "nheads": H,
                "seq_len": L,
                "dim": H * E,
                "padding": pm,
                "dtype": dtype,
                "ms": ms,
            }
        )

    for B, H, L, pm, d_state, dtype in tqdm(
        itertools.product(
            batch_sizes, heads, time_steps, get_padding_masks, d_states, dtypes
        )
    ):
        (q, k, v), padding_mask = get_inputs(
            B, H, L, E=64, ret_padding_mask=pm, dtype=dtype
        )

        ms = benchmark_mamba(B, H, L, E, d_state)
        results.append(
            {
                "name": f"mamba-{d_state}",
                "batch_size": B,
                "nheads": H,
                "seq_len": L,
                "dim": H * E,
                "padding": pm,
                "dtype": dtype,
                "ms": ms,
            }
        )

    df = pandas.DataFrame(results)
    piv = df.pivot(
        columns="name",
        values="ms",
        index=["dtype", "padding", "batch_size", "nheads", "seq_len"],
    )
    print(piv.sort_index().round(3))

@tridao
Copy link
Collaborator

tridao commented Feb 7, 2024

Try selective_scan_fn(u, delta, A, B, C, D) (no z, delta_bias, delta_softplus) to see if that makes a difference?

@apoorv2904
Copy link

apoorv2904 commented Feb 7, 2024

@tridao selective_scan_fn(u, delta, A, B, C, D) resulted in speed up but its still significantly slower for N=16.

image

@llmexperiment
Copy link

@tridao selective_scan_fn(u, delta, A, B, C, D) resulted in speed up but its still significantly slower for N=16.

image

HI @apoorv2904 , are you able to reproduce the results? If so could you please share how you reproduced the result?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

5 participants