-
Notifications
You must be signed in to change notification settings - Fork 324
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP]chunked prefill #1272
base: main
Are you sure you want to change the base?
[WIP]chunked prefill #1272
Conversation
complete_prompt_padded_length: Optional[int] = None, | ||
positions: Optional[jax.Array] = None, | ||
previous_chunk: Optional[Any] = None, | ||
) -> Tuple[Prefix, engine_api.ResultTokens]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You also need to update the prefill_aot()
method.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also, please update the docbook below.
zero_to_n = jnp.arange(0, padded_tokens.shape[0]) | ||
else: | ||
zero_to_n = jnp.arange(0, complete_prompt_padded_length.shape[0]) | ||
ones_to_keep = zero_to_n < complete_prompt_true_length |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How is complete_prompt_true_length
different from existing true_length
param? Can you re-use the true_length
param for chunking too? For all chunks, the true_length == padded_length, except for the last chunk.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
complete_prompt_true_length - is the complete prompt (in true length, without padding), we need that to set decoder segment ids which is used in decode step.
true_length is the true length of the chunked prompt
yes, correct, true_length == padded_length for all chunks except last chunk, to avoid if statement to check last chunk or not, I preferred to pass true_length. Let me know if you have concerns?
cached_prefill_key_vars[0].value = jax.lax.dynamic_update_slice(cached_key_value, key_shaped_for_cache, (next_pos, 0, 0, 0)) | ||
|
||
cached_prefill_value_vars[0].value = jax.lax.dynamic_update_slice(cached_value_value, value_shaped_for_cache, (next_pos, 0, 0, 0)) | ||
cached_prefill_segment_id_var.value = decoder_segment_ids | ||
return jnp.transpose(cached_prefill_key_vars[0].value, (2,0,1,3)), jnp.transpose(cached_prefill_value_vars[0].value, (2,0,1,3)), cached_prefill_segment_id_var.value | ||
else: | ||
cached_prefill_key_vars[0].value = jax.lax.dynamic_update_slice(cached_prefill_key_vars[0].value, key_shaped_for_cache, (next_pos, 0, 0, 0)) | ||
cached_prefill_value_vars[0].value = jax.lax.dynamic_update_slice(cached_prefill_value_vars[0].value, value_shaped_for_cache, (next_pos, 0, 0, 0)) | ||
cached_prefill_segment_id_var.value = decoder_segment_ids | ||
return jnp.transpose(cached_prefill_key_vars[0].value, (2,0,1,3)), jnp.transpose(cached_prefill_value_vars[0].value, (2,0,1,3)), cached_prefill_segment_id_var.value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you generalize the values used (eg. (2,0,1,3)
) instead of hardcoding?
@@ -786,11 +817,30 @@ def kv_cache_prefill( | |||
cached_prefill_key_vars[0].value = key_shaped_for_cache | |||
cached_prefill_value_vars[0].value = value_shaped_for_cache | |||
|
|||
if self.config.use_chunked_prefill: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a code comment to describe whats being done below.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure, will send another version with updated comments.
Description
Start with a short description of what the PR does and how this is a change from
the past.
The current CL is WIP and supports Chunked Prefill for unquantized KV Cache.
attentions.py file has mainly 2 changes-
i) Fetching previous cache is there are processed chunks
ii) Generating attention masks for a particular chunk based on article - https://medium.com/byte-sized-ai/llm-inference-optimizations-2-chunked-prefill-764407b3a67a
Minor chnages -
wiring previous chunk information to decoder, model, attentions layers.
The rest of the description includes relevant details and context, examples:
If the change fixes a bug or a Github issue, please include a link, e.g.,:
FIXES: b/123456
FIXES: #123456
Tests
Please describe how you tested this change, and include any instructions and/or
commands to reproduce.
Checklist
Before submitting this PR, please make sure (put X in square brackets):