Skip to content

Commit

Permalink
[Pallas] Allow setting FlashAttention's causal mask (#6792)
Browse files Browse the repository at this point in the history
Summary:
This pull request channels the causal mask to our wrapper.

Test Plan:
PJRT_DEVICE=TPU python test/test_pallas.py -v -k test_flash_attention_wrapper_causal
  • Loading branch information
alanwaketan authored Mar 21, 2024
1 parent 8d579f9 commit de42834
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 4 deletions.
24 changes: 24 additions & 0 deletions test/test_pallas.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,30 @@ def attention(q, k, v):
self.assertTrue(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)

@unittest.skipIf(xr.device_type() != 'TPU' or tpu.version() < 3,
"This test only works on TPUv3+.")
@unittest.mock.patch.dict(os.environ, {"XLA_TPU_LAYOUT": "0"})
def test_flash_attention_wrapper_causal(self):
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.HIGHEST)
from torch_xla.experimental.custom_kernel import flash_attention

def attention(q, k, v):
attn_weight = q @ k.transpose(-2, -1)
attn_weight = nn.functional.softmax(attn_weight, dim=-1)
attn_output = attn_weight @ v
return attn_output

q = torch.randn(3, 2, 128, 4).to("xla")
k = torch.randn(3, 2, 128, 4).to("xla")
v = torch.randn(3, 2, 128, 4).to("xla")

# The causal mask is turned on by default in the wrapper.
# It masks out the top right triangle of the attention matrix, therefore it speeds up the compute but also changes the output.
o = flash_attention(q, k, v, causal=True)
expected_o = attention(q, k, v)
self.assertFalse(torch.allclose(o.cpu(), expected_o.cpu()))
jax.config.update('jax_default_matmul_precision', jax.lax.Precision.DEFAULT)


if __name__ == '__main__':
logging.getLogger().setLevel(logging.INFO)
Expand Down
10 changes: 6 additions & 4 deletions torch_xla/experimental/custom_kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,13 @@ def wrapped_kernel(kernel: Callable,


# This is a simplified wrapper on top of https://github.com/google/jax/blob/b2058d72b7e1693a41303d5411572aabf99b7981/jax/experimental/pallas/ops/tpu/flash_attention.py#L139
# where we only takes q, k, v, and segment_ids as input and set causal and block_sizes for the users.
# where we only takes q, k, v, segment_ids and causal as input and set block_sizes for the users.
def flash_attention(
q, # [batch_size, num_heads, q_seq_len, d_model]
k, # [batch_size, num_heads, kv_seq_len, d_model]
v, # [batch_size, num_heads, kv_seq_len, d_model]
segment_ids=None, # q of [batch_size, q_seq_len] and kv of [batch_size, kv_seq_len]
causal=False,
):
# Import JAX within the function such that we don't need to call the jax_import_guard()
# in the global scope which could cause problems for xmp.spawn.
Expand All @@ -140,7 +141,7 @@ def flash_attention(
import jax.numpy as jnp
import jax.experimental.pallas.ops.tpu.flash_attention as tpu_flash_attention

# TODO: Support segment_ids and causal.
# TODO: Support segment_ids.
flash_attention_kernel = make_kernel_from_pallas(
tpu_flash_attention.flash_attention, lambda q, k, v: (q.shape, q.dtype))

Expand All @@ -150,7 +151,7 @@ def flash_attention(
q,
k,
v,
static_argnames=["block_sizes"],
static_argnames=["block_sizes", "causal"],
block_sizes=tpu_flash_attention.BlockSizes(
block_q=min(512, q.shape[2]),
block_k_major=min(512, k.shape[2]),
Expand All @@ -163,4 +164,5 @@ def flash_attention(
block_q_dq=min(1024, q.shape[2]),
block_k_dq=min(256, k.shape[2]),
block_k_major_dq=min(512, k.shape[2]),
))
),
causal=causal)

0 comments on commit de42834

Please sign in to comment.