Skip to content

Commit

Permalink
Expose out in python API (#2)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored May 22, 2024
1 parent eee8e47 commit b16c279
Showing 1 changed file with 40 additions and 4 deletions.
44 changes: 40 additions & 4 deletions vllm_flash_attn/flash_attn_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,15 +44,15 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal):


def _flash_attn_forward(
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax
q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.fwd(
q,
k,
v,
None,
out,
alibi_slopes,
dropout_p,
softmax_scale,
Expand Down Expand Up @@ -80,14 +80,16 @@ def _flash_attn_varlen_forward(
alibi_slopes,
return_softmax,
block_table,
*,
out=None
):
maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
out, q, k, v, out_padded, softmax_lse, S_dmask, rng_state = flash_attn_cuda.varlen_fwd(
q,
k,
v,
None,
out,
cu_seqlens_q,
cu_seqlens_k,
None,
Expand Down Expand Up @@ -220,6 +222,8 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
*,
out=None,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
Expand All @@ -233,6 +237,7 @@ def forward(
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -284,6 +289,8 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
*,
out=None,
):
if softmax_scale is None:
softmax_scale = qkv.shape[-1] ** (-0.5)
Expand All @@ -302,6 +309,7 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -357,6 +365,7 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -370,6 +379,7 @@ def forward(
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -426,6 +436,7 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -444,6 +455,7 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=None,
out=out,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Expand Down Expand Up @@ -505,6 +517,7 @@ def forward(
alibi_slopes,
deterministic,
return_softmax,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -518,6 +531,7 @@ def forward(
window_size=window_size,
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
out=out,
)
ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state)
ctx.dropout_p = dropout_p
Expand Down Expand Up @@ -575,6 +589,7 @@ def forward(
deterministic,
return_softmax,
block_table,
out=None,
):
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)
Expand All @@ -593,6 +608,7 @@ def forward(
alibi_slopes=alibi_slopes,
return_softmax=return_softmax and dropout_p > 0,
block_table=block_table,
out=out,
)
ctx.save_for_backward(
q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state
Expand Down Expand Up @@ -648,6 +664,8 @@ def flash_attn_qkvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
Expand Down Expand Up @@ -691,6 +709,7 @@ def flash_attn_qkvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -704,6 +723,8 @@ def flash_attn_kvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
Expand Down Expand Up @@ -765,6 +786,7 @@ def flash_attn_kvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -779,6 +801,8 @@ def flash_attn_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads
Expand Down Expand Up @@ -839,6 +863,7 @@ def flash_attn_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -853,6 +878,8 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If Q, K, V are already stacked into 1 tensor, this function will be faster than
Expand Down Expand Up @@ -901,6 +928,7 @@ def flash_attn_varlen_qkvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -918,6 +946,8 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes=None,
deterministic=False,
return_attn_probs=False,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
If K, V are already stacked into 1 tensor, this function will be faster than
Expand Down Expand Up @@ -989,6 +1019,7 @@ def flash_attn_varlen_kvpacked_func(
alibi_slopes,
deterministic,
return_attn_probs,
out=out,
)


Expand All @@ -1008,6 +1039,8 @@ def flash_attn_varlen_func(
deterministic=False,
return_attn_probs=False,
block_table=None,
*,
out=None,
):
"""dropout_p should be set to 0.0 during evaluation
Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads
Expand Down Expand Up @@ -1079,6 +1112,7 @@ def flash_attn_varlen_func(
deterministic,
return_attn_probs,
block_table,
out=out,
)


Expand All @@ -1099,6 +1133,8 @@ def flash_attn_with_kvcache(
rotary_interleaved=True,
alibi_slopes=None,
num_splits=0,
*,
out=None,
):
"""
If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from
Expand Down Expand Up @@ -1206,7 +1242,7 @@ def flash_attn_with_kvcache(
cache_batch_idx,
block_table,
alibi_slopes,
None,
out,
softmax_scale,
causal,
window_size[0],
Expand Down

0 comments on commit b16c279

Please sign in to comment.