From 5a26cc82c42654cccccba6f47d6fc16e2ad4c46e Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Sun, 27 Aug 2023 18:41:05 +0800 Subject: [PATCH 01/19] fix --- jax/experimental/pallas/ops/attention.py | 2 +- tests/pallas/pallas_test.py | 16 +++++++++------- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 2cf5cdddb98e..83c1f1c02069 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -77,7 +77,7 @@ def body(start_k, carry): acc = acc + pl.dot(p.astype(v.dtype), v) return acc, m_curr, l_curr if causal: - upper_bound = lax.div(block_q * start_q, block_k) + 1 + upper_bound = pl.cdiv(block_q * (start_q + 1), block_k) else: upper_bound = pl.cdiv(seq_len, block_k) # type: ignore acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index c7c679bcdedf..bf280fb54e94 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1413,13 +1413,15 @@ class FusedAttentionTest(PallasTest): @parameterized.named_parameters(*[ (f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}_{use_fwd=}", batch_size, seq_len, num_heads, head_dim, causal, use_fwd) - for batch_size, seq_len, num_heads, head_dim, causal, use_fwd in [ - (1, 384, 1, 64, False, False), - (2, 384, 2, 64, False, False), - (1, 384, 1, 64, True, False), - (2, 384, 2, 64, True, False), - (1, 384, 8, 64, True, True), - (2, 384, 8, 64, True, True), + for batch_size, seq_len, num_heads, head_dim, causal, use_fwd, block_q in [ + (1, 384, 1, 64, False, False, 128), + (2, 384, 2, 64, False, False, 128), + (1, 384, 1, 64, True, False, 128), + (2, 384, 2, 64, True, False, 128), + (1, 384, 8, 64, True, True, 128), + (2, 384, 8, 64, True, True, 128), + (2, 384, 8, 64, True, True, 128), + (1, 384, 8, 64, True, True, 256), ] ]) def test_fused_attention_fwd(self, batch_size, seq_len, num_heads, head_dim, From 23392e7804925ba7c6fe7134919c3473a97e5298 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Sun, 27 Aug 2023 18:43:42 +0800 Subject: [PATCH 02/19] fix --- tests/pallas/pallas_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index bf280fb54e94..e66b0647745d 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1438,10 +1438,10 @@ def test_fused_attention_fwd(self, batch_size, seq_len, num_heads, head_dim, if use_fwd: @jax.jit def impl(q, k, v): - v, _ = jax.vjp(functools.partial(attention.mha, causal=causal), q, k, v) + v, _ = jax.vjp(functools.partial(attention.mha, causal=causal, block_q=block_q), q, k, v) return v else: - impl = functools.partial(attention.mha, causal=causal) + impl = functools.partial(attention.mha, causal=causal, block_q=block_q) o = impl(q, k, v) o_ref = attention.mha_reference(q, k, v, causal=causal) np.testing.assert_allclose(o, o_ref, atol=0.05) From 22545be68eeb49f637df3d0563f5e8018b2308be Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Sun, 27 Aug 2023 19:06:11 +0800 Subject: [PATCH 03/19] improve --- tests/pallas/pallas_test.py | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index e66b0647745d..9c28975eaeb8 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1413,19 +1413,20 @@ class FusedAttentionTest(PallasTest): @parameterized.named_parameters(*[ (f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}_{use_fwd=}", batch_size, seq_len, num_heads, head_dim, causal, use_fwd) - for batch_size, seq_len, num_heads, head_dim, causal, use_fwd, block_q in [ - (1, 384, 1, 64, False, False, 128), - (2, 384, 2, 64, False, False, 128), - (1, 384, 1, 64, True, False, 128), - (2, 384, 2, 64, True, False, 128), - (1, 384, 8, 64, True, True, 128), - (2, 384, 8, 64, True, True, 128), - (2, 384, 8, 64, True, True, 128), - (1, 384, 8, 64, True, True, 256), + for batch_size, seq_len, num_heads, head_dim, causal, use_fwd, kwargs in [ + (1, 384, 1, 64, False, False, {}), + (2, 384, 2, 64, False, False, {}), + (1, 384, 1, 64, True, False, {}), + (2, 384, 2, 64, True, False, {}), + (1, 384, 8, 64, True, True, {}), + (2, 384, 8, 64, True, True, {}), + (2, 384, 8, 64, True, True, {}), + # regression test: https://github.com/google/jax/pull/17314 + (1, 384, 8, 64, True, True, {'block_q'=256, 'block_k'=128}), ] ]) def test_fused_attention_fwd(self, batch_size, seq_len, num_heads, head_dim, - causal, use_fwd): + causal, use_fwd, kwargs): if plgpu.get_compute_capability(0) < 80: raise unittest.SkipTest( "Fused attention only works on GPUs with capability >= sm80") @@ -1438,10 +1439,10 @@ def test_fused_attention_fwd(self, batch_size, seq_len, num_heads, head_dim, if use_fwd: @jax.jit def impl(q, k, v): - v, _ = jax.vjp(functools.partial(attention.mha, causal=causal, block_q=block_q), q, k, v) + v, _ = jax.vjp(functools.partial(attention.mha, causal=causal, **kwargs), q, k, v) return v else: - impl = functools.partial(attention.mha, causal=causal, block_q=block_q) + impl = functools.partial(attention.mha, causal=causal, **kwargs) o = impl(q, k, v) o_ref = attention.mha_reference(q, k, v, causal=causal) np.testing.assert_allclose(o, o_ref, atol=0.05) From a4f56655bf774f28b0a1fc0ed14505ca35343ff3 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Sun, 27 Aug 2023 23:48:12 +0800 Subject: [PATCH 04/19] fix --- jax/experimental/pallas/ops/attention.py | 3 ++- tests/pallas/pallas_test.py | 7 +++---- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 83c1f1c02069..98c2c16a8ab0 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -77,7 +77,8 @@ def body(start_k, carry): acc = acc + pl.dot(p.astype(v.dtype), v) return acc, m_curr, l_curr if causal: - upper_bound = pl.cdiv(block_q * (start_q + 1), block_k) + # Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q) + upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k) else: upper_bound = pl.cdiv(seq_len, block_k) # type: ignore acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 9c28975eaeb8..44855c15d4df 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1411,8 +1411,8 @@ def body(x_ref): class FusedAttentionTest(PallasTest): @parameterized.named_parameters(*[ - (f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}_{use_fwd=}", - batch_size, seq_len, num_heads, head_dim, causal, use_fwd) + (f"{batch_size=}_{seq_len=}_{num_heads=}_{head_dim=}_{causal=}_{use_fwd=}_{kwargs=}", + batch_size, seq_len, num_heads, head_dim, causal, use_fwd, kwargs) for batch_size, seq_len, num_heads, head_dim, causal, use_fwd, kwargs in [ (1, 384, 1, 64, False, False, {}), (2, 384, 2, 64, False, False, {}), @@ -1420,9 +1420,8 @@ class FusedAttentionTest(PallasTest): (2, 384, 2, 64, True, False, {}), (1, 384, 8, 64, True, True, {}), (2, 384, 8, 64, True, True, {}), - (2, 384, 8, 64, True, True, {}), # regression test: https://github.com/google/jax/pull/17314 - (1, 384, 8, 64, True, True, {'block_q'=256, 'block_k'=128}), + (1, 384, 8, 64, True, True, {'block_q': 128, 'block_k': 64}), ] ]) def test_fused_attention_fwd(self, batch_size, seq_len, num_heads, head_dim, From cefe8b87db4a9f86de1278118a2b45e7959896c6 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Mon, 28 Aug 2023 18:27:59 +0800 Subject: [PATCH 05/19] done --- examples/benchmark_pallas_attention.py | 289 +++++++++++++++++++++++ jax/experimental/pallas/ops/attention.py | 69 ++++++ tests/pallas/pallas_test.py | 2 +- 3 files changed, 359 insertions(+), 1 deletion(-) create mode 100644 examples/benchmark_pallas_attention.py diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py new file mode 100644 index 000000000000..c93c8c805fcd --- /dev/null +++ b/examples/benchmark_pallas_attention.py @@ -0,0 +1,289 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example of benchmarking a Pallas kernel (fused attention). + +Since a common use-case of Pallas may be to write performant tiled kernels with +certain amount of low-level control, users will often want to benchmark their +Pallas implementation against pure JAX implementations (e.g. lowered by XLA) and +implementations in external libraries. + +Here, we show an example benchmarking fused attention for the following: +1. Pallas implementation +2. Pure-JAX implementation +3. Triton implementation (with PyTorch tensor infra) +4. flash_attn implementation + +TODO: +1. cuDNN +2. xformers + +We choose the settings to be similar to those in +https://tridao.me/publications/flash2/flash2.pdf +""" + +import functools +import time + +import matplotlib.pyplot as plt +import triton +import triton.language as tl +from triton import cdiv +import torch + +from jax import random +import jax +import jax.numpy as jnp +from jax.experimental.pallas.ops import attention + + +DIM = 2048 +D_HEAD = 64 +N_HEADS = DIM // D_HEAD +BATCH, SEQ_LEN = 8, 2048 +SEQ_LENS = [128, 256, 512, 1024, 2048, 4096] +NUM_RUNS = 10 + +def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="jax"): + block_qk_grid = [(64, 32), (128, 32), (128, 64)] + k1, k2, k3 = random.split(random.PRNGKey(0), 3) + q = random.normal(k1, (batch, seq_len, heads, d_model), dtype=jnp.float16) + k = random.normal(k2, (batch, seq_len, heads, d_model), dtype=jnp.float16) + v = random.normal(k3, (batch, seq_len, heads, d_model), dtype=jnp.float16) + + functools.partial(attention.mha, causal=causal) + + min_ms = float("inf") + + # Perform a grid search and choose the best timing + for block_q, block_k in block_qk_grid: + if mode == "pallas": + impl = functools.partial( + attention.mha, causal=causal, block_q=block_q, block_k=block_k, num_warps=4) + elif mode == "jax": + if seq_len >= 4096: # Handle OOM + return None + impl = attention.mha_reference + else: + raise ValueError("Invalid JAX benchmark mode") + + # Warm up + impl(q, k, v).block_until_ready() + impl(q, k, v).block_until_ready() + + t1 = time.time() + for _ in range(NUM_RUNS): + impl(q, k, v).block_until_ready() + estimate_ms = 1000 * (time.time() - t1) / NUM_RUNS + min_ms = min(estimate_ms, min_ms) + print(f"{mode} (seq_len={seq_len}, block_q={block_q}, block_k={block_k}): {estimate_ms} ms") + return min_ms + +# Mode is one of {"triton", "flash_attn"} +def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="triton"): + import torch + dtype = torch.float16 + q = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) + if mode == "triton": + # Currently broken: `RuntimeError: CUDA error: an illegal memory access was encountered` + fn = lambda: triton_attention(q, k, v, causal, 1.0) + elif mode == "flash_attn": + from flash_attn import flash_attn_func + # Currently broken: `RuntimeError: CUDA error: an illegal memory access was encountered` + fn = lambda: flash_attn_func(q, k, v, causal=causal) + else: + raise ValueError("Invalid JAX benchmark mode") + + # Warmup + fn() + fn() + torch.cuda.synchronize() + t1 = time.time() + num_runs = 100 + for _ in range(num_runs): + fn() + torch.cuda.synchronize() + estimate_ms = 1000 * (time.time() - t1) / num_runs + return estimate_ms + +def benchmark(causal=True): + y_pallas, y_jax, y_triton, y_flash_attn = [], [], [], [] + + for s in SEQ_LENS: + y_pallas.append(benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="pallas")) + y_jax.append(benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="jax")) + y_triton.append(bench_torch(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="triton")) + y_flash_attn.append(bench_torch(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="flash_attn")) + + for name, y_vals in [ + ("pallas", y_pallas), + ("jax", y_jax), + ("triton", y_triton), + ("flash_attn", y_flash_attn) + ]: + + plt.plot(SEQ_LENS, y_vals, label=name) + for a, b in zip(SEQ_LENS, y_vals): + if b is not None: + plt.text(a, b, str(round(b, 2))) + # plt.plot(SEQ_LENS, y_jax_triton, label='jax+triton') + # plt.plot(SEQ_LENS, y_trit, label='triton') + plt.title(f'Fused Attention ({"Causal" if causal else "Non-Causal"})') + plt.ylabel('time (ms)') + plt.xlabel('Sequence Length') + plt.legend() + plt.show() + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + L, + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # credits to: Adam P. Goucher (https://github.com/apgoucher): + # scale sm_scale by 1/log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(K.dtype.element_ty) + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k, allow_tf32=True) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + # write back l and m + acc = acc / l_i[:, None] + l_ptrs = L + off_hz * N_CTX + offs_m + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): + # only support for Ampere now + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported for compute capability >= 80") + BLOCK_M = 128 + BLOCK_N = 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( + q, k, v, sm_scale, + L, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + IS_CAUSAL=causal, + num_warps=num_warps, + num_stages=4) + + ctx.save_for_backward(q, k, v, o, L) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + ctx.causal = causal + ctx.sequence_parallel = sequence_parallel + return o + +triton_attention = _attention.apply + +if __name__ == '__main__': + benchmark() diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 98c2c16a8ab0..6af2cebdf1cc 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -92,6 +92,75 @@ def body(start_k, carry): acc = acc.astype(o_ref.dtype) pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc) +def _mha_forward_kernel( + q_ref, k_ref, v_ref, # Input arrays + o_ref, # Output + *residual_refs, # Residual outputs + sm_scale: float, causal: bool, + block_q: int, block_d: int, block_k: int): + seq_len = q_ref.shape[0] + start_q = pl.program_id(0) + + # acc is the buffer where we accumulate the output on sram. + # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. + m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf') + l_i = jnp.zeros(block_q, dtype=jnp.float32) + # acc is the buffer where we accumulate the output on sram. + acc = jnp.zeros((block_q, block_d), dtype=jnp.float32) + + # Load q: it will stay in L1 throughout. Indices form a matrix because we + # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. + # q tile has shape [block_q, block_d], block_d == head_dim. + q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None))) + # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size + # (Bc == block_k here), and fast over blocks of q (size Br == block_q here). + # Here we only loop over blocks of kv to process entire seq_len, the loop over + # blocks of q is carried out by the grid. + def body(start_k, carry): + acc, m_prev, l_prev = carry + + k = pl.load(k_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) + qk = jnp.zeros([block_q, block_k], dtype=jnp.float32) + qk += pl.dot(q, k.T) # [block_q, block_k] + if sm_scale != 1.: + qk *= sm_scale # [block_q, block_k] + + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + span_k = start_k * block_k + jnp.arange(block_k) + qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) + # Bring closer to XLA:GPU numerics. + qk = qk.astype(q_ref.dtype) + qk = qk.astype(jnp.float32) + m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev) + alpha = jnp.exp(m_prev - m_curr) + p = jnp.exp(qk - m_curr[:, None]) + l_curr = jnp.sum(p, axis=1) + alpha * l_prev + + acc *= alpha[:, None] + p = p.astype(jnp.float16) + + v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), pl.dslice(block_d))) + acc = acc + pl.dot(p.astype(v.dtype), v) + return acc, m_curr, l_curr + if causal: + # Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q) + upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k) + else: + upper_bound = pl.cdiv(seq_len, block_k) # type: ignore + acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, + (acc, m_i, l_i)) + + acc = acc / l_i[:, None] + + if residual_refs: + l_ref, m_ref = residual_refs + pl.store(l_ref, (pl.ds(start_q * block_q, block_q),), l_i) + pl.store(m_ref, (pl.ds(start_q * block_q, block_q),), m_i) + # Write output to dram. + acc = acc.astype(o_ref.dtype) + pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc) + @functools.partial(jax.custom_vjp, nondiff_argnums=[3, 4, 5, 6, 7, 8, 9, 10, 11, 12]) @functools.partial(jax.jit, static_argnames=["sm_scale", "causal", "block_q", "block_k", "backward_pass_impl", diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 44855c15d4df..0f7c51a0c159 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1421,7 +1421,7 @@ class FusedAttentionTest(PallasTest): (1, 384, 8, 64, True, True, {}), (2, 384, 8, 64, True, True, {}), # regression test: https://github.com/google/jax/pull/17314 - (1, 384, 8, 64, True, True, {'block_q': 128, 'block_k': 64}), + (1, 384, 8, 64, True, False, {'block_q': 128, 'block_k': 64}), ] ]) def test_fused_attention_fwd(self, batch_size, seq_len, num_heads, head_dim, From c721df97a528091a42142bd3310fce618ca5ca5c Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Mon, 28 Aug 2023 18:38:36 +0800 Subject: [PATCH 06/19] minor --- examples/benchmark_pallas_attention.py | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index c93c8c805fcd..f980e66b4a80 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -98,11 +98,15 @@ def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, cau k = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) v = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) if mode == "triton": - # Currently broken: `RuntimeError: CUDA error: an illegal memory access was encountered` + """ + Triton implementation broken in dep of jax-triton: + `RuntimeError: CUDA error: an illegal memory access was encountered` + """ + # from triton.ops import attention as triton_attention + # Use a jitted function from triton nightly 28/08/23 as defined below. fn = lambda: triton_attention(q, k, v, causal, 1.0) elif mode == "flash_attn": from flash_attn import flash_attn_func - # Currently broken: `RuntimeError: CUDA error: an illegal memory access was encountered` fn = lambda: flash_attn_func(q, k, v, causal=causal) else: raise ValueError("Invalid JAX benchmark mode") @@ -147,6 +151,13 @@ def benchmark(causal=True): plt.legend() plt.show() +if __name__ == '__main__': + benchmark() + +""" +Appendix +""" + @triton.jit def _fwd_kernel( Q, K, V, sm_scale, @@ -285,5 +296,3 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): triton_attention = _attention.apply -if __name__ == '__main__': - benchmark() From 961957f4b6373451ac97ce391d86b7f8313731f4 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 30 Aug 2023 07:51:10 +0800 Subject: [PATCH 07/19] reorder --- examples/benchmark_pallas_attention.py | 198 +++++++++++++------------ 1 file changed, 100 insertions(+), 98 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index f980e66b4a80..ab712a397359 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -55,104 +55,6 @@ SEQ_LENS = [128, 256, 512, 1024, 2048, 4096] NUM_RUNS = 10 -def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="jax"): - block_qk_grid = [(64, 32), (128, 32), (128, 64)] - k1, k2, k3 = random.split(random.PRNGKey(0), 3) - q = random.normal(k1, (batch, seq_len, heads, d_model), dtype=jnp.float16) - k = random.normal(k2, (batch, seq_len, heads, d_model), dtype=jnp.float16) - v = random.normal(k3, (batch, seq_len, heads, d_model), dtype=jnp.float16) - - functools.partial(attention.mha, causal=causal) - - min_ms = float("inf") - - # Perform a grid search and choose the best timing - for block_q, block_k in block_qk_grid: - if mode == "pallas": - impl = functools.partial( - attention.mha, causal=causal, block_q=block_q, block_k=block_k, num_warps=4) - elif mode == "jax": - if seq_len >= 4096: # Handle OOM - return None - impl = attention.mha_reference - else: - raise ValueError("Invalid JAX benchmark mode") - - # Warm up - impl(q, k, v).block_until_ready() - impl(q, k, v).block_until_ready() - - t1 = time.time() - for _ in range(NUM_RUNS): - impl(q, k, v).block_until_ready() - estimate_ms = 1000 * (time.time() - t1) / NUM_RUNS - min_ms = min(estimate_ms, min_ms) - print(f"{mode} (seq_len={seq_len}, block_q={block_q}, block_k={block_k}): {estimate_ms} ms") - return min_ms - -# Mode is one of {"triton", "flash_attn"} -def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="triton"): - import torch - dtype = torch.float16 - q = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) - if mode == "triton": - """ - Triton implementation broken in dep of jax-triton: - `RuntimeError: CUDA error: an illegal memory access was encountered` - """ - # from triton.ops import attention as triton_attention - # Use a jitted function from triton nightly 28/08/23 as defined below. - fn = lambda: triton_attention(q, k, v, causal, 1.0) - elif mode == "flash_attn": - from flash_attn import flash_attn_func - fn = lambda: flash_attn_func(q, k, v, causal=causal) - else: - raise ValueError("Invalid JAX benchmark mode") - - # Warmup - fn() - fn() - torch.cuda.synchronize() - t1 = time.time() - num_runs = 100 - for _ in range(num_runs): - fn() - torch.cuda.synchronize() - estimate_ms = 1000 * (time.time() - t1) / num_runs - return estimate_ms - -def benchmark(causal=True): - y_pallas, y_jax, y_triton, y_flash_attn = [], [], [], [] - - for s in SEQ_LENS: - y_pallas.append(benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="pallas")) - y_jax.append(benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="jax")) - y_triton.append(bench_torch(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="triton")) - y_flash_attn.append(bench_torch(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="flash_attn")) - - for name, y_vals in [ - ("pallas", y_pallas), - ("jax", y_jax), - ("triton", y_triton), - ("flash_attn", y_flash_attn) - ]: - - plt.plot(SEQ_LENS, y_vals, label=name) - for a, b in zip(SEQ_LENS, y_vals): - if b is not None: - plt.text(a, b, str(round(b, 2))) - # plt.plot(SEQ_LENS, y_jax_triton, label='jax+triton') - # plt.plot(SEQ_LENS, y_trit, label='triton') - plt.title(f'Fused Attention ({"Causal" if causal else "Non-Causal"})') - plt.ylabel('time (ms)') - plt.xlabel('Sequence Length') - plt.legend() - plt.show() - -if __name__ == '__main__': - benchmark() """ Appendix @@ -296,3 +198,103 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): triton_attention = _attention.apply +def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="jax"): + block_qk_grid = [(64, 32), (128, 32), (128, 64)] + k1, k2, k3 = random.split(random.PRNGKey(0), 3) + q = random.normal(k1, (batch, seq_len, heads, d_model), dtype=jnp.float16) + k = random.normal(k2, (batch, seq_len, heads, d_model), dtype=jnp.float16) + v = random.normal(k3, (batch, seq_len, heads, d_model), dtype=jnp.float16) + + functools.partial(attention.mha, causal=causal) + + min_ms = float("inf") + + # Perform a grid search and choose the best timing + for block_q, block_k in block_qk_grid: + if mode == "pallas": + impl = functools.partial( + attention.mha, causal=causal, block_q=block_q, block_k=block_k, num_warps=4) + elif mode == "jax": + if seq_len >= 4096: # Handle OOM + return None + impl = attention.mha_reference + else: + raise ValueError("Invalid JAX benchmark mode") + + # Warm up + impl(q, k, v).block_until_ready() + impl(q, k, v).block_until_ready() + + t1 = time.time() + for _ in range(NUM_RUNS): + impl(q, k, v).block_until_ready() + estimate_ms = 1000 * (time.time() - t1) / NUM_RUNS + min_ms = min(estimate_ms, min_ms) + print(f"{mode} (seq_len={seq_len}, block_q={block_q}, block_k={block_k}): {estimate_ms} ms") + return min_ms + +# Mode is one of {"triton", "flash_attn"} +def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="triton"): + import torch + dtype = torch.float16 + q = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) + if mode == "triton": + """ + Triton implementation broken in dep of jax-triton: + `RuntimeError: CUDA error: an illegal memory access was encountered` + """ + # from triton.ops import attention as triton_attention + # Use a jitted function from triton nightly 28/08/23 as defined below. + fn = lambda: triton_attention(q, k, v, causal, 1.0) + elif mode == "flash_attn": + from flash_attn import flash_attn_func + fn = lambda: flash_attn_func(q, k, v, causal=causal) + else: + raise ValueError("Invalid JAX benchmark mode") + + # Warmup + fn() + fn() + torch.cuda.synchronize() + t1 = time.time() + num_runs = 100 + for _ in range(num_runs): + fn() + torch.cuda.synchronize() + estimate_ms = 1000 * (time.time() - t1) / num_runs + return estimate_ms + +def benchmark(causal=True): + y_pallas, y_jax, y_triton, y_flash_attn = [], [], [], [] + + for s in SEQ_LENS: + y_pallas.append(benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="pallas")) + y_jax.append(benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="jax")) + y_triton.append(bench_torch(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="triton")) + y_flash_attn.append(bench_torch(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="flash_attn")) + + for name, y_vals in [ + ("pallas", y_pallas), + ("jax", y_jax), + ("triton", y_triton), + ("flash_attn", y_flash_attn) + ]: + + plt.plot(SEQ_LENS, y_vals, label=name) + for a, b in zip(SEQ_LENS, y_vals): + if b is not None: + plt.text(a, b, str(round(b, 2))) + # plt.plot(SEQ_LENS, y_jax_triton, label='jax+triton') + # plt.plot(SEQ_LENS, y_trit, label='triton') + plt.title(f'Fused Attention ({"Causal" if causal else "Non-Causal"})') + plt.ylabel('time (ms)') + plt.xlabel('Sequence Length') + plt.legend() + plt.show() + +if __name__ == '__main__': + benchmark() + + From 16f7b75c7532c0c69f707ba2b824a112d439c062 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 30 Aug 2023 13:27:26 +0800 Subject: [PATCH 08/19] logscale --- examples/benchmark_pallas_attention.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index ab712a397359..885eb8fd078d 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -276,8 +276,8 @@ def benchmark(causal=True): y_flash_attn.append(bench_torch(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="flash_attn")) for name, y_vals in [ - ("pallas", y_pallas), ("jax", y_jax), + ("pallas", y_pallas), ("triton", y_triton), ("flash_attn", y_flash_attn) ]: @@ -291,6 +291,7 @@ def benchmark(causal=True): plt.title(f'Fused Attention ({"Causal" if causal else "Non-Causal"})') plt.ylabel('time (ms)') plt.xlabel('Sequence Length') + plt.yscale("log") plt.legend() plt.show() From ae768fc36de95940aecfecaf2c243057b2b84f4c Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 30 Aug 2023 14:09:37 +0800 Subject: [PATCH 09/19] delayed softmax - track until pallas ~ triton --- examples/benchmark_pallas_attention.py | 40 ++++++++--- jax/experimental/pallas/ops/attention.py | 90 +++++------------------- 2 files changed, 49 insertions(+), 81 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index 885eb8fd078d..24fd3dacc017 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -73,6 +73,7 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, + DELAYED_ONLINE_SOFTMAX: tl.constexpr, ): start_m = tl.program_id(0) off_hz = tl.program_id(1) @@ -131,18 +132,35 @@ def _fwd_kernel( m_i_new = tl.maximum(m_i, tl.max(qk, 1)) alpha = tl.math.exp2(m_i - m_i_new) p = tl.math.exp2(qk - m_i_new[:, None]) - # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] - acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) - # -- update m_i and l_i -- - l_i = l_i * alpha + tl.sum(p, 1) - m_i = m_i_new + + if DELAYED_ONLINE_SOFTMAX: + # -- scale and update acc -- + acc_scale = l_i * 0 + alpha # workaround some compiler bug + acc *= acc_scale[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + else: + l_i = l_i * alpha + l_i_new = l_i + tl.sum(p, 1) + l_rcp = 1. / l_i_new + p *= l_rcp[:, None] + # -- scale and update acc -- + acc *= (l_i * l_rcp)[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) + # -- update m_i and l_i -- + l_i = l_i_new + m_i = m_i_new + # update pointers K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + if DELAYED_ONLINE_SOFTMAX: + acc = acc / l_i[:, None] + # write back l and m - acc = acc / l_i[:, None] l_ptrs = L + off_hz * N_CTX + offs_m tl.store(l_ptrs, m_i + tl.math.log2(l_i)) # write back O @@ -185,6 +203,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, IS_CAUSAL=causal, + DELAYED_ONLINE_SOFTMAX=False, num_warps=num_warps, num_stages=4) @@ -266,6 +285,10 @@ def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, cau estimate_ms = 1000 * (time.time() - t1) / num_runs return estimate_ms +# TODO: implement this +def test_allclose(): + pass + def benchmark(causal=True): y_pallas, y_jax, y_triton, y_flash_attn = [], [], [], [] @@ -296,6 +319,7 @@ def benchmark(causal=True): plt.show() if __name__ == '__main__': + test_allclose() benchmark() diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 6af2cebdf1cc..a4fb9da682e6 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -23,6 +23,9 @@ from jax.experimental import pallas as pl +# Currently slower on Pallas but faster on Triton +DELAYED_ONLINE_SOFTMAX = True + def mha_forward_kernel( q_ref, k_ref, v_ref, # Input arrays o_ref, # Output @@ -64,81 +67,22 @@ def body(start_k, carry): qk = qk.astype(q_ref.dtype) qk = qk.astype(jnp.float32) m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev) - l_prev *= jnp.exp(m_prev - m_curr) + alpha = jnp.exp(m_prev - m_curr) p = jnp.exp(qk - m_curr[:, None]) - l_curr = jnp.sum(p, axis=1) + l_prev - l_rcp = 1. / l_curr - p = p * l_rcp[:, None] - acc *= (l_prev * l_rcp)[:, None] - p = p.astype(jnp.float16) + if DELAYED_ONLINE_SOFTMAX: + l_curr = jnp.sum(p, axis=1) + alpha * l_prev - v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), pl.dslice(block_d))) - acc = acc + pl.dot(p.astype(v.dtype), v) - return acc, m_curr, l_curr - if causal: - # Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q) - upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k) - else: - upper_bound = pl.cdiv(seq_len, block_k) # type: ignore - acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, - (acc, m_i, l_i)) - - if residual_refs: - l_ref, m_ref = residual_refs - pl.store(l_ref, (pl.ds(start_q * block_q, block_q),), l_i) - pl.store(m_ref, (pl.ds(start_q * block_q, block_q),), m_i) - # Write output to dram. - acc = acc.astype(o_ref.dtype) - pl.store(o_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None)), acc) - -def _mha_forward_kernel( - q_ref, k_ref, v_ref, # Input arrays - o_ref, # Output - *residual_refs, # Residual outputs - sm_scale: float, causal: bool, - block_q: int, block_d: int, block_k: int): - seq_len = q_ref.shape[0] - start_q = pl.program_id(0) - - # acc is the buffer where we accumulate the output on sram. - # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. - m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf') - l_i = jnp.zeros(block_q, dtype=jnp.float32) - # acc is the buffer where we accumulate the output on sram. - acc = jnp.zeros((block_q, block_d), dtype=jnp.float32) - - # Load q: it will stay in L1 throughout. Indices form a matrix because we - # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. - # q tile has shape [block_q, block_d], block_d == head_dim. - q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None))) - # In FlashAttention algorithm 1 there are 2 loops: slow over tiles of kv (size - # (Bc == block_k here), and fast over blocks of q (size Br == block_q here). - # Here we only loop over blocks of kv to process entire seq_len, the loop over - # blocks of q is carried out by the grid. - def body(start_k, carry): - acc, m_prev, l_prev = carry - - k = pl.load(k_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) - qk = jnp.zeros([block_q, block_k], dtype=jnp.float32) - qk += pl.dot(q, k.T) # [block_q, block_k] - if sm_scale != 1.: - qk *= sm_scale # [block_q, block_k] - - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - span_k = start_k * block_k + jnp.arange(block_k) - qk = jnp.where(span_q[:, None] >= span_k[None, :], qk, float('-inf')) - # Bring closer to XLA:GPU numerics. - qk = qk.astype(q_ref.dtype) - qk = qk.astype(jnp.float32) - m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev) - alpha = jnp.exp(m_prev - m_curr) - p = jnp.exp(qk - m_curr[:, None]) - l_curr = jnp.sum(p, axis=1) + alpha * l_prev + acc *= alpha[:, None] + p = p.astype(jnp.float16) + else: + l_prev *= alpha + l_curr = jnp.sum(p, axis=1) + l_prev - acc *= alpha[:, None] - p = p.astype(jnp.float16) + l_rcp = 1. / l_curr + p = p * l_rcp[:, None] + acc *= (l_prev * l_rcp)[:, None] + p = p.astype(jnp.float16) v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), pl.dslice(block_d))) acc = acc + pl.dot(p.astype(v.dtype), v) @@ -150,8 +94,8 @@ def body(start_k, carry): upper_bound = pl.cdiv(seq_len, block_k) # type: ignore acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, (acc, m_i, l_i)) - - acc = acc / l_i[:, None] + if DELAYED_ONLINE_SOFTMAX: + acc = acc / l_i[:, None] if residual_refs: l_ref, m_ref = residual_refs From b92b0810ca6169849dca9b3abe29c4f3bc4dbff3 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 30 Aug 2023 14:24:16 +0800 Subject: [PATCH 10/19] Triton compiler bug? --- examples/benchmark_pallas_attention.py | 11 +++++------ jax/experimental/pallas/ops/attention.py | 3 ++- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index 24fd3dacc017..37f5e9d02748 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -53,8 +53,8 @@ N_HEADS = DIM // D_HEAD BATCH, SEQ_LEN = 8, 2048 SEQ_LENS = [128, 256, 512, 1024, 2048, 4096] -NUM_RUNS = 10 - +NUM_RUNS = 30 +DELAYED_ONLINE_SOFTMAX = True """ Appendix @@ -203,7 +203,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, IS_CAUSAL=causal, - DELAYED_ONLINE_SOFTMAX=False, + DELAYED_ONLINE_SOFTMAX=DELAYED_ONLINE_SOFTMAX, num_warps=num_warps, num_stages=4) @@ -278,11 +278,10 @@ def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, cau fn() torch.cuda.synchronize() t1 = time.time() - num_runs = 100 - for _ in range(num_runs): + for _ in range(NUM_RUNS): fn() torch.cuda.synchronize() - estimate_ms = 1000 * (time.time() - t1) / num_runs + estimate_ms = 1000 * (time.time() - t1) / NUM_RUNS return estimate_ms # TODO: implement this diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index a4fb9da682e6..b52e873cbb2a 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -73,7 +73,8 @@ def body(start_k, carry): if DELAYED_ONLINE_SOFTMAX: l_curr = jnp.sum(p, axis=1) + alpha * l_prev - acc *= alpha[:, None] + # Adding 0 * l_prev is due to a weird compiler bug in Triton + acc *= (0 * l_prev + alpha)[:, None] p = p.astype(jnp.float16) else: l_prev *= alpha From c5d256836b96f69b9d12d06ace823a263c5be504 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 30 Aug 2023 21:32:44 +0800 Subject: [PATCH 11/19] minor --- examples/benchmark_pallas_attention.py | 2 +- jax/experimental/pallas/ops/attention.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index 37f5e9d02748..d6606d2eb46f 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -52,7 +52,7 @@ D_HEAD = 64 N_HEADS = DIM // D_HEAD BATCH, SEQ_LEN = 8, 2048 -SEQ_LENS = [128, 256, 512, 1024, 2048, 4096] +SEQ_LENS = [128, 256, 512, 1024, 2048, 4096, 8192] NUM_RUNS = 30 DELAYED_ONLINE_SOFTMAX = True diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index b52e873cbb2a..24f44d8ae054 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -73,7 +73,7 @@ def body(start_k, carry): if DELAYED_ONLINE_SOFTMAX: l_curr = jnp.sum(p, axis=1) + alpha * l_prev - # Adding 0 * l_prev is due to a weird compiler bug in Triton + # `0 * l_prev` is to handle weird compiler perf bug in Triton acc *= (0 * l_prev + alpha)[:, None] p = p.astype(jnp.float16) else: From 9ad147849bf2576783308d516515aa415a15b703 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:09:36 +0800 Subject: [PATCH 12/19] update --- examples/benchmark_pallas_attention.py | 77 +++++++++++++++--------- jax/experimental/pallas/ops/attention.py | 7 +-- 2 files changed, 50 insertions(+), 34 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index d6606d2eb46f..736660781e5c 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -35,6 +35,7 @@ import functools import time +import math import matplotlib.pyplot as plt import triton @@ -52,9 +53,11 @@ D_HEAD = 64 N_HEADS = DIM // D_HEAD BATCH, SEQ_LEN = 8, 2048 -SEQ_LENS = [128, 256, 512, 1024, 2048, 4096, 8192] +SEQ_LENS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384] NUM_RUNS = 30 -DELAYED_ONLINE_SOFTMAX = True +BETWEEN_RUN_SLEEP_TIME_MS = 200 + +DELAYED_SOFTMAX_NORMALIZE = True """ Appendix @@ -73,7 +76,7 @@ def _fwd_kernel( BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, IS_CAUSAL: tl.constexpr, - DELAYED_ONLINE_SOFTMAX: tl.constexpr, + DELAYED_SOFTMAX_NORMALIZE: tl.constexpr, ): start_m = tl.program_id(0) off_hz = tl.program_id(1) @@ -133,7 +136,7 @@ def _fwd_kernel( alpha = tl.math.exp2(m_i - m_i_new) p = tl.math.exp2(qk - m_i_new[:, None]) - if DELAYED_ONLINE_SOFTMAX: + if DELAYED_SOFTMAX_NORMALIZE: # -- scale and update acc -- acc_scale = l_i * 0 + alpha # workaround some compiler bug acc *= acc_scale[:, None] @@ -157,7 +160,7 @@ def _fwd_kernel( K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) - if DELAYED_ONLINE_SOFTMAX: + if DELAYED_SOFTMAX_NORMALIZE: acc = acc / l_i[:, None] # write back l and m @@ -203,7 +206,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): q.shape[0], q.shape[1], q.shape[2], BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, IS_CAUSAL=causal, - DELAYED_ONLINE_SOFTMAX=DELAYED_ONLINE_SOFTMAX, + DELAYED_SOFTMAX_NORMALIZE=DELAYED_SOFTMAX_NORMALIZE, num_warps=num_warps, num_stages=4) @@ -218,7 +221,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): triton_attention = _attention.apply def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="jax"): - block_qk_grid = [(64, 32), (128, 32), (128, 64)] + block_qk_grid = [(64, 32), (128, 32), (64, 64), (128, 64)] k1, k2, k3 = random.split(random.PRNGKey(0), 3) q = random.normal(k1, (batch, seq_len, heads, d_model), dtype=jnp.float16) k = random.normal(k2, (batch, seq_len, heads, d_model), dtype=jnp.float16) @@ -234,7 +237,7 @@ def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, c impl = functools.partial( attention.mha, causal=causal, block_q=block_q, block_k=block_k, num_warps=4) elif mode == "jax": - if seq_len >= 4096: # Handle OOM + if seq_len >= 2048: # Handle OOM return None impl = attention.mha_reference else: @@ -256,9 +259,6 @@ def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, c def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="triton"): import torch dtype = torch.float16 - q = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) if mode == "triton": """ Triton implementation broken in dep of jax-triton: @@ -267,9 +267,15 @@ def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, cau # from triton.ops import attention as triton_attention # Use a jitted function from triton nightly 28/08/23 as defined below. fn = lambda: triton_attention(q, k, v, causal, 1.0) + q = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) elif mode == "flash_attn": from flash_attn import flash_attn_func fn = lambda: flash_attn_func(q, k, v, causal=causal) + q = torch.randn((batch, seq_len, heads, d_model), dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn((batch, seq_len, heads, d_model), dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn((batch, seq_len, heads, d_model), dtype=dtype, device="cuda", requires_grad=True) else: raise ValueError("Invalid JAX benchmark mode") @@ -280,7 +286,7 @@ def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, cau t1 = time.time() for _ in range(NUM_RUNS): fn() - torch.cuda.synchronize() + torch.cuda.synchronize() estimate_ms = 1000 * (time.time() - t1) / NUM_RUNS return estimate_ms @@ -289,30 +295,41 @@ def test_allclose(): pass def benchmark(causal=True): - y_pallas, y_jax, y_triton, y_flash_attn = [], [], [], [] + configs = [ + {'name': name, 'timings': [], 'tokens': tokens } + for name in ["jax", "pallas", "triton", "flash_attn"] + for tokens in [32768] + ] + bench_fns = { + 'jax': functools.partial(benchmark_jax, mode="jax"), + 'pallas': functools.partial(benchmark_jax, mode="pallas"), + 'flash_attn': functools.partial(bench_torch, mode="flash_attn"), + 'triton': functools.partial(bench_torch, mode="triton"), + } + fig, ax = plt.subplots() + ax.ticklabel_format(useOffset=False) for s in SEQ_LENS: - y_pallas.append(benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="pallas")) - y_jax.append(benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="jax")) - y_triton.append(bench_torch(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="triton")) - y_flash_attn.append(bench_torch(batch=BATCH, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal, mode="flash_attn")) - - for name, y_vals in [ - ("jax", y_jax), - ("pallas", y_pallas), - ("triton", y_triton), - ("flash_attn", y_flash_attn) - ]: - - plt.plot(SEQ_LENS, y_vals, label=name) - for a, b in zip(SEQ_LENS, y_vals): + for config in configs: + config['timings'].append( + bench_fns[config['name']]( + batch=config['tokens'] // s, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal) + ) + time.sleep(BETWEEN_RUN_SLEEP_TIME_MS / 1000) + + for config in configs: + ax.plot(SEQ_LENS, config['timings'], label=config['name'] + '_' + str(config['tokens'] // 1000) + 'K_tokens' ) + for seq_len, b in zip(SEQ_LENS, config['timings']): if b is not None: - plt.text(a, b, str(round(b, 2))) - # plt.plot(SEQ_LENS, y_jax_triton, label='jax+triton') - # plt.plot(SEQ_LENS, y_trit, label='triton') + b_height = b * 1.05 if (config['name'] == "triton")\ + else b * 0.95 if (config['name'] == "flash_attn")\ + else b + tflops = (4 * seq_len ** 2 * D_HEAD * N_HEADS * config['tokens'] / seq_len) / (b * 10 ** 9) + plt.text(seq_len, b_height, f'Time={round(b, 2)}ms,TFLOPs={round(tflops, 2)}') plt.title(f'Fused Attention ({"Causal" if causal else "Non-Causal"})') plt.ylabel('time (ms)') plt.xlabel('Sequence Length') + ax.set_xticks(SEQ_LENS, SEQ_LENS) plt.yscale("log") plt.legend() plt.show() diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 24f44d8ae054..ce758877631c 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -23,8 +23,7 @@ from jax.experimental import pallas as pl -# Currently slower on Pallas but faster on Triton -DELAYED_ONLINE_SOFTMAX = True +DELAYED_SOFTMAX_NORMALIZE = True def mha_forward_kernel( q_ref, k_ref, v_ref, # Input arrays @@ -70,7 +69,7 @@ def body(start_k, carry): alpha = jnp.exp(m_prev - m_curr) p = jnp.exp(qk - m_curr[:, None]) - if DELAYED_ONLINE_SOFTMAX: + if DELAYED_SOFTMAX_NORMALIZE: l_curr = jnp.sum(p, axis=1) + alpha * l_prev # `0 * l_prev` is to handle weird compiler perf bug in Triton @@ -95,7 +94,7 @@ def body(start_k, carry): upper_bound = pl.cdiv(seq_len, block_k) # type: ignore acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, (acc, m_i, l_i)) - if DELAYED_ONLINE_SOFTMAX: + if DELAYED_SOFTMAX_NORMALIZE: acc = acc / l_i[:, None] if residual_refs: From 979fd78b06500b85621229d0b0cb31ea88b2c880 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:20:53 +0800 Subject: [PATCH 13/19] merge master --- examples/benchmark_pallas_attention.py | 4 ++-- jax/experimental/pallas/ops/attention.py | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index 736660781e5c..8503330e3a7b 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -235,11 +235,11 @@ def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, c for block_q, block_k in block_qk_grid: if mode == "pallas": impl = functools.partial( - attention.mha, causal=causal, block_q=block_q, block_k=block_k, num_warps=4) + attention.mha, causal=causal, block_q=block_q, block_k=block_k, num_warps=4, segment_ids=None) elif mode == "jax": if seq_len >= 2048: # Handle OOM return None - impl = attention.mha_reference + impl = functools.partial(attention.mha_reference, segment_ids=None) else: raise ValueError("Invalid JAX benchmark mode") diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index ad0278b2f922..749ad4887038 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -26,7 +26,6 @@ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) - DELAYED_SOFTMAX_NORMALIZE = True def mha_forward_kernel( From f4ce448c7d612dcd38a4643bf1cbe7ea1a682ec0 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 30 Aug 2023 22:54:16 +0800 Subject: [PATCH 14/19] apply exp2 --- examples/benchmark_pallas_attention.py | 2 +- jax/_src/pallas/triton/lowering.py | 7 +++++++ jax/experimental/pallas/ops/attention.py | 20 +++++++++++--------- 3 files changed, 19 insertions(+), 10 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index 8503330e3a7b..e6ea53d9f76e 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -186,7 +186,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): if capability[0] < 8: raise RuntimeError("Flash attention currently only supported for compute capability >= 80") BLOCK_M = 128 - BLOCK_N = 64 + BLOCK_N = 32 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 0403333da574..69d37f6073a3 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -391,6 +391,13 @@ def _exp_lowering_rule(ctx: TritonLoweringRuleContext, a): triton_lowering_rules[lax.exp_p] = _exp_lowering_rule +def _exp2_lowering_rule(ctx: TritonLoweringRuleContext, a): + return tl.math.exp2(a, _builder=ctx.builder) + + +triton_lowering_rules[lax.exp2_p] = _exp2_lowering_rule + + def _log_lowering_rule(ctx: TritonLoweringRuleContext, a): return tl.log(a, _builder=ctx.builder) diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 749ad4887038..0e99fe975f3d 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -75,12 +75,13 @@ def body(start_k, carry): ) qk = jnp.zeros([block_q, block_k], dtype=jnp.float32) qk += pl.dot(q, k.T) # [block_q, block_k] - if sm_scale != 1.: - qk *= sm_scale # [block_q, block_k] # Bring closer to XLA:GPU numerics. qk = qk.astype(q_ref.dtype) qk = qk.astype(jnp.float32) + qk_scale = sm_scale * 1.44269504089 + if qk_scale != 1.: + qk *= qk_scale # [block_q, block_k] if causal or segment_ids_ref is not None: mask = None @@ -96,15 +97,14 @@ def body(start_k, carry): # Apply mask to qk. qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev) - alpha = jnp.exp(m_prev - m_curr) - p = jnp.exp(qk - m_curr[:, None]) + alpha = jnp.exp2(m_prev - m_curr) + p = jnp.exp2(qk - m_curr[:, None]) if DELAYED_SOFTMAX_NORMALIZE: l_curr = jnp.sum(p, axis=1) + alpha * l_prev # `0 * l_prev` is to handle weird compiler perf bug in Triton acc *= (0 * l_prev + alpha)[:, None] - p = p.astype(jnp.float16) else: l_prev *= alpha l_curr = jnp.sum(p, axis=1) + l_prev @@ -112,9 +112,9 @@ def body(start_k, carry): l_rcp = 1. / l_curr p = p * l_rcp[:, None] acc *= (l_prev * l_rcp)[:, None] - p = p.astype(jnp.float16) v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), pl.dslice(block_d))) + p = p.astype(jnp.float16) acc = acc + pl.dot(p.astype(v.dtype), v) return acc, m_curr, l_curr if causal: @@ -400,8 +400,10 @@ def inner_loop(start_q, carry): qk = pl.dot(q, k.T) qk = qk.astype(q_ref.dtype) qk = qk.astype(jnp.float32) - if sm_scale != 1.0: - qk *= sm_scale + + qk_scale = sm_scale * 1.44269504089 + if qk_scale != 1.0: + qk *= qk_scale q_segment_ids = ( None @@ -425,7 +427,7 @@ def inner_loop(start_q, carry): qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) - p = jnp.exp(qk - m[:, None]) + p = jnp.exp2(qk - m[:, None]) do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) dv = dv + pl.dot(p.astype(do.dtype).T, do) di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) From bea4e44cf1ba2a82baaaa589697fd477eb4cf1ba Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Wed, 30 Aug 2023 23:06:09 +0800 Subject: [PATCH 15/19] minor --- examples/benchmark_pallas_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index e6ea53d9f76e..8503330e3a7b 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -186,7 +186,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): if capability[0] < 8: raise RuntimeError("Flash attention currently only supported for compute capability >= 80") BLOCK_M = 128 - BLOCK_N = 32 + BLOCK_N = 64 # shape constraints Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] assert Lq == Lk and Lk == Lv From c59e7a70402731f3ab1bcfe2b20fdd700227d692 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Mon, 4 Sep 2023 19:04:40 +0800 Subject: [PATCH 16/19] update --- examples/benchmark_pallas_attention.py | 115 ++++++++++++------- jax/experimental/pallas/ops/attention.py | 139 +++++++++-------------- tests/pallas/pallas_test.py | 2 +- 3 files changed, 134 insertions(+), 122 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index 8503330e3a7b..1e1cd346c1aa 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -44,6 +44,7 @@ import torch from jax import random +import random as pyrand import jax import jax.numpy as jnp from jax.experimental.pallas.ops import attention @@ -54,10 +55,15 @@ N_HEADS = DIM // D_HEAD BATCH, SEQ_LEN = 8, 2048 SEQ_LENS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384] -NUM_RUNS = 30 + +TOTAL_RUNS = 150 +NUM_OUTER_RUNS = 5 # For randomization of benchmark order +NUM_INNER_RUNS = (TOTAL_RUNS + NUM_OUTER_RUNS - 1) // NUM_OUTER_RUNS + BETWEEN_RUN_SLEEP_TIME_MS = 200 DELAYED_SOFTMAX_NORMALIZE = True +SEPARATE_ON_GRAPH = True """ Appendix @@ -138,8 +144,7 @@ def _fwd_kernel( if DELAYED_SOFTMAX_NORMALIZE: # -- scale and update acc -- - acc_scale = l_i * 0 + alpha # workaround some compiler bug - acc *= acc_scale[:, None] + acc *= alpha[:, None] acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) # -- update m_i and l_i -- l_i = l_i * alpha + tl.sum(p, 1) @@ -195,7 +200,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) num_warps = 4 if Lk <= 64 else 8 - _fwd_kernel[grid]( + compiled = _fwd_kernel[grid]( q, k, v, sm_scale, L, o, @@ -209,6 +214,11 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): DELAYED_SOFTMAX_NORMALIZE=DELAYED_SOFTMAX_NORMALIZE, num_warps=num_warps, num_stages=4) + + from triton.compiler import compiler as tc + # print("IR", compiled.asm['ttir']) + # print("TTGIR", compiled.asm['ttgir']) + # print("IR", compiled.asm['ptx']) ctx.save_for_backward(q, k, v, o, L) ctx.grid = grid @@ -220,14 +230,17 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): triton_attention = _attention.apply -def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="jax"): - block_qk_grid = [(64, 32), (128, 32), (64, 64), (128, 64)] +def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="jax", swap_seq_axis=False): + block_qk_grid = [(128, 64)] #[(64, 32), (128, 32), (64, 64)if mode == "pallas" else [(None, None)] k1, k2, k3 = random.split(random.PRNGKey(0), 3) - q = random.normal(k1, (batch, seq_len, heads, d_model), dtype=jnp.float16) - k = random.normal(k2, (batch, seq_len, heads, d_model), dtype=jnp.float16) - v = random.normal(k3, (batch, seq_len, heads, d_model), dtype=jnp.float16) + if swap_seq_axis: + shape = (batch, heads, seq_len, d_model) + else: + shape = (batch, seq_len, heads, d_model) + q = random.normal(k1, shape, dtype=jnp.float16) + k = random.normal(k2, shape, dtype=jnp.float16) + v = random.normal(k3, shape, dtype=jnp.float16) - functools.partial(attention.mha, causal=causal) min_ms = float("inf") @@ -235,7 +248,7 @@ def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, c for block_q, block_k in block_qk_grid: if mode == "pallas": impl = functools.partial( - attention.mha, causal=causal, block_q=block_q, block_k=block_k, num_warps=4, segment_ids=None) + attention.mha, causal=causal, block_q=block_q, block_k=block_k, num_warps=4, segment_ids=None, debug=False, swap_seq_axis=swap_seq_axis) elif mode == "jax": if seq_len >= 2048: # Handle OOM return None @@ -248,9 +261,9 @@ def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, c impl(q, k, v).block_until_ready() t1 = time.time() - for _ in range(NUM_RUNS): + for _ in range(NUM_INNER_RUNS): impl(q, k, v).block_until_ready() - estimate_ms = 1000 * (time.time() - t1) / NUM_RUNS + estimate_ms = 1000 * (time.time() - t1) / NUM_INNER_RUNS min_ms = min(estimate_ms, min_ms) print(f"{mode} (seq_len={seq_len}, block_q={block_q}, block_k={block_k}): {estimate_ms} ms") return min_ms @@ -267,65 +280,91 @@ def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, cau # from triton.ops import attention as triton_attention # Use a jitted function from triton nightly 28/08/23 as defined below. fn = lambda: triton_attention(q, k, v, causal, 1.0) - q = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((batch, heads, seq_len, d_model), dtype=dtype, device="cuda", requires_grad=True) + shape = (batch, heads, seq_len, d_model) elif mode == "flash_attn": from flash_attn import flash_attn_func fn = lambda: flash_attn_func(q, k, v, causal=causal) - q = torch.randn((batch, seq_len, heads, d_model), dtype=dtype, device="cuda", requires_grad=True) - k = torch.randn((batch, seq_len, heads, d_model), dtype=dtype, device="cuda", requires_grad=True) - v = torch.randn((batch, seq_len, heads, d_model), dtype=dtype, device="cuda", requires_grad=True) + shape = (batch, seq_len, heads, d_model) else: raise ValueError("Invalid JAX benchmark mode") + q = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) # Warmup fn() fn() torch.cuda.synchronize() t1 = time.time() - for _ in range(NUM_RUNS): + for _ in range(NUM_INNER_RUNS): fn() torch.cuda.synchronize() - estimate_ms = 1000 * (time.time() - t1) / NUM_RUNS + estimate_ms = 1000 * (time.time() - t1) / NUM_INNER_RUNS return estimate_ms # TODO: implement this def test_allclose(): pass +def tflops_from_ms(timing, seq_len, tokens): + return (4 * seq_len ** 2 * D_HEAD * N_HEADS * tokens / seq_len) / (timing * 10 ** 9) + +def is_zero(a): + return math.isclose(a, 0.0, abs_tol=0.00001) + def benchmark(causal=True): configs = [ - {'name': name, 'timings': [], 'tokens': tokens } - for name in ["jax", "pallas", "triton", "flash_attn"] + {'name': name, 'timings': [0.0 for _ in range(len(SEQ_LENS))], 'tokens': tokens } + for name in ["pallas", "triton", "flash_attn", "jax"] #, "triton", "flash_attn"] #, "triton", "flash_attn"]#["jax", "pallas", "triton", "flash_attn"] for tokens in [32768] ] + bench_fns = { 'jax': functools.partial(benchmark_jax, mode="jax"), 'pallas': functools.partial(benchmark_jax, mode="pallas"), + # 'pallas_swap': functools.partial(benchmark_jax, mode="pallas", swap_seq_axis=True), 'flash_attn': functools.partial(bench_torch, mode="flash_attn"), 'triton': functools.partial(bench_torch, mode="triton"), } fig, ax = plt.subplots() ax.ticklabel_format(useOffset=False) - for s in SEQ_LENS: - for config in configs: - config['timings'].append( - bench_fns[config['name']]( - batch=config['tokens'] // s, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal) - ) - time.sleep(BETWEEN_RUN_SLEEP_TIME_MS / 1000) - + for _ in range(NUM_OUTER_RUNS): + shuffled = configs[:] + pyrand.shuffle(shuffled) + print("ORDERING", [cfg['name'] for cfg in shuffled]) + for config in shuffled: + for s_idx, s in enumerate(SEQ_LENS): + # Randomize order of configs as the order been shown to matter (esp for small runs) + res = bench_fns[config['name']]( + batch=config['tokens'] // s, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal) + if res is not None and config['timings'][s_idx] is not None: + config['timings'][s_idx] += res + + time.sleep(BETWEEN_RUN_SLEEP_TIME_MS / 1000) + + # preprocess for config in configs: + config['timings'] = [None if is_zero(t) else t / NUM_OUTER_RUNS for t in config['timings']] + + len_configs = float(len(configs)) + min_timings = [min([config['timings'][pos] for config in configs if config['timings'][pos] is not None]) for pos in range(len(SEQ_LENS))] + + configs.sort(key=lambda c: ([t for t in c['timings'] if t is not None] or [float.max])[-1], reverse=True) + + for config_idx, config in enumerate(configs): + config_pos = (float(len_configs - config_idx) - len_configs / 2) / len_configs ax.plot(SEQ_LENS, config['timings'], label=config['name'] + '_' + str(config['tokens'] // 1000) + 'K_tokens' ) - for seq_len, b in zip(SEQ_LENS, config['timings']): - if b is not None: - b_height = b * 1.05 if (config['name'] == "triton")\ - else b * 0.95 if (config['name'] == "flash_attn")\ - else b - tflops = (4 * seq_len ** 2 * D_HEAD * N_HEADS * config['tokens'] / seq_len) / (b * 10 ** 9) - plt.text(seq_len, b_height, f'Time={round(b, 2)}ms,TFLOPs={round(tflops, 2)}') + for seq_len, timing, min_timing in zip(SEQ_LENS, config['timings'], min_timings): + if timing is not None: + if SEPARATE_ON_GRAPH: + timing_height = timing * (1. + 0.1 * config_pos) + else: + timing_height = timing + tflops = tflops_from_ms(timing, seq_len, config['tokens']) + max_tflops = tflops_from_ms(min_timing, seq_len, config['tokens']) + percentage_of_max = tflops / max_tflops + plt.text(seq_len, timing_height, f"{config['name']}: TFLOPs={round(tflops, 1)} ({round(percentage_of_max * 100, 2)}% max) - {round(timing, 1)}ms") plt.title(f'Fused Attention ({"Causal" if causal else "Non-Causal"})') plt.ylabel('time (ms)') plt.xlabel('Sequence Length') diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 0e99fe975f3d..eaa60dd1a72c 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -55,6 +55,10 @@ def mha_forward_kernel( # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_q, block_d], block_d == head_dim. q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None))) + q_scale = sm_scale * 1.44269504089 + if q_scale != 1.: + q *= q_scale + q_segment_ids = ( None if segment_ids_ref is None @@ -68,20 +72,16 @@ def body(start_k, carry): acc, m_prev, l_prev = carry k = pl.load(k_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) + v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) kv_segment_ids = ( None if segment_ids_ref is None else pl.load(segment_ids_ref, (pl.dslice(start_k * block_k, block_k),)) ) - qk = jnp.zeros([block_q, block_k], dtype=jnp.float32) - qk += pl.dot(q, k.T) # [block_q, block_k] - + qk = pl.dot(q, k.T) # [block_q, block_k] # Bring closer to XLA:GPU numerics. qk = qk.astype(q_ref.dtype) qk = qk.astype(jnp.float32) - qk_scale = sm_scale * 1.44269504089 - if qk_scale != 1.: - qk *= qk_scale # [block_q, block_k] if causal or segment_ids_ref is not None: mask = None @@ -103,7 +103,6 @@ def body(start_k, carry): if DELAYED_SOFTMAX_NORMALIZE: l_curr = jnp.sum(p, axis=1) + alpha * l_prev - # `0 * l_prev` is to handle weird compiler perf bug in Triton acc *= (0 * l_prev + alpha)[:, None] else: l_prev *= alpha @@ -113,7 +112,6 @@ def body(start_k, carry): p = p * l_rcp[:, None] acc *= (l_prev * l_rcp)[:, None] - v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), pl.dslice(block_d))) p = p.astype(jnp.float16) acc = acc + pl.dot(p.astype(v.dtype), v) return acc, m_curr, l_curr @@ -151,7 +149,7 @@ def segment_mask( @functools.partial( - jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] ) @functools.partial( jax.jit, @@ -164,6 +162,7 @@ def segment_mask( "num_warps", "num_stages", "grid", + "swap_seq_axis", "interpret", "debug", ], @@ -178,6 +177,7 @@ def mha( block_q: int = 128, block_k: int = 128, backward_pass_impl: str = "triton", + swap_seq_axis: bool = False, num_warps: Optional[int] = None, num_stages: int = 2, grid=None, @@ -185,7 +185,13 @@ def mha( debug: bool = False, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape + if swap_seq_axis: + batch_size, num_heads, seq_len, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda _, j, k: (j, k, 0, 0), (None, None, seq_len, head_dim)) + else: + batch_size, seq_len, num_heads, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)) + block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) # Heuristics. @@ -200,18 +206,7 @@ def mha( block_q=block_q, block_k=block_k, block_d=head_dim, causal=causal) - - in_specs = [ - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - ] + in_specs = [qkv_block_spec for _ in range(3)] in_specs.append( None # type: ignore[arg-type] if segment_ids is None @@ -222,9 +217,7 @@ def mha( kernel, grid=grid_, in_specs=in_specs, - out_specs=pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), + out_specs=qkv_block_spec, num_warps=num_warps_, num_stages=num_stages, out_shape=out_shape, @@ -244,6 +237,7 @@ def _mha_forward( block_q: int, block_k: int, backward_pass_impl: str, + swap_seq_axis: bool, num_warps: Optional[int], num_stages: int, grid: Any, @@ -251,7 +245,13 @@ def _mha_forward( debug: bool, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape + if swap_seq_axis: + batch_size, num_heads, seq_len, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda _, j, k: (j, k, 0, 0), (None, None, seq_len, head_dim)) + else: + batch_size, seq_len, num_heads, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)) + block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) # Heuristics. @@ -272,17 +272,7 @@ def _mha_forward( jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m dtype=jnp.float32) ] - in_specs = [ - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - ] + in_specs = [qkv_block_spec for _ in range(3)] in_specs.append( None # type: ignore[arg-type] if segment_ids is None @@ -293,9 +283,7 @@ def _mha_forward( grid=grid_, in_specs=in_specs, out_specs=[ - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), + qkv_block_spec, pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), ], @@ -327,9 +315,14 @@ def _preprocess_backward_kernel(out_ref, dout_ref, l_ref, do.astype(new_dout_ref.dtype)) pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype)) -def _preprocess_backward(out, do, l, block_q: int, +def _preprocess_backward(out, do, l, block_q: int, swap_seq_axis: bool, debug: bool, interpret: bool): - batch_size, seq_len, num_heads, head_dim = out.shape + if swap_seq_axis: + batch_size, num_heads, seq_len, head_dim = out.shape + out_block_spec = pl.BlockSpec(lambda _, j, k: (j, k, 0, 0), (None, None, seq_len, head_dim)) + else: + batch_size, seq_len, num_heads, head_dim = out.shape + out_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)) out_shape = [ jax.ShapeDtypeStruct(do.shape, do.dtype), jax.ShapeDtypeStruct(l.shape, l.dtype), @@ -338,12 +331,12 @@ def _preprocess_backward(out, do, l, block_q: int, functools.partial(_preprocess_backward_kernel, block_q=block_q), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + out_block_spec, + out_block_spec, pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + out_block_spec, pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), ], num_warps=4, @@ -397,14 +390,14 @@ def outer_loop(start_k, _): def inner_loop(start_q, carry): dv, dk = carry q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) - qk = pl.dot(q, k.T) + + q_scale = sm_scale * 1.44269504089 + if q_scale != 1.0: + q_scaled = q * q_scale + qk = pl.dot(q_scaled, k.T) qk = qk.astype(q_ref.dtype) qk = qk.astype(jnp.float32) - qk_scale = sm_scale * 1.44269504089 - if qk_scale != 1.0: - qk *= qk_scale - q_segment_ids = ( None if segment_ids_ref is None @@ -457,18 +450,25 @@ def inner_loop(start_q, carry): def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, - backward_pass_impl: str, num_warps: Optional[int], + backward_pass_impl: str, swap_seq_axis: bool, num_warps: Optional[int], num_stages: int, grid: Any, interpret: bool, debug: bool, res, do): del num_warps, num_stages, grid q, k, v, segment_ids, out, l, m = res - batch_size, seq_len, num_heads, head_dim = q.shape + if swap_seq_axis: + batch_size, num_heads, seq_len, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda j, k: (j, k, 0, 0), (None, None, seq_len, head_dim)) + else: + batch_size, seq_len, num_heads, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)) + block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) - do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) + do_scaled, delta = _preprocess_backward(out, do, l, block_q, swap_seq_axis, debug, interpret) if backward_pass_impl == "xla": + # TODO(jon-chuang): Handle the `swap_seq_axis=True` case for "xla" return jax.vjp( functools.partial(mha_reference, sm_scale=sm_scale, causal=causal), q, @@ -485,28 +485,11 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, jax.ShapeDtypeStruct(v.shape, v.dtype), ] - in_specs = [ - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), + in_specs = [qkv_block_spec for _ in range(5)] + [ pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), + qkv_block_spec, ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] @@ -529,17 +512,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, grid=grid, out_shape=out_shapes, in_specs=in_specs, - out_specs=[ - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - ], + out_specs=[qkv_block_spec for _ in range(3)], name="mha_backward", debug=debug, interpret=interpret, diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 00360b264320..129a72a7d86d 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1561,7 +1561,7 @@ def f_ref(q, k, v): dq, dk, dv = jax.grad(f, argnums=(0, 1, 2))(q, k, v) dq_ref, dk_ref, dv_ref = jax.grad(f_ref, argnums=(0, 1, 2))(q, k, v) np.testing.assert_allclose(dq, dq_ref, atol=0.1) - np.testing.assert_allclose(dk, dk_ref, atol=0.08) + np.testing.assert_allclose(dk, dk_ref, atol=0.1) np.testing.assert_allclose(dv, dv_ref, atol=0.05) From 06d0bc2c1a7dc89107ddb17e42ed932323f8cbc9 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Mon, 4 Sep 2023 19:09:00 +0800 Subject: [PATCH 17/19] minor --- examples/benchmark_pallas_attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index 1e1cd346c1aa..0ba95220818a 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -315,7 +315,7 @@ def is_zero(a): def benchmark(causal=True): configs = [ {'name': name, 'timings': [0.0 for _ in range(len(SEQ_LENS))], 'tokens': tokens } - for name in ["pallas", "triton", "flash_attn", "jax"] #, "triton", "flash_attn"] #, "triton", "flash_attn"]#["jax", "pallas", "triton", "flash_attn"] + for name in ["pallas", "triton", "flash_attn", "jax"] for tokens in [32768] ] From 38a36ecc2937afcc82deae1af60376883daef1f0 Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Mon, 4 Sep 2023 19:14:34 +0800 Subject: [PATCH 18/19] stash --- jax/_src/pallas/triton/lowering.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 379517fb9853..69d37f6073a3 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1696,13 +1696,13 @@ def pallas_call_lowering( kernel_call_proto = kernel_call.to_proto(serialized_metadata) return hlo_helpers.custom_call( call_target_name="triton_kernel_call", - result_types=out_types, + out_types=out_types, operands=in_nodes, backend_config=zlib.compress(kernel_call_proto), operand_layouts=triton_lib.avals_to_layouts(ctx.avals_in), result_layouts=triton_lib.avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), - ).results + ) mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda") From 40b70d6623df9c70a2cee6c3514cdd97d34d620c Mon Sep 17 00:00:00 2001 From: Jon Chuang <9093549+jon-chuang@users.noreply.github.com> Date: Fri, 15 Sep 2023 07:07:28 +0800 Subject: [PATCH 19/19] add new optimizations --- examples/benchmark_pallas_attention.py | 26 +++++---- jax/_src/pallas/primitives.py | 4 +- jax/_src/pallas/triton/lowering.py | 7 ++- jax/experimental/pallas/ops/attention.py | 74 ++++++++++++++---------- 4 files changed, 68 insertions(+), 43 deletions(-) diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py index 0ba95220818a..e5ad1b4d8b3e 100644 --- a/examples/benchmark_pallas_attention.py +++ b/examples/benchmark_pallas_attention.py @@ -54,7 +54,7 @@ D_HEAD = 64 N_HEADS = DIM // D_HEAD BATCH, SEQ_LEN = 8, 2048 -SEQ_LENS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384] +SEQ_LENS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384] # [4096, 8192, 16384] # TOTAL_RUNS = 150 NUM_OUTER_RUNS = 5 # For randomization of benchmark order @@ -231,7 +231,7 @@ def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): triton_attention = _attention.apply def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="jax", swap_seq_axis=False): - block_qk_grid = [(128, 64)] #[(64, 32), (128, 32), (64, 64)if mode == "pallas" else [(None, None)] + block_qk_grid = [(128, 64)] #[(64, 32), (128, 32), (64, 64)] if mode == "pallas" else [(None, None)] k1, k2, k3 = random.split(random.PRNGKey(0), 3) if swap_seq_axis: shape = (batch, heads, seq_len, d_model) @@ -248,7 +248,9 @@ def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, c for block_q, block_k in block_qk_grid: if mode == "pallas": impl = functools.partial( - attention.mha, causal=causal, block_q=block_q, block_k=block_k, num_warps=4, segment_ids=None, debug=False, swap_seq_axis=swap_seq_axis) + attention.mha, + causal=causal, block_q=block_q, block_k=block_k, num_warps=4, + segment_ids=None, debug=False, swap_seq_axis=swap_seq_axis) elif mode == "jax": if seq_len >= 2048: # Handle OOM return None @@ -262,7 +264,8 @@ def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, c t1 = time.time() for _ in range(NUM_INNER_RUNS): - impl(q, k, v).block_until_ready() + res = impl(q, k, v) + res.block_until_ready() estimate_ms = 1000 * (time.time() - t1) / NUM_INNER_RUNS min_ms = min(estimate_ms, min_ms) print(f"{mode} (seq_len={seq_len}, block_q={block_q}, block_k={block_k}): {estimate_ms} ms") @@ -298,7 +301,7 @@ def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, cau t1 = time.time() for _ in range(NUM_INNER_RUNS): fn() - torch.cuda.synchronize() + torch.cuda.synchronize() estimate_ms = 1000 * (time.time() - t1) / NUM_INNER_RUNS return estimate_ms @@ -315,14 +318,14 @@ def is_zero(a): def benchmark(causal=True): configs = [ {'name': name, 'timings': [0.0 for _ in range(len(SEQ_LENS))], 'tokens': tokens } - for name in ["pallas", "triton", "flash_attn", "jax"] + for name in ["pallas", "triton"] #, "flash_attn"] #, "jax"] for tokens in [32768] ] bench_fns = { 'jax': functools.partial(benchmark_jax, mode="jax"), 'pallas': functools.partial(benchmark_jax, mode="pallas"), - # 'pallas_swap': functools.partial(benchmark_jax, mode="pallas", swap_seq_axis=True), + 'pallas_swap': functools.partial(benchmark_jax, mode="pallas", swap_seq_axis=True), 'flash_attn': functools.partial(bench_torch, mode="flash_attn"), 'triton': functools.partial(bench_torch, mode="triton"), } @@ -348,9 +351,11 @@ def benchmark(causal=True): config['timings'] = [None if is_zero(t) else t / NUM_OUTER_RUNS for t in config['timings']] len_configs = float(len(configs)) - min_timings = [min([config['timings'][pos] for config in configs if config['timings'][pos] is not None]) for pos in range(len(SEQ_LENS))] + min_timings = [min([config['timings'][pos] + for config in configs if config['timings'][pos] is not None]) + for pos in range(len(SEQ_LENS))] - configs.sort(key=lambda c: ([t for t in c['timings'] if t is not None] or [float.max])[-1], reverse=True) + configs.sort(key=lambda c: ([t for t in c['timings'] if t is not None] or [10. ** 9])[-1], reverse=True) for config_idx, config in enumerate(configs): config_pos = (float(len_configs - config_idx) - len_configs / 2) / len_configs @@ -364,7 +369,8 @@ def benchmark(causal=True): tflops = tflops_from_ms(timing, seq_len, config['tokens']) max_tflops = tflops_from_ms(min_timing, seq_len, config['tokens']) percentage_of_max = tflops / max_tflops - plt.text(seq_len, timing_height, f"{config['name']}: TFLOPs={round(tflops, 1)} ({round(percentage_of_max * 100, 2)}% max) - {round(timing, 1)}ms") + plt.text(seq_len, timing_height, f"{config['name']}: TFLOPs={round(tflops, 1)} \ +({round(percentage_of_max * 100, 2)}% max) - {round(timing, 1)}ms") plt.title(f'Fused Attention ({"Causal" if causal else "Non-Causal"})') plt.ylabel('time (ms)') plt.xlabel('Sequence Length') diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 74e9d8fb36b0..4712b4502250 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -398,7 +398,7 @@ def store(x_ref, idx, val, *, mask=None, eviction_policy="") -> None: _ = swap(x_ref, idx, val, mask=mask, eviction_policy=eviction_policy) def dot(a, b, trans_a: bool = False, trans_b: bool = False, - allow_tf32: bool | None = None, precision=None): + allow_tf32: bool | None = None, precision=None, out_dtype=None): lhs_contract_dim = 0 if trans_a else 1 rhs_contract_dim = 0 if not trans_b else 1 if allow_tf32 is not None: @@ -408,4 +408,4 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False, return jax.lax.dot_general( a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())), precision=precision, - preferred_element_type=None).astype(jnp.float32) + preferred_element_type=out_dtype).astype(out_dtype or jnp.float32) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index 69d37f6073a3..2edb2eb5e67f 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -1020,8 +1020,11 @@ def _dot_general_lowering( allow_tf32 = ( precision == lax.Precision.HIGH or precision == lax.Precision.DEFAULT ) - return tl.dot(a, b, _builder=ctx.builder, allow_tf32=allow_tf32) - + if preferred_element_type == jnp.float16: + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + return tl.dot(a, b, _builder=ctx.builder, allow_tf32=allow_tf32, out_dtype=out_dtype) triton_lowering_rules[lax.dot_general_p] = _dot_general_lowering diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index eaa60dd1a72c..6230d755ed64 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -27,6 +27,9 @@ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) DELAYED_SOFTMAX_NORMALIZE = True +USE_UNMASKED_LOOP_BODY = False +ALLOW_QK_FP16_ACC = False +ALLOW_PV_FP16_ACC = True def mha_forward_kernel( q_ref, @@ -44,17 +47,21 @@ def mha_forward_kernel( seq_len = q_ref.shape[0] start_q = pl.program_id(0) + pv_acc_dtype = o_ref.dtype if ALLOW_PV_FP16_ACC else jnp.float32 + qk_acc_dtype = q_ref.dtype if ALLOW_QK_FP16_ACC else jnp.float32 + # acc is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf') l_i = jnp.zeros(block_q, dtype=jnp.float32) # acc is the buffer where we accumulate the output on sram. - acc = jnp.zeros((block_q, block_d), dtype=jnp.float32) + acc = jnp.zeros((block_q, block_d), pv_acc_dtype) # Load q: it will stay in L1 throughout. Indices form a matrix because we # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_q, block_d], block_d == head_dim. q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None))) + stability_factor = jnp.log2(seq_len) if DELAYED_SOFTMAX_NORMALIZE else 0. q_scale = sm_scale * 1.44269504089 if q_scale != 1.: q *= q_scale @@ -68,34 +75,37 @@ def mha_forward_kernel( # (Bc == block_k here), and fast over blocks of q (size Br == block_q here). # Here we only loop over blocks of kv to process entire seq_len, the loop over # blocks of q is carried out by the grid. - def body(start_k, carry): + def body(start_k, carry, masked): acc, m_prev, l_prev = carry k = pl.load(k_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) - kv_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (pl.dslice(start_k * block_k, block_k),)) - ) - qk = pl.dot(q, k.T) # [block_q, block_k] - # Bring closer to XLA:GPU numerics. - qk = qk.astype(q_ref.dtype) - qk = qk.astype(jnp.float32) - - if causal or segment_ids_ref is not None: - mask = None - if segment_ids_ref is not None: - mask = segment_mask(q_segment_ids, kv_segment_ids) - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - span_k = start_k * block_k + jnp.arange(block_k) - causal_mask = span_q[:, None] >= span_k[None, :] - mask = ( - causal_mask if mask is None else jnp.logical_and(mask, causal_mask) - ) - # Apply mask to qk. - qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + if masked: + kv_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.dslice(start_k * block_k, block_k),)) + ) + qk = pl.dot(q, k.T, out_dtype=qk_acc_dtype).astype(q_ref.dtype) # [block_q, block_k] + # Bring closer to XLA:GPU numerics. + qk = qk.astype(jnp.float32) + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + mask = segment_mask(q_segment_ids, kv_segment_ids) + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + span_k = start_k * block_k + jnp.arange(block_k) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + # Apply mask to qk. + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + else: + qk = pl.dot(q, k.T, out_dtype=qk_acc_dtype).astype(q_ref.dtype) # [block_q, block_k] + # Bring closer to XLA:GPU numerics. + qk = qk.astype(jnp.float32) m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev) alpha = jnp.exp2(m_prev - m_curr) p = jnp.exp2(qk - m_curr[:, None]) @@ -103,25 +113,31 @@ def body(start_k, carry): if DELAYED_SOFTMAX_NORMALIZE: l_curr = jnp.sum(p, axis=1) + alpha * l_prev - acc *= (0 * l_prev + alpha)[:, None] + acc = (acc * alpha[:, None]).astype(acc.dtype) else: l_prev *= alpha l_curr = jnp.sum(p, axis=1) + l_prev l_rcp = 1. / l_curr p = p * l_rcp[:, None] - acc *= (l_prev * l_rcp)[:, None] + + acc = (acc * (l_prev * l_rcp)[:, None]).astype(acc.dtype) p = p.astype(jnp.float16) - acc = acc + pl.dot(p.astype(v.dtype), v) + acc += pl.dot(p.astype(v.dtype), v, out_dtype=acc.dtype) return acc, m_curr, l_curr if causal: # Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q) upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k) + causal_lower_bound = lax.div(block_q * start_q, block_k) if USE_UNMASKED_LOOP_BODY else upper_bound else: upper_bound = pl.cdiv(seq_len, block_k) # type: ignore - acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, + causal_lower_bound = upper_bound + must_mask = segment_ids_ref is not None + acc, m_i, l_i = lax.fori_loop(causal_lower_bound, upper_bound, functools.partial(body, masked=causal or must_mask), (acc, m_i, l_i)) + acc, m_i, l_i = lax.fori_loop(0, causal_lower_bound, functools.partial(body, masked=must_mask), + (acc, m_i, l_i)) if DELAYED_SOFTMAX_NORMALIZE: acc = acc / l_i[:, None]