Skip to content

Commit

Permalink
fix for cosine sim attention
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Aug 20, 2022
1 parent 665966e commit 9a48f49
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 10 deletions.
14 changes: 5 additions & 9 deletions flash_attention_jax/cosine_sim_flash_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def chunk_scanner(carries, _):

(_, out, row_sum), _ = lax.scan(chunk_scanner, init = (0, out, row_sum), xs = None, length = math.ceil(k_len / K_CHUNK_SIZE))

out = out * (k_len / (row_sum + EPSILON)) # renormalize after acquiring all the correct row sums
out = out * (k_len / (row_sum + EPSILON)) # renormalize after acquiring all the correct row sums

out = out.reshape(q_len, v_dim)
row_sum = row_sum.reshape(q_len)
Expand Down Expand Up @@ -86,11 +86,9 @@ def flash_attention_forward(q, k, v, key_mask):
out, (row_sum,) = cosine_sim_flash_attention_after_l2norm(q, k, v, key_mask)
return out, (q, k, v, key_mask, out, row_sum)

def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l, m):
def _query_chunk_flash_attention_backward(q, k, v, key_mask,o, do, l):
q_len, dim, k_len, v_dim = *q.shape, *v.shape

scale = 1 / jnp.sqrt(dim)

def chunk_scanner(carries, _):
chunk_idx, dq = carries
k_chunk_sizes = min(K_CHUNK_SIZE, k_len)
Expand All @@ -111,7 +109,7 @@ def chunk_scanner(carries, _):
dp = do @ v_chunk.transpose()

D = jnp.sum(do * o, axis = -1, keepdims = True)
ds = p * scale * (dp - D)
ds = p * COSINE_SIM_SCALE * (dp - D)

dq_chunk = ds @ k_chunk
dk_chunk = ds.transpose() @ q
Expand All @@ -130,14 +128,13 @@ def chunk_scanner(carries, _):

@jit
def flash_attention_backward(res, do):
q, k, v, key_mask, o, l, m = res
q, k, v, key_mask, o, l = res

q_len, dim = q.shape

dk = jnp.zeros_like(k)
dv = jnp.zeros_like(v)

m = m.reshape(q_len, 1)
l = l.reshape(q_len, 1)

def chunk_scanner(carries, _):
Expand All @@ -146,12 +143,11 @@ def chunk_scanner(carries, _):
chunk_sizes = min(Q_CHUNK_SIZE, q_len)

q_chunk = lax.dynamic_slice(q, (chunk_idx, 0), slice_sizes = (chunk_sizes, q.shape[-1]))
m_chunk = lax.dynamic_slice(m, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
l_chunk = lax.dynamic_slice(l, (chunk_idx, 0), slice_sizes = (chunk_sizes, 1))
o_chunk = lax.dynamic_slice(o, (chunk_idx, 0), slice_sizes = (chunk_sizes, o.shape[-1]))
do_chunk = lax.dynamic_slice(do, (chunk_idx, 0), slice_sizes = (chunk_sizes, do.shape[-1]))

dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk, m_chunk)
dq_chunk, dk_chunk, dv_chunk = _query_chunk_flash_attention_backward(q_chunk, k, v, key_mask, o_chunk, do_chunk, l_chunk)
return (chunk_idx + chunk_sizes, dk + dk_chunk, dv + dv_chunk), dq_chunk

(_, dk, dv), dq = lax.scan(chunk_scanner, init = (0, dk, dv), xs = None, length = math.ceil(q_len / Q_CHUNK_SIZE))
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'flash-attention-jax',
packages = find_packages(exclude=[]),
version = '0.0.7',
version = '0.0.8',
license='MIT',
description = 'Flash Attention - in Jax',
author = 'Phil Wang',
Expand Down

0 comments on commit 9a48f49

Please sign in to comment.