diff --git a/examples/benchmark_pallas_attention.py b/examples/benchmark_pallas_attention.py new file mode 100644 index 000000000000..e5ad1b4d8b3e --- /dev/null +++ b/examples/benchmark_pallas_attention.py @@ -0,0 +1,386 @@ +# Copyright 2023 The JAX Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""An example of benchmarking a Pallas kernel (fused attention). + +Since a common use-case of Pallas may be to write performant tiled kernels with +certain amount of low-level control, users will often want to benchmark their +Pallas implementation against pure JAX implementations (e.g. lowered by XLA) and +implementations in external libraries. + +Here, we show an example benchmarking fused attention for the following: +1. Pallas implementation +2. Pure-JAX implementation +3. Triton implementation (with PyTorch tensor infra) +4. flash_attn implementation + +TODO: +1. cuDNN +2. xformers + +We choose the settings to be similar to those in +https://tridao.me/publications/flash2/flash2.pdf +""" + +import functools +import time +import math + +import matplotlib.pyplot as plt +import triton +import triton.language as tl +from triton import cdiv +import torch + +from jax import random +import random as pyrand +import jax +import jax.numpy as jnp +from jax.experimental.pallas.ops import attention + + +DIM = 2048 +D_HEAD = 64 +N_HEADS = DIM // D_HEAD +BATCH, SEQ_LEN = 8, 2048 +SEQ_LENS = [128, 256, 512, 1024, 2048, 4096, 8192, 16384] # [4096, 8192, 16384] # + +TOTAL_RUNS = 150 +NUM_OUTER_RUNS = 5 # For randomization of benchmark order +NUM_INNER_RUNS = (TOTAL_RUNS + NUM_OUTER_RUNS - 1) // NUM_OUTER_RUNS + +BETWEEN_RUN_SLEEP_TIME_MS = 200 + +DELAYED_SOFTMAX_NORMALIZE = True +SEPARATE_ON_GRAPH = True + +""" +Appendix +""" + +@triton.jit +def _fwd_kernel( + Q, K, V, sm_scale, + L, + Out, + stride_qz, stride_qh, stride_qm, stride_qk, + stride_kz, stride_kh, stride_kn, stride_kk, + stride_vz, stride_vh, stride_vk, stride_vn, + stride_oz, stride_oh, stride_om, stride_on, + Z, H, N_CTX, + BLOCK_M: tl.constexpr, BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, + IS_CAUSAL: tl.constexpr, + DELAYED_SOFTMAX_NORMALIZE: tl.constexpr, +): + start_m = tl.program_id(0) + off_hz = tl.program_id(1) + qvk_offset = off_hz * stride_qh + Q_block_ptr = tl.make_block_ptr( + base=Q + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_qm, stride_qk), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + K_block_ptr = tl.make_block_ptr( + base=K + qvk_offset, + shape=(BLOCK_DMODEL, N_CTX), + strides=(stride_kk, stride_kn), + offsets=(0, 0), + block_shape=(BLOCK_DMODEL, BLOCK_N), + order=(0, 1) + ) + V_block_ptr = tl.make_block_ptr( + base=V + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_vk, stride_vn), + offsets=(0, 0), + block_shape=(BLOCK_N, BLOCK_DMODEL), + order=(1, 0) + ) + # initialize offsets + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + offs_n = tl.arange(0, BLOCK_N) + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + # credits to: Adam P. Goucher (https://github.com/apgoucher): + # scale sm_scale by 1/log_2(e) and use + # 2^x instead of exp in the loop because CSE and LICM + # don't work as expected with `exp` in the loop + qk_scale = sm_scale * 1.44269504 + # load q: it will stay in SRAM throughout + q = tl.load(Q_block_ptr) + q = (q * qk_scale).to(K.dtype.element_ty) + lo = 0 + hi = (start_m + 1) * BLOCK_M if IS_CAUSAL else N_CTX + for start_n in range(lo, hi, BLOCK_N): + # -- load k, v -- + k = tl.load(K_block_ptr) + v = tl.load(V_block_ptr) + # -- compute qk --- + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + if IS_CAUSAL: + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + qk += tl.dot(q, k, allow_tf32=True) + # -- compute scaling constant --- + m_i_new = tl.maximum(m_i, tl.max(qk, 1)) + alpha = tl.math.exp2(m_i - m_i_new) + p = tl.math.exp2(qk - m_i_new[:, None]) + + if DELAYED_SOFTMAX_NORMALIZE: + # -- scale and update acc -- + acc *= alpha[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) + # -- update m_i and l_i -- + l_i = l_i * alpha + tl.sum(p, 1) + m_i = m_i_new + else: + l_i = l_i * alpha + l_i_new = l_i + tl.sum(p, 1) + l_rcp = 1. / l_i_new + p *= l_rcp[:, None] + # -- scale and update acc -- + acc *= (l_i * l_rcp)[:, None] + acc += tl.dot(p.to(V.dtype.element_ty), v, allow_tf32=True) + # -- update m_i and l_i -- + l_i = l_i_new + m_i = m_i_new + + # update pointers + K_block_ptr = tl.advance(K_block_ptr, (0, BLOCK_N)) + V_block_ptr = tl.advance(V_block_ptr, (BLOCK_N, 0)) + + if DELAYED_SOFTMAX_NORMALIZE: + acc = acc / l_i[:, None] + + # write back l and m + l_ptrs = L + off_hz * N_CTX + offs_m + tl.store(l_ptrs, m_i + tl.math.log2(l_i)) + # write back O + O_block_ptr = tl.make_block_ptr( + base=Out + qvk_offset, + shape=(N_CTX, BLOCK_DMODEL), + strides=(stride_om, stride_on), + offsets=(start_m * BLOCK_M, 0), + block_shape=(BLOCK_M, BLOCK_DMODEL), + order=(1, 0) + ) + tl.store(O_block_ptr, acc.to(K.dtype.element_ty)) + +class _attention(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, causal, sm_scale, sequence_parallel=False): + # only support for Ampere now + capability = torch.cuda.get_device_capability() + if capability[0] < 8: + raise RuntimeError("Flash attention currently only supported for compute capability >= 80") + BLOCK_M = 128 + BLOCK_N = 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + assert Lk in {16, 32, 64, 128} + o = torch.empty_like(q) + grid = (cdiv(q.shape[2], BLOCK_M), q.shape[0] * q.shape[1], 1) + L = torch.empty((q.shape[0] * q.shape[1], q.shape[2]), device=q.device, dtype=torch.float32) + num_warps = 4 if Lk <= 64 else 8 + compiled = _fwd_kernel[grid]( + q, k, v, sm_scale, + L, + o, + q.stride(0), q.stride(1), q.stride(2), q.stride(3), + k.stride(0), k.stride(1), k.stride(2), k.stride(3), + v.stride(0), v.stride(1), v.stride(2), v.stride(3), + o.stride(0), o.stride(1), o.stride(2), o.stride(3), + q.shape[0], q.shape[1], q.shape[2], + BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N, BLOCK_DMODEL=Lk, + IS_CAUSAL=causal, + DELAYED_SOFTMAX_NORMALIZE=DELAYED_SOFTMAX_NORMALIZE, + num_warps=num_warps, + num_stages=4) + + from triton.compiler import compiler as tc + # print("IR", compiled.asm['ttir']) + # print("TTGIR", compiled.asm['ttgir']) + # print("IR", compiled.asm['ptx']) + + ctx.save_for_backward(q, k, v, o, L) + ctx.grid = grid + ctx.sm_scale = sm_scale + ctx.BLOCK_DMODEL = Lk + ctx.causal = causal + ctx.sequence_parallel = sequence_parallel + return o + +triton_attention = _attention.apply + +def benchmark_jax(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="jax", swap_seq_axis=False): + block_qk_grid = [(128, 64)] #[(64, 32), (128, 32), (64, 64)] if mode == "pallas" else [(None, None)] + k1, k2, k3 = random.split(random.PRNGKey(0), 3) + if swap_seq_axis: + shape = (batch, heads, seq_len, d_model) + else: + shape = (batch, seq_len, heads, d_model) + q = random.normal(k1, shape, dtype=jnp.float16) + k = random.normal(k2, shape, dtype=jnp.float16) + v = random.normal(k3, shape, dtype=jnp.float16) + + + min_ms = float("inf") + + # Perform a grid search and choose the best timing + for block_q, block_k in block_qk_grid: + if mode == "pallas": + impl = functools.partial( + attention.mha, + causal=causal, block_q=block_q, block_k=block_k, num_warps=4, + segment_ids=None, debug=False, swap_seq_axis=swap_seq_axis) + elif mode == "jax": + if seq_len >= 2048: # Handle OOM + return None + impl = functools.partial(attention.mha_reference, segment_ids=None) + else: + raise ValueError("Invalid JAX benchmark mode") + + # Warm up + impl(q, k, v).block_until_ready() + impl(q, k, v).block_until_ready() + + t1 = time.time() + for _ in range(NUM_INNER_RUNS): + res = impl(q, k, v) + res.block_until_ready() + estimate_ms = 1000 * (time.time() - t1) / NUM_INNER_RUNS + min_ms = min(estimate_ms, min_ms) + print(f"{mode} (seq_len={seq_len}, block_q={block_q}, block_k={block_k}): {estimate_ms} ms") + return min_ms + +# Mode is one of {"triton", "flash_attn"} +def bench_torch(batch=BATCH, heads=N_HEADS, seq_len=SEQ_LEN, d_model=D_HEAD, causal=True, mode="triton"): + import torch + dtype = torch.float16 + if mode == "triton": + """ + Triton implementation broken in dep of jax-triton: + `RuntimeError: CUDA error: an illegal memory access was encountered` + """ + # from triton.ops import attention as triton_attention + # Use a jitted function from triton nightly 28/08/23 as defined below. + fn = lambda: triton_attention(q, k, v, causal, 1.0) + shape = (batch, heads, seq_len, d_model) + elif mode == "flash_attn": + from flash_attn import flash_attn_func + fn = lambda: flash_attn_func(q, k, v, causal=causal) + shape = (batch, seq_len, heads, d_model) + else: + raise ValueError("Invalid JAX benchmark mode") + q = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + k = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + v = torch.randn(shape, dtype=dtype, device="cuda", requires_grad=True) + + # Warmup + fn() + fn() + torch.cuda.synchronize() + t1 = time.time() + for _ in range(NUM_INNER_RUNS): + fn() + torch.cuda.synchronize() + estimate_ms = 1000 * (time.time() - t1) / NUM_INNER_RUNS + return estimate_ms + +# TODO: implement this +def test_allclose(): + pass + +def tflops_from_ms(timing, seq_len, tokens): + return (4 * seq_len ** 2 * D_HEAD * N_HEADS * tokens / seq_len) / (timing * 10 ** 9) + +def is_zero(a): + return math.isclose(a, 0.0, abs_tol=0.00001) + +def benchmark(causal=True): + configs = [ + {'name': name, 'timings': [0.0 for _ in range(len(SEQ_LENS))], 'tokens': tokens } + for name in ["pallas", "triton"] #, "flash_attn"] #, "jax"] + for tokens in [32768] + ] + + bench_fns = { + 'jax': functools.partial(benchmark_jax, mode="jax"), + 'pallas': functools.partial(benchmark_jax, mode="pallas"), + 'pallas_swap': functools.partial(benchmark_jax, mode="pallas", swap_seq_axis=True), + 'flash_attn': functools.partial(bench_torch, mode="flash_attn"), + 'triton': functools.partial(bench_torch, mode="triton"), + } + fig, ax = plt.subplots() + ax.ticklabel_format(useOffset=False) + + for _ in range(NUM_OUTER_RUNS): + shuffled = configs[:] + pyrand.shuffle(shuffled) + print("ORDERING", [cfg['name'] for cfg in shuffled]) + for config in shuffled: + for s_idx, s in enumerate(SEQ_LENS): + # Randomize order of configs as the order been shown to matter (esp for small runs) + res = bench_fns[config['name']]( + batch=config['tokens'] // s, heads=N_HEADS, seq_len=s, d_model=D_HEAD, causal=causal) + if res is not None and config['timings'][s_idx] is not None: + config['timings'][s_idx] += res + + time.sleep(BETWEEN_RUN_SLEEP_TIME_MS / 1000) + + # preprocess + for config in configs: + config['timings'] = [None if is_zero(t) else t / NUM_OUTER_RUNS for t in config['timings']] + + len_configs = float(len(configs)) + min_timings = [min([config['timings'][pos] + for config in configs if config['timings'][pos] is not None]) + for pos in range(len(SEQ_LENS))] + + configs.sort(key=lambda c: ([t for t in c['timings'] if t is not None] or [10. ** 9])[-1], reverse=True) + + for config_idx, config in enumerate(configs): + config_pos = (float(len_configs - config_idx) - len_configs / 2) / len_configs + ax.plot(SEQ_LENS, config['timings'], label=config['name'] + '_' + str(config['tokens'] // 1000) + 'K_tokens' ) + for seq_len, timing, min_timing in zip(SEQ_LENS, config['timings'], min_timings): + if timing is not None: + if SEPARATE_ON_GRAPH: + timing_height = timing * (1. + 0.1 * config_pos) + else: + timing_height = timing + tflops = tflops_from_ms(timing, seq_len, config['tokens']) + max_tflops = tflops_from_ms(min_timing, seq_len, config['tokens']) + percentage_of_max = tflops / max_tflops + plt.text(seq_len, timing_height, f"{config['name']}: TFLOPs={round(tflops, 1)} \ +({round(percentage_of_max * 100, 2)}% max) - {round(timing, 1)}ms") + plt.title(f'Fused Attention ({"Causal" if causal else "Non-Causal"})') + plt.ylabel('time (ms)') + plt.xlabel('Sequence Length') + ax.set_xticks(SEQ_LENS, SEQ_LENS) + plt.yscale("log") + plt.legend() + plt.show() + +if __name__ == '__main__': + test_allclose() + benchmark() + + diff --git a/jax/_src/pallas/primitives.py b/jax/_src/pallas/primitives.py index 74e9d8fb36b0..4712b4502250 100644 --- a/jax/_src/pallas/primitives.py +++ b/jax/_src/pallas/primitives.py @@ -398,7 +398,7 @@ def store(x_ref, idx, val, *, mask=None, eviction_policy="") -> None: _ = swap(x_ref, idx, val, mask=mask, eviction_policy=eviction_policy) def dot(a, b, trans_a: bool = False, trans_b: bool = False, - allow_tf32: bool | None = None, precision=None): + allow_tf32: bool | None = None, precision=None, out_dtype=None): lhs_contract_dim = 0 if trans_a else 1 rhs_contract_dim = 0 if not trans_b else 1 if allow_tf32 is not None: @@ -408,4 +408,4 @@ def dot(a, b, trans_a: bool = False, trans_b: bool = False, return jax.lax.dot_general( a, b, dimension_numbers=(((lhs_contract_dim,), (rhs_contract_dim,)), ((), ())), precision=precision, - preferred_element_type=None).astype(jnp.float32) + preferred_element_type=out_dtype).astype(out_dtype or jnp.float32) diff --git a/jax/_src/pallas/triton/lowering.py b/jax/_src/pallas/triton/lowering.py index e9b30d7d80f7..2edb2eb5e67f 100644 --- a/jax/_src/pallas/triton/lowering.py +++ b/jax/_src/pallas/triton/lowering.py @@ -391,6 +391,13 @@ def _exp_lowering_rule(ctx: TritonLoweringRuleContext, a): triton_lowering_rules[lax.exp_p] = _exp_lowering_rule +def _exp2_lowering_rule(ctx: TritonLoweringRuleContext, a): + return tl.math.exp2(a, _builder=ctx.builder) + + +triton_lowering_rules[lax.exp2_p] = _exp2_lowering_rule + + def _log_lowering_rule(ctx: TritonLoweringRuleContext, a): return tl.log(a, _builder=ctx.builder) @@ -1013,8 +1020,11 @@ def _dot_general_lowering( allow_tf32 = ( precision == lax.Precision.HIGH or precision == lax.Precision.DEFAULT ) - return tl.dot(a, b, _builder=ctx.builder, allow_tf32=allow_tf32) - + if preferred_element_type == jnp.float16: + out_dtype = tl.float16 + else: + out_dtype = tl.float32 + return tl.dot(a, b, _builder=ctx.builder, allow_tf32=allow_tf32, out_dtype=out_dtype) triton_lowering_rules[lax.dot_general_p] = _dot_general_lowering @@ -1689,13 +1699,13 @@ def pallas_call_lowering( kernel_call_proto = kernel_call.to_proto(serialized_metadata) return hlo_helpers.custom_call( call_target_name="triton_kernel_call", - result_types=out_types, + out_types=out_types, operands=in_nodes, backend_config=zlib.compress(kernel_call_proto), operand_layouts=triton_lib.avals_to_layouts(ctx.avals_in), result_layouts=triton_lib.avals_to_layouts(ctx.avals_out), operand_output_aliases=dict(input_output_aliases), - ).results + ) mlir.register_lowering(pallas_call_p, pallas_call_lowering, platform="cuda") diff --git a/jax/experimental/pallas/ops/attention.py b/jax/experimental/pallas/ops/attention.py index 0a9ef0de3f72..6230d755ed64 100644 --- a/jax/experimental/pallas/ops/attention.py +++ b/jax/experimental/pallas/ops/attention.py @@ -26,6 +26,10 @@ DEFAULT_MASK_VALUE = -0.7 * float(np.finfo(np.dtype("float32")).max) +DELAYED_SOFTMAX_NORMALIZE = True +USE_UNMASKED_LOOP_BODY = False +ALLOW_QK_FP16_ACC = False +ALLOW_PV_FP16_ACC = True def mha_forward_kernel( q_ref, @@ -43,17 +47,25 @@ def mha_forward_kernel( seq_len = q_ref.shape[0] start_q = pl.program_id(0) + pv_acc_dtype = o_ref.dtype if ALLOW_PV_FP16_ACC else jnp.float32 + qk_acc_dtype = q_ref.dtype if ALLOW_QK_FP16_ACC else jnp.float32 + # acc is the buffer where we accumulate the output on sram. # m_i and l_i (see FlashAttention paper) are updated during the k,v loop. m_i = jnp.zeros(block_q, dtype=jnp.float32) - float('inf') l_i = jnp.zeros(block_q, dtype=jnp.float32) # acc is the buffer where we accumulate the output on sram. - acc = jnp.zeros((block_q, block_d), dtype=jnp.float32) + acc = jnp.zeros((block_q, block_d), pv_acc_dtype) # Load q: it will stay in L1 throughout. Indices form a matrix because we # read, compute, and write all in 2d chunks. 1 element ~= 1 CUDA thread index. # q tile has shape [block_q, block_d], block_d == head_dim. q = pl.load(q_ref, (pl.dslice(start_q * block_q, block_q), pl.dslice(None))) + stability_factor = jnp.log2(seq_len) if DELAYED_SOFTMAX_NORMALIZE else 0. + q_scale = sm_scale * 1.44269504089 + if q_scale != 1.: + q *= q_scale + q_segment_ids = ( None if segment_ids_ref is None @@ -63,56 +75,71 @@ def mha_forward_kernel( # (Bc == block_k here), and fast over blocks of q (size Br == block_q here). # Here we only loop over blocks of kv to process entire seq_len, the loop over # blocks of q is carried out by the grid. - def body(start_k, carry): + def body(start_k, carry, masked): acc, m_prev, l_prev = carry k = pl.load(k_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) - kv_segment_ids = ( - None - if segment_ids_ref is None - else pl.load(segment_ids_ref, (pl.dslice(start_k * block_k, block_k),)) - ) - qk = jnp.zeros([block_q, block_k], dtype=jnp.float32) - qk += pl.dot(q, k.T) # [block_q, block_k] - if sm_scale != 1.: - qk *= sm_scale # [block_q, block_k] - - # Bring closer to XLA:GPU numerics. - qk = qk.astype(q_ref.dtype) - qk = qk.astype(jnp.float32) - - if causal or segment_ids_ref is not None: - mask = None - if segment_ids_ref is not None: - mask = segment_mask(q_segment_ids, kv_segment_ids) - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - span_k = start_k * block_k + jnp.arange(block_k) - causal_mask = span_q[:, None] >= span_k[None, :] - mask = ( - causal_mask if mask is None else jnp.logical_and(mask, causal_mask) - ) - # Apply mask to qk. - qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), slice(None))) + if masked: + kv_segment_ids = ( + None + if segment_ids_ref is None + else pl.load(segment_ids_ref, (pl.dslice(start_k * block_k, block_k),)) + ) + qk = pl.dot(q, k.T, out_dtype=qk_acc_dtype).astype(q_ref.dtype) # [block_q, block_k] + # Bring closer to XLA:GPU numerics. + qk = qk.astype(jnp.float32) + if causal or segment_ids_ref is not None: + mask = None + if segment_ids_ref is not None: + mask = segment_mask(q_segment_ids, kv_segment_ids) + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + span_k = start_k * block_k + jnp.arange(block_k) + causal_mask = span_q[:, None] >= span_k[None, :] + mask = ( + causal_mask if mask is None else jnp.logical_and(mask, causal_mask) + ) + # Apply mask to qk. + qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) + else: + qk = pl.dot(q, k.T, out_dtype=qk_acc_dtype).astype(q_ref.dtype) # [block_q, block_k] + # Bring closer to XLA:GPU numerics. + qk = qk.astype(jnp.float32) m_curr = jnp.maximum(jnp.max(qk, axis=1), m_prev) - l_prev *= jnp.exp(m_prev - m_curr) - p = jnp.exp(qk - m_curr[:, None]) - l_curr = jnp.sum(p, axis=1) + l_prev + alpha = jnp.exp2(m_prev - m_curr) + p = jnp.exp2(qk - m_curr[:, None]) - l_rcp = 1. / l_curr - p = p * l_rcp[:, None] - acc *= (l_prev * l_rcp)[:, None] - p = p.astype(jnp.float16) + if DELAYED_SOFTMAX_NORMALIZE: + l_curr = jnp.sum(p, axis=1) + alpha * l_prev - v = pl.load(v_ref, (pl.dslice(start_k * block_k, block_k), pl.dslice(block_d))) - acc = acc + pl.dot(p.astype(v.dtype), v) + acc = (acc * alpha[:, None]).astype(acc.dtype) + else: + l_prev *= alpha + l_curr = jnp.sum(p, axis=1) + l_prev + + l_rcp = 1. / l_curr + p = p * l_rcp[:, None] + + acc = (acc * (l_prev * l_rcp)[:, None]).astype(acc.dtype) + + p = p.astype(jnp.float16) + acc += pl.dot(p.astype(v.dtype), v, out_dtype=acc.dtype) return acc, m_curr, l_curr if causal: - upper_bound = lax.div(block_q * start_q, block_k) + 1 + # Ceildiv (`pl.cdiv` and `//` do not work due to type of start_q) + upper_bound = lax.div(block_q * (start_q + 1) + block_k - 1, block_k) + causal_lower_bound = lax.div(block_q * start_q, block_k) if USE_UNMASKED_LOOP_BODY else upper_bound else: upper_bound = pl.cdiv(seq_len, block_k) # type: ignore - acc, m_i, l_i = lax.fori_loop(0, upper_bound, body, + causal_lower_bound = upper_bound + must_mask = segment_ids_ref is not None + acc, m_i, l_i = lax.fori_loop(causal_lower_bound, upper_bound, functools.partial(body, masked=causal or must_mask), (acc, m_i, l_i)) + acc, m_i, l_i = lax.fori_loop(0, causal_lower_bound, functools.partial(body, masked=must_mask), + (acc, m_i, l_i)) + if DELAYED_SOFTMAX_NORMALIZE: + acc = acc / l_i[:, None] if residual_refs: l_ref, m_ref = residual_refs @@ -138,7 +165,7 @@ def segment_mask( @functools.partial( - jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13] + jax.custom_vjp, nondiff_argnums=[4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14] ) @functools.partial( jax.jit, @@ -151,6 +178,7 @@ def segment_mask( "num_warps", "num_stages", "grid", + "swap_seq_axis", "interpret", "debug", ], @@ -165,6 +193,7 @@ def mha( block_q: int = 128, block_k: int = 128, backward_pass_impl: str = "triton", + swap_seq_axis: bool = False, num_warps: Optional[int] = None, num_stages: int = 2, grid=None, @@ -172,7 +201,13 @@ def mha( debug: bool = False, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape + if swap_seq_axis: + batch_size, num_heads, seq_len, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda _, j, k: (j, k, 0, 0), (None, None, seq_len, head_dim)) + else: + batch_size, seq_len, num_heads, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)) + block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) # Heuristics. @@ -187,18 +222,7 @@ def mha( block_q=block_q, block_k=block_k, block_d=head_dim, causal=causal) - - in_specs = [ - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - ] + in_specs = [qkv_block_spec for _ in range(3)] in_specs.append( None # type: ignore[arg-type] if segment_ids is None @@ -209,9 +233,7 @@ def mha( kernel, grid=grid_, in_specs=in_specs, - out_specs=pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), + out_specs=qkv_block_spec, num_warps=num_warps_, num_stages=num_stages, out_shape=out_shape, @@ -231,6 +253,7 @@ def _mha_forward( block_q: int, block_k: int, backward_pass_impl: str, + swap_seq_axis: bool, num_warps: Optional[int], num_stages: int, grid: Any, @@ -238,7 +261,13 @@ def _mha_forward( debug: bool, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape + if swap_seq_axis: + batch_size, num_heads, seq_len, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda _, j, k: (j, k, 0, 0), (None, None, seq_len, head_dim)) + else: + batch_size, seq_len, num_heads, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)) + block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) # Heuristics. @@ -259,17 +288,7 @@ def _mha_forward( jax.ShapeDtypeStruct(shape=(batch_size, num_heads, seq_len), # m dtype=jnp.float32) ] - in_specs = [ - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - ] + in_specs = [qkv_block_spec for _ in range(3)] in_specs.append( None # type: ignore[arg-type] if segment_ids is None @@ -280,9 +299,7 @@ def _mha_forward( grid=grid_, in_specs=in_specs, out_specs=[ - pl.BlockSpec( - lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), + qkv_block_spec, pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), ], @@ -314,9 +331,14 @@ def _preprocess_backward_kernel(out_ref, dout_ref, l_ref, do.astype(new_dout_ref.dtype)) pl.store(delta_ref, (off_m,), delta.astype(delta_ref.dtype)) -def _preprocess_backward(out, do, l, block_q: int, +def _preprocess_backward(out, do, l, block_q: int, swap_seq_axis: bool, debug: bool, interpret: bool): - batch_size, seq_len, num_heads, head_dim = out.shape + if swap_seq_axis: + batch_size, num_heads, seq_len, head_dim = out.shape + out_block_spec = pl.BlockSpec(lambda _, j, k: (j, k, 0, 0), (None, None, seq_len, head_dim)) + else: + batch_size, seq_len, num_heads, head_dim = out.shape + out_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)) out_shape = [ jax.ShapeDtypeStruct(do.shape, do.dtype), jax.ShapeDtypeStruct(l.shape, l.dtype), @@ -325,12 +347,12 @@ def _preprocess_backward(out, do, l, block_q: int, functools.partial(_preprocess_backward_kernel, block_q=block_q), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + out_block_spec, + out_block_spec, pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + out_block_spec, pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), ], num_warps=4, @@ -384,11 +406,13 @@ def outer_loop(start_k, _): def inner_loop(start_q, carry): dv, dk = carry q = pl.load(q_ref, (pl.ds(start_q * block_q, block_q), slice(None))) - qk = pl.dot(q, k.T) + + q_scale = sm_scale * 1.44269504089 + if q_scale != 1.0: + q_scaled = q * q_scale + qk = pl.dot(q_scaled, k.T) qk = qk.astype(q_ref.dtype) qk = qk.astype(jnp.float32) - if sm_scale != 1.0: - qk *= sm_scale q_segment_ids = ( None @@ -412,7 +436,7 @@ def inner_loop(start_q, carry): qk = jnp.where(mask, qk, DEFAULT_MASK_VALUE) m = pl.load(m_ref, (pl.ds(start_q * block_q, block_q),)) - p = jnp.exp(qk - m[:, None]) + p = jnp.exp2(qk - m[:, None]) do = pl.load(do_scaled_ref, (pl.ds(start_q * block_q, block_q), slice(None))) dv = dv + pl.dot(p.astype(do.dtype).T, do) di = pl.load(delta_ref, (pl.ds(start_q * block_q, block_q),)) @@ -442,18 +466,25 @@ def inner_loop(start_q, carry): def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, - backward_pass_impl: str, num_warps: Optional[int], + backward_pass_impl: str, swap_seq_axis: bool, num_warps: Optional[int], num_stages: int, grid: Any, interpret: bool, debug: bool, res, do): del num_warps, num_stages, grid q, k, v, segment_ids, out, l, m = res - batch_size, seq_len, num_heads, head_dim = q.shape + if swap_seq_axis: + batch_size, num_heads, seq_len, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda j, k: (j, k, 0, 0), (None, None, seq_len, head_dim)) + else: + batch_size, seq_len, num_heads, head_dim = q.shape + qkv_block_spec = pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)) + block_q = min(block_q, seq_len) block_k = min(block_k, seq_len) - do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) + do_scaled, delta = _preprocess_backward(out, do, l, block_q, swap_seq_axis, debug, interpret) if backward_pass_impl == "xla": + # TODO(jon-chuang): Handle the `swap_seq_axis=True` case for "xla" return jax.vjp( functools.partial(mha_reference, sm_scale=sm_scale, causal=causal), q, @@ -470,28 +501,11 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, jax.ShapeDtypeStruct(v.shape, v.dtype), ] - in_specs = [ - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), + in_specs = [qkv_block_spec for _ in range(5)] + [ pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), + qkv_block_spec, ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] @@ -514,17 +528,7 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, grid=grid, out_shape=out_shapes, in_specs=in_specs, - out_specs=[ - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - pl.BlockSpec( - lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim) - ), - ], + out_specs=[qkv_block_spec for _ in range(3)], name="mha_backward", debug=debug, interpret=interpret, diff --git a/tests/pallas/pallas_test.py b/tests/pallas/pallas_test.py index 452f0361cfba..ae2e4767e4cc 100644 --- a/tests/pallas/pallas_test.py +++ b/tests/pallas/pallas_test.py @@ -1561,7 +1561,7 @@ def f_ref(q, k, v): dq, dk, dv = jax.grad(f, argnums=(0, 1, 2))(q, k, v) dq_ref, dk_ref, dv_ref = jax.grad(f_ref, argnums=(0, 1, 2))(q, k, v) np.testing.assert_allclose(dq, dq_ref, atol=0.1) - np.testing.assert_allclose(dk, dk_ref, atol=0.08) + np.testing.assert_allclose(dk, dk_ref, atol=0.1) np.testing.assert_allclose(dv, dv_ref, atol=0.05)