Skip to content

Commit

Permalink
add gpu pallas flash kernel.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwyang-google committed Feb 24, 2025
1 parent 5faba12 commit 963519b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 16 deletions.
2 changes: 1 addition & 1 deletion .github/CODEOWNERS
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Changes in this file should match with requiredReviewers in file .github/workflows/AddLabel.yml
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan @richjames0 @rni418 @gagika
* @gobbleturk @khatwanimohit @bvandermoon @vipannalla @RissyRan
5 changes: 1 addition & 4 deletions .github/workflows/AddLabel.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,16 +51,13 @@ jobs:
process.exit(1)
}
// This list should match with CODEOWNERS.
// This list should match with CODEOWNERS
let requiredReviewers = {
gobbleturk: "",
khatwanimohit: "",
bvandermoon: "",
vipannalla: "",
RissyRan: "",
richjames0: "",
rni418: "",
gagika: "",
}
const reviews = await github.rest.pulls.listReviews({
owner,
Expand Down
33 changes: 22 additions & 11 deletions MaxText/layers/attentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from jax.experimental import shard_map
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_kernel
from jax.experimental.pallas.ops.tpu.splash_attention import splash_attention_mask
from jax.experimental.pallas.ops.gpu import attention
import jax.numpy as jnp
import common_types
from kernels.ragged_attention import ragged_gqa
Expand Down Expand Up @@ -237,17 +238,27 @@ def apply_attention(
):
return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode)
elif self.attention_kernel == "flash" or self.attention_kernel == "autoselected":
if isinstance(key, KVTensor):
key = key.dequant()
if isinstance(value, KVTensor):
value = value.dequant()

if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
"""Decode not supported with flash attention.
Use `dot_product` instead."""
)
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
if jax.devices()[0].platform == "tpu":
if isinstance(key, KVTensor):
key = key.dequant()
if isinstance(value, KVTensor):
value = value.dequant()

if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
raise ValueError(
"""Decode not supported with flash attention.
Use `dot_product` instead."""
)
return self.tpu_flash_attention(query, key, value, decoder_segment_ids, self.attn_logits_soft_cap), None, None
else:
if model_mode == common_types.MODEL_MODE_AUTOREGRESSIVE:
# fallback to dot_product as pallas gpu flash attention doesn't support decode stage
return self.apply_attention_dot(query, key, value, decoder_segment_ids, model_mode)
else:
key = jnp.repeat(key, self.num_query_heads // self.num_kv_heads, axis=2)
value = jnp.repeat(value, self.num_query_heads // self.num_kv_heads, axis=2)
out = attention.mha(query, key, value, decoder_segment_ids, sm_scale=1.0, causal=True)
return out, None, None
elif self.attention_kernel == "cudnn_flash_te":
if isinstance(key, KVTensor):
key = key.dequant()
Expand Down

0 comments on commit 963519b

Please sign in to comment.