From b16c2794580aff815dfe0bb1b183a1169cabe41c Mon Sep 17 00:00:00 2001 From: Antoni Baum Date: Wed, 22 May 2024 13:42:34 -0700 Subject: [PATCH] Expose out in python API (#2) --- vllm_flash_attn/flash_attn_interface.py | 44 ++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 4 deletions(-) diff --git a/vllm_flash_attn/flash_attn_interface.py b/vllm_flash_attn/flash_attn_interface.py index 9d78beb2a..6a5f16cf7 100644 --- a/vllm_flash_attn/flash_attn_interface.py +++ b/vllm_flash_attn/flash_attn_interface.py @@ -44,7 +44,7 @@ def _get_block_size_n(device, head_dim, is_dropout, is_causal): def _flash_attn_forward( - q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax + q, k, v, dropout_p, softmax_scale, causal, window_size, alibi_slopes, return_softmax, *, out=None ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -52,7 +52,7 @@ def _flash_attn_forward( q, k, v, - None, + out, alibi_slopes, dropout_p, softmax_scale, @@ -80,6 +80,8 @@ def _flash_attn_varlen_forward( alibi_slopes, return_softmax, block_table, + *, + out=None ): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] @@ -87,7 +89,7 @@ def _flash_attn_varlen_forward( q, k, v, - None, + out, cu_seqlens_q, cu_seqlens_k, None, @@ -220,6 +222,8 @@ def forward( alibi_slopes, deterministic, return_softmax, + *, + out=None, ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) @@ -233,6 +237,7 @@ def forward( window_size=window_size, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + out=out, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p @@ -284,6 +289,8 @@ def forward( alibi_slopes, deterministic, return_softmax, + *, + out=None, ): if softmax_scale is None: softmax_scale = qkv.shape[-1] ** (-0.5) @@ -302,6 +309,7 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, + out=out, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, cu_seqlens, rng_state) ctx.dropout_p = dropout_p @@ -357,6 +365,7 @@ def forward( alibi_slopes, deterministic, return_softmax, + out=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -370,6 +379,7 @@ def forward( window_size=window_size, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + out=out, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p @@ -426,6 +436,7 @@ def forward( alibi_slopes, deterministic, return_softmax, + out=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -444,6 +455,7 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=None, + out=out, ) ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state @@ -505,6 +517,7 @@ def forward( alibi_slopes, deterministic, return_softmax, + out=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -518,6 +531,7 @@ def forward( window_size=window_size, alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, + out=out, ) ctx.save_for_backward(q, k, v, out_padded, softmax_lse, rng_state) ctx.dropout_p = dropout_p @@ -575,6 +589,7 @@ def forward( deterministic, return_softmax, block_table, + out=None, ): if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -593,6 +608,7 @@ def forward( alibi_slopes=alibi_slopes, return_softmax=return_softmax and dropout_p > 0, block_table=block_table, + out=out, ) ctx.save_for_backward( q, k, v, out_padded, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state @@ -648,6 +664,8 @@ def flash_attn_qkvpacked_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + *, + out=None, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -691,6 +709,7 @@ def flash_attn_qkvpacked_func( alibi_slopes, deterministic, return_attn_probs, + out=out, ) @@ -704,6 +723,8 @@ def flash_attn_kvpacked_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + *, + out=None, ): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than @@ -765,6 +786,7 @@ def flash_attn_kvpacked_func( alibi_slopes, deterministic, return_attn_probs, + out=out, ) @@ -779,6 +801,8 @@ def flash_attn_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + *, + out=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads @@ -839,6 +863,7 @@ def flash_attn_func( alibi_slopes, deterministic, return_attn_probs, + out=out, ) @@ -853,6 +878,8 @@ def flash_attn_varlen_qkvpacked_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + *, + out=None, ): """dropout_p should be set to 0.0 during evaluation If Q, K, V are already stacked into 1 tensor, this function will be faster than @@ -901,6 +928,7 @@ def flash_attn_varlen_qkvpacked_func( alibi_slopes, deterministic, return_attn_probs, + out=out, ) @@ -918,6 +946,8 @@ def flash_attn_varlen_kvpacked_func( alibi_slopes=None, deterministic=False, return_attn_probs=False, + *, + out=None, ): """dropout_p should be set to 0.0 during evaluation If K, V are already stacked into 1 tensor, this function will be faster than @@ -989,6 +1019,7 @@ def flash_attn_varlen_kvpacked_func( alibi_slopes, deterministic, return_attn_probs, + out=out, ) @@ -1008,6 +1039,8 @@ def flash_attn_varlen_func( deterministic=False, return_attn_probs=False, block_table=None, + *, + out=None, ): """dropout_p should be set to 0.0 during evaluation Supports multi-query and grouped-query attention (MQA/GQA) by passing in K, V with fewer heads @@ -1079,6 +1112,7 @@ def flash_attn_varlen_func( deterministic, return_attn_probs, block_table, + out=out, ) @@ -1099,6 +1133,8 @@ def flash_attn_with_kvcache( rotary_interleaved=True, alibi_slopes=None, num_splits=0, + *, + out=None, ): """ If k and v are not None, k_cache and v_cache will be updated *inplace* with the new values from @@ -1206,7 +1242,7 @@ def flash_attn_with_kvcache( cache_batch_idx, block_table, alibi_slopes, - None, + out, softmax_scale, causal, window_size[0],