Skip to content

Commit

Permalink
fMHA dispatch: Fix dispatch to Flash when using BlockDiag attn mask
Browse files Browse the repository at this point in the history
ghstack-source-id: f8a54005dc163203a6fa54495be02f6edcedfa3c
Pull Request resolved: https://github.com/fairinternal/xformers/pull/461

__original_commit__ = fairinternal/xformers@471967c9ffc6cfc3205c72237d3013453ace6c93
  • Loading branch information
danthe3rd authored and xFormers Bot committed Feb 9, 2023
1 parent 1637b24 commit 44bc216
Showing 1 changed file with 3 additions and 0 deletions.
3 changes: 3 additions & 0 deletions xformers/ops/fmha/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,15 @@
from typing import List, Type, TypeVar

from . import cutlass, flash, small_k, triton
from .attn_bias import BlockDiagonalMask
from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs


def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool:
# Very small batch sizes - if batch size specified
batch_size, q_len, num_heads, k = inp.query.shape
if isinstance(inp.attn_bias, BlockDiagonalMask):
batch_size *= len(inp.attn_bias.k_seqinfo.cu_seqlen_py)
if batch_size > 0:
threads_flash = batch_size * num_heads
threads_cutlass = threads_flash * (q_len // 64)
Expand Down

0 comments on commit 44bc216

Please sign in to comment.