From 44bc21633392d0ff099a306ad968370489411d33 Mon Sep 17 00:00:00 2001 From: danthe3rd Date: Thu, 9 Feb 2023 13:12:29 +0000 Subject: [PATCH] fMHA dispatch: Fix dispatch to Flash when using BlockDiag attn mask ghstack-source-id: f8a54005dc163203a6fa54495be02f6edcedfa3c Pull Request resolved: https://github.com/fairinternal/xformers/pull/461 __original_commit__ = fairinternal/xformers@471967c9ffc6cfc3205c72237d3013453ace6c93 --- xformers/ops/fmha/dispatch.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/xformers/ops/fmha/dispatch.py b/xformers/ops/fmha/dispatch.py index 5f4b641481..a25c74b346 100644 --- a/xformers/ops/fmha/dispatch.py +++ b/xformers/ops/fmha/dispatch.py @@ -8,12 +8,15 @@ from typing import List, Type, TypeVar from . import cutlass, flash, small_k, triton +from .attn_bias import BlockDiagonalMask from .common import AttentionBwOpBase, AttentionFwOpBase, Inputs def _is_cutlass_fwd_faster_than_flash(inp: Inputs) -> bool: # Very small batch sizes - if batch size specified batch_size, q_len, num_heads, k = inp.query.shape + if isinstance(inp.attn_bias, BlockDiagonalMask): + batch_size *= len(inp.attn_bias.k_seqinfo.cu_seqlen_py) if batch_size > 0: threads_flash = batch_size * num_heads threads_cutlass = threads_flash * (q_len // 64)