Skip to content

Commit

Permalink
flashinfer: remove contiguous calls
Browse files Browse the repository at this point in the history
  • Loading branch information
danieldk committed Jan 6, 2025
1 parent 02e3dc4 commit 6defe57
Showing 1 changed file with 2 additions and 4 deletions.
6 changes: 2 additions & 4 deletions server/text_generation_server/layers/attention/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def paged_attention(
from text_generation_server.layers.attention.flashinfer import decode_state

return decode_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(),
query,
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
sm_scale=softmax_scale,
Expand Down Expand Up @@ -231,8 +230,7 @@ def attention(
softcap = 0.0

return prefill_with_paged_kv_state.get().forward(
# TODO: remove `contiguous` call once https://github.com/flashinfer-ai/flashinfer/pull/553 is merged.
query.contiguous(),
query,
causal=causal,
paged_kv_cache=(kv_cache.key, kv_cache.value),
logits_soft_cap=softcap,
Expand Down

0 comments on commit 6defe57

Please sign in to comment.