diff --git a/flash_attention_jax/cosine_sim_flash_attention.py b/flash_attention_jax/cosine_sim_flash_attention.py index 6f73b75..9a73ba7 100644 --- a/flash_attention_jax/cosine_sim_flash_attention.py +++ b/flash_attention_jax/cosine_sim_flash_attention.py @@ -47,7 +47,7 @@ def chunk_scanner(carries, _): (_, out, row_sum), _ = lax.scan(chunk_scanner, init = (0, out, row_sum), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE)) - out = out * (k_len / (row_sum + EPSILON)) # renormalize after acquiring all the correct row sums + out = out * (k_len / (row_sum + EPSILON)) # renormalize after acquiring all the correct row sums out = out.reshape(q_len, v_dim) row_sum = row_sum.reshape(q_len) @@ -86,11 +86,9 @@ def flash_attention_forward(q, k, v, key_mask): out, (row_sum,) = cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask) return out, (q, k, v, key_mask, out, row_sum) -def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l, m): +def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l): q_len, dim, k_len, v_dim = *q.shape, *v.shape - scale = 1 / jnp.sqrt(dim) - def chunk_scanner(carries, _): chunk_idx, dq = carries k_chunk_sizes = min(K_CHUNK_SIZE, k_len) @@ -111,7 +109,7 @@ def chunk_scanner(carries, _): dp = do @ v_chunk.transpose() D = jnp.sum(do * o, axis = -1, keepdims = True) - ds = p * scale * (dp - D) + ds = p * COSINE_SIM_SCALE * (dp - D) dq_chunk = ds @ k_chunk dk_chunk = ds.transpose() @ q @@ -130,14 +128,13 @@ def chunk_scanner(carries, _): @jit def flash_attention_backward(res, do): - q, k, v, key_mask, o, l, m = res + q, k, v, key_mask, o, l = res q_len, dim = q.shape dk = jnp.zeros_like(k) dv = jnp.zeros_like(v) - m = m.reshape(q_len, 1) l = l.reshape(q_len, 1) def chunk_scanner(carries, _): @@ -146,12 +143,11 @@ def chunk_scanner(carries, _): chunk_sizes = min(Q_CHUNK_SIZE, q_len) q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (chunk_sizes, q.shape[-1])) - m_chunk = lax.dynamic_slice(m, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1)) l_chunk = lax.dynamic_slice(l, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1)) o_chunk = lax.dynamic_slice(o, (chunk_idx, 0), slice_sizes = (chunk_sizes, o.shape[-1])) do_chunk = lax.dynamic_slice(do, (chunk_idx, 0), slice_sizes = (chunk_sizes, do.shape[-1])) - dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk) + dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk) return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk (_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE)) diff --git a/setup.py b/setup.py index 1a5d826..8a338f3 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'flash-attention-jax', packages = find_packages(exclude=[]), - version = '0.0.7', + version = '0.0.8', license='MIT', description = 'Flash Attention - in Jax', author = 'Phil Wang',