-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add left aligned cache support. #133
Changes from all commits
ada79a4
3998c95
37a8e89
73da3cf
bce8a09
93dae6e
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -65,7 +65,7 @@ class DecodeState: | |
Tuple[jax.Array, jax.Array] | ||
] # only present in quantized kv | ||
current_position: int | ||
lens: jax.Array # [batch_size, 1] | ||
lens: jax.Array # [batch_size, 1], the output token length | ||
start: jax.Array # [batch_size, 1], the starting pos for each slot | ||
input_pos: jax.Array # [batch_size, 1] input pos for each slot | ||
mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid | ||
|
@@ -157,15 +157,17 @@ def _call_model_generate( | |
): | ||
if self.env.quant_config.enable_kv_quantization: | ||
caches_obj = [ | ||
cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes) | ||
cache_manager.Int8KVCacheGenerate( | ||
k, v, ks, vs, input_indexes, env=self.env | ||
) | ||
for (k, v), (ks, vs) in torchjax.to_torch( | ||
list(zip(caches, cache_scales)) | ||
) | ||
] | ||
else: | ||
caches_obj = [ | ||
cache_manager.KVCacheGenerate( | ||
k, v, input_indexes, self.cache_sharding | ||
k, v, input_indexes, self.cache_sharding, env=self.env | ||
) | ||
for k, v in torchjax.to_torch(caches) | ||
] | ||
|
@@ -295,11 +297,16 @@ def _insert_no_wrap( | |
): | ||
scales = [] | ||
caches = [] | ||
pos = decode_state.current_position - prefix.seq_len | ||
if self.env.ring_buffer: | ||
current_pos = decode_state.current_position | ||
else: | ||
current_pos = prefix.seq_len | ||
|
||
pos = current_pos - prefix.seq_len | ||
tokens = decode_state.tokens.at[slot].set(prefix.token) | ||
|
||
x = jnp.arange(0, self.env.cache_sequence_length) | ||
cond = jnp.logical_and(x <= decode_state.current_position, x >= pos) | ||
cond = jnp.logical_and(x <= current_pos, x >= pos) | ||
mask_insert = jnp.where(cond, 0, float("-inf")) | ||
mask = decode_state.mask.at[slot].set(mask_insert) | ||
start = decode_state.start.at[slot].set( | ||
|
@@ -470,18 +477,22 @@ def insert( | |
# prefix, | ||
# decode_state, | ||
# ) | ||
start_insert = decode_state.current_position - prefix.seq_len | ||
end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen | ||
return jax.lax.cond( | ||
jnp.logical_and( | ||
start_insert >= 0, end_insert < self.env.cache_sequence_length | ||
), | ||
self._insert_no_wrap, | ||
self._insert_wrap, | ||
prefix, | ||
decode_state, | ||
slot, | ||
) | ||
if self.env.ring_buffer: | ||
start_insert = decode_state.current_position - prefix.seq_len | ||
end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen | ||
return jax.lax.cond( | ||
jnp.logical_and( | ||
start_insert >= 0, end_insert < self.env.cache_sequence_length | ||
), | ||
self._insert_no_wrap, | ||
self._insert_wrap, | ||
prefix, | ||
decode_state, | ||
slot, | ||
) | ||
# Left aligned, starts from 0, guaranteed no wrap | ||
else: | ||
return self._insert_no_wrap(prefix, decode_state, slot) | ||
|
||
def precompute_ragged_block_indices(self, decode_state: DecodeState): | ||
"""Precompute the ragged attention block indices. Ragged attention iterates the grid | ||
|
@@ -545,10 +556,13 @@ def generate( | |
) -> tuple[DecodeState, engine_api.ResultTokens]: | ||
# seq_len = padded_tokens.shape[0] | ||
pos = decode_state.current_position | ||
input_indexes = jnp.full((1,), pos) | ||
|
||
# fill mask first | ||
mask = decode_state.mask.at[:, decode_state.current_position].set(0) | ||
if self.env.ring_buffer: | ||
input_indexes = jnp.full((1,), pos) | ||
mask = decode_state.mask.at[:, decode_state.current_position].set(0) | ||
else: | ||
input_indexes = decode_state.input_pos | ||
batch = jnp.arange(self.env.batch_size) | ||
mask = decode_state.mask.at[batch, decode_state.input_pos].set(0) | ||
ragged_batch_index, ragged_block_index = ( | ||
self.precompute_ragged_block_indices(decode_state) | ||
) | ||
|
@@ -570,7 +584,19 @@ def generate( | |
) | ||
|
||
next_token = self._sampling(logits, self.env.batch_size) | ||
lens = decode_state.lens + 1 | ||
if self.env.ring_buffer: | ||
input_pos = decode_state.input_pos + 1 | ||
lens = decode_state.lens + 1 | ||
else: | ||
input_pos = jnp.where( | ||
decode_state.input_pos == 0, | ||
0, | ||
decode_state.input_pos + 1 % self.env.cache_len, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In non ring buffer case, can input_pos be larger than cache len? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. No. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If no, I feel we don't need do % since it never reach the cache len. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We don't have control for this. Generate() will keep running if no new prefill results are inserted. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks for sharing the details! |
||
) | ||
lens = jnp.where( | ||
decode_state.lens == 0, 0, decode_state.lens + 1 % self.env.cache_len | ||
) | ||
wang2yn84 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
data = jnp.concatenate( | ||
[ | ||
decode_state.tokens, | ||
|
@@ -597,15 +623,14 @@ def generate( | |
(decode_state.current_position + 1) % self.env.cache_sequence_length, | ||
lens, | ||
decode_state.start, | ||
decode_state.input_pos + 1, | ||
input_pos, | ||
mask, | ||
) | ||
print( | ||
"new_pos", | ||
(decode_state.current_position + 1) % self.env.cache_sequence_length, | ||
) | ||
print("cache_seq_len", self.env.cache_sequence_length) | ||
|
||
print(f"new_token: {jnp.squeeze(next_token)}") | ||
return new_decode_state, result_tokens | ||
|
||
# pylint: disable-next=all | ||
|
@@ -782,6 +807,7 @@ def create_pytorch_engine( | |
sampling_algorithm="greedy", | ||
nucleus_topp=None, | ||
topk=None, | ||
ring_buffer=True, | ||
) -> PyTorchEngine: | ||
"""Returns: The pytorch engine.""" | ||
|
||
|
@@ -851,6 +877,7 @@ def create_pytorch_engine( | |
sampling_algorithm=sampling_algorithm, | ||
nucleus_topp=nucleus_topp, | ||
topk=topk, | ||
ring_buffer=ring_buffer, | ||
) | ||
|
||
if shard_on_batch and sharding_config: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What if we change current_position to [batch_size, 1], can we use same logic do mask for both ring_buffer and onn_ring_buffer?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really. For non ring buffer case, there is one single value of current position to indicate the decoding position for all the batches. But for ring buffer, every batch has different position, so we cannot use the current_position here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes. I mean if we change current_position to [batch_size, 1], different slot can have different the current_position. For non ring buffer case, the current_position should be same as input_pos.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It will cause performance regression. Please check jax_experiments.py/test7, inserting with batching + position array takes much longer, like x4~x5