diff --git a/jetstream_pt/cache_manager.py b/jetstream_pt/cache_manager.py index 75bd75c0..13789f91 100644 --- a/jetstream_pt/cache_manager.py +++ b/jetstream_pt/cache_manager.py @@ -91,20 +91,33 @@ def __init__( cache_v: torch.Tensor, # previous cache position: int, # position to store the cache sharding, + env=None, ): super().__init__() self.cache_k = cache_k self.cache_v = cache_v self.pos = position self.sharding = sharding + self.env = env def update(self, key, value): """Update kv cache""" keyj, valuej = torchjax.to_torch((key, value)) - # pylint: disable-next=all - self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) - # pylint: disable-next=all - self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej) + if self.env.ring_buffer: + # pylint: disable-next=all + self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj) + # pylint: disable-next=all + self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej) + else: + batch = jnp.arange(self.env.batch_size) + # pylint: disable-next=all + self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos].set( + keyj.squeeze(2) + ) + # pylint: disable-next=all + self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos].set( + valuej.squeeze(2) + ) return self.cache_k, self.cache_v def state(self): @@ -113,13 +126,13 @@ def state(self): return self.cache_k.jax(), self.cache_v.jax() @classmethod - def empty(cls, shape, device, bf16_enable): + def empty(cls, shape, device, bf16_enable, env): """Create empty kv caches""" default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32 k = jnp.zeros(shape, device=device, dtype=default_dtype) v = jnp.zeros(shape, device=device, dtype=default_dtype) k, v = torchjax.to_torch((k, v)) - return cls(k, v, 0, device) + return cls(k, v, 0, device, env=env) # pylint: disable-next=all @@ -155,6 +168,7 @@ def __init__( cache_v_scaler, input_pos, # used to write cache sharding=None, + env=None, ): super().__init__() self.cache_k = cache_k @@ -163,6 +177,7 @@ def __init__( self.v_scaler = cache_v_scaler self.input_pos = input_pos self.sharding = sharding + self.env = env def state(self): """Get kv cache state""" @@ -174,7 +189,7 @@ def scalers(self): @classmethod # pylint: disable-next=all - def empty(cls, shape, device, bf16_enable): + def empty(cls, shape, device, bf16_enable, env): """Create empty kv caches""" cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8) cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8) @@ -185,7 +200,7 @@ def empty(cls, shape, device, bf16_enable): cache_k, cache_v, kscaler, vscaler = torchjax.to_torch( (cache_k, cache_v, kscaler, vscaler) ) - return cls(cache_k, cache_v, kscaler, vscaler, 0, device) + return cls(cache_k, cache_v, kscaler, vscaler, 0, device, env=env) def quantize(self, val): """Quantize value""" @@ -198,8 +213,15 @@ def update(self, xk, xv): """Update kv cache""" k_quant, kscale = self.quantize(xk) v_quant, vscale = self.quantize(xv) - self.cache_k[:, :, self.input_pos, :] = k_quant - self.cache_v[:, :, self.input_pos, :] = v_quant - self.k_scaler[:, :, self.input_pos, :] = kscale - self.v_scaler[:, :, self.input_pos, :] = vscale + if self.env.ring_buffer: + self.cache_k[:, :, self.input_pos, :] = k_quant + self.cache_v[:, :, self.input_pos, :] = v_quant + self.k_scaler[:, :, self.input_pos, :] = kscale + self.v_scaler[:, :, self.input_pos, :] = vscale + else: + batch = jnp.arange(self.env.batch_size) + self.cache_k[batch, :, self.input_pos, :] = k_quant.squeeze(2) + self.cache_v[batch, :, self.input_pos, :] = v_quant.squeeze(2) + self.k_scaler[batch, :, self.input_pos, :] = kscale.squeeze(2) + self.v_scaler[batch, :, self.input_pos, :] = vscale.squeeze(2) return self.cache_k, self.cache_v, self.k_scaler, self.v_scaler diff --git a/jetstream_pt/config.py b/jetstream_pt/config.py index 5ad29078..e38066b1 100644 --- a/jetstream_pt/config.py +++ b/jetstream_pt/config.py @@ -83,6 +83,12 @@ "for performance tuning and debugging only", required=False, ) +flags.DEFINE_bool( + "ring_buffer", + True, + "Whether to enable ring buffer", + required=False, +) flags.DEFINE_float( "temperature", 1.0, @@ -175,6 +181,7 @@ def create_engine_from_config_flags(): sampling_algorithm=FLAGS.sampling_algorithm, nucleus_topp=FLAGS.nucleus_topp, topk=FLAGS.topk, + ring_buffer=FLAGS.ring_buffer, ) print("Initialize engine", time.perf_counter() - start) diff --git a/jetstream_pt/engine.py b/jetstream_pt/engine.py index ced821ec..a878488d 100644 --- a/jetstream_pt/engine.py +++ b/jetstream_pt/engine.py @@ -65,7 +65,7 @@ class DecodeState: Tuple[jax.Array, jax.Array] ] # only present in quantized kv current_position: int - lens: jax.Array # [batch_size, 1] + lens: jax.Array # [batch_size, 1], the output token length start: jax.Array # [batch_size, 1], the starting pos for each slot input_pos: jax.Array # [batch_size, 1] input pos for each slot mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid @@ -157,7 +157,9 @@ def _call_model_generate( ): if self.env.quant_config.enable_kv_quantization: caches_obj = [ - cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes) + cache_manager.Int8KVCacheGenerate( + k, v, ks, vs, input_indexes, env=self.env + ) for (k, v), (ks, vs) in torchjax.to_torch( list(zip(caches, cache_scales)) ) @@ -165,7 +167,7 @@ def _call_model_generate( else: caches_obj = [ cache_manager.KVCacheGenerate( - k, v, input_indexes, self.cache_sharding + k, v, input_indexes, self.cache_sharding, env=self.env ) for k, v in torchjax.to_torch(caches) ] @@ -295,11 +297,16 @@ def _insert_no_wrap( ): scales = [] caches = [] - pos = decode_state.current_position - prefix.seq_len + if self.env.ring_buffer: + current_pos = decode_state.current_position + else: + current_pos = prefix.seq_len + + pos = current_pos - prefix.seq_len tokens = decode_state.tokens.at[slot].set(prefix.token) x = jnp.arange(0, self.env.cache_sequence_length) - cond = jnp.logical_and(x <= decode_state.current_position, x >= pos) + cond = jnp.logical_and(x <= current_pos, x >= pos) mask_insert = jnp.where(cond, 0, float("-inf")) mask = decode_state.mask.at[slot].set(mask_insert) start = decode_state.start.at[slot].set( @@ -470,18 +477,22 @@ def insert( # prefix, # decode_state, # ) - start_insert = decode_state.current_position - prefix.seq_len - end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen - return jax.lax.cond( - jnp.logical_and( - start_insert >= 0, end_insert < self.env.cache_sequence_length - ), - self._insert_no_wrap, - self._insert_wrap, - prefix, - decode_state, - slot, - ) + if self.env.ring_buffer: + start_insert = decode_state.current_position - prefix.seq_len + end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen + return jax.lax.cond( + jnp.logical_and( + start_insert >= 0, end_insert < self.env.cache_sequence_length + ), + self._insert_no_wrap, + self._insert_wrap, + prefix, + decode_state, + slot, + ) + # Left aligned, starts from 0, guaranteed no wrap + else: + return self._insert_no_wrap(prefix, decode_state, slot) def precompute_ragged_block_indices(self, decode_state: DecodeState): """Precompute the ragged attention block indices. Ragged attention iterates the grid @@ -545,10 +556,13 @@ def generate( ) -> tuple[DecodeState, engine_api.ResultTokens]: # seq_len = padded_tokens.shape[0] pos = decode_state.current_position - input_indexes = jnp.full((1,), pos) - - # fill mask first - mask = decode_state.mask.at[:, decode_state.current_position].set(0) + if self.env.ring_buffer: + input_indexes = jnp.full((1,), pos) + mask = decode_state.mask.at[:, decode_state.current_position].set(0) + else: + input_indexes = decode_state.input_pos + batch = jnp.arange(self.env.batch_size) + mask = decode_state.mask.at[batch, decode_state.input_pos].set(0) ragged_batch_index, ragged_block_index = ( self.precompute_ragged_block_indices(decode_state) ) @@ -570,7 +584,19 @@ def generate( ) next_token = self._sampling(logits, self.env.batch_size) - lens = decode_state.lens + 1 + if self.env.ring_buffer: + input_pos = decode_state.input_pos + 1 + lens = decode_state.lens + 1 + else: + input_pos = jnp.where( + decode_state.input_pos == 0, + 0, + decode_state.input_pos + 1 % self.env.cache_len, + ) + lens = jnp.where( + decode_state.lens == 0, 0, decode_state.lens + 1 % self.env.cache_len + ) + data = jnp.concatenate( [ decode_state.tokens, @@ -597,15 +623,14 @@ def generate( (decode_state.current_position + 1) % self.env.cache_sequence_length, lens, decode_state.start, - decode_state.input_pos + 1, + input_pos, mask, ) print( "new_pos", (decode_state.current_position + 1) % self.env.cache_sequence_length, ) - print("cache_seq_len", self.env.cache_sequence_length) - + print(f"new_token: {jnp.squeeze(next_token)}") return new_decode_state, result_tokens # pylint: disable-next=all @@ -782,6 +807,7 @@ def create_pytorch_engine( sampling_algorithm="greedy", nucleus_topp=None, topk=None, + ring_buffer=True, ) -> PyTorchEngine: """Returns: The pytorch engine.""" @@ -851,6 +877,7 @@ def create_pytorch_engine( sampling_algorithm=sampling_algorithm, nucleus_topp=nucleus_topp, topk=topk, + ring_buffer=ring_buffer, ) if shard_on_batch and sharding_config: diff --git a/jetstream_pt/environment.py b/jetstream_pt/environment.py index 5311f8c2..fce606d9 100644 --- a/jetstream_pt/environment.py +++ b/jetstream_pt/environment.py @@ -100,6 +100,9 @@ class JetEngineEnvironmentData: # Starting position starting_position: int = 512 + # Ring buffer + ring_buffer: bool = True + # Variables used in token sampling # sampling algorithm to use ("greedy", "weighted", "neucleus", "topk") sampling_algorithm: str = "greedy" @@ -120,11 +123,13 @@ class JetEngineEnvironment: def __init__(self, data: JetEngineEnvironmentData): self._data = data + self.batch_size = self._data.batch_size self.seq_len = self._data.max_input_sequence_length self.cache_len = self._data.cache_sequence_length self.ragged_mha = self._data.ragged_mha self.block_size = self._data.block_size self.starting_position = self._data.starting_position + self.ring_buffer = self._data.ring_buffer P = jax.sharding.PartitionSpec num_of_partitions = jax.device_count() @@ -202,13 +207,13 @@ def make_caches_generate(self): if self._data.quant_config.enable_kv_quantization: caches.append( cache_manager.Int8KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable + shape, self.cache_sharding, self.bf16_enable, env=self ) ) else: caches.append( cache_manager.KVCacheGenerate.empty( - shape, self.cache_sharding, self.bf16_enable + shape, self.cache_sharding, self.bf16_enable, env=self ) ) return caches diff --git a/tests/test_model_impl.py b/tests/test_model_impl.py index b0dfc151..4d4ddfd6 100644 --- a/tests/test_model_impl.py +++ b/tests/test_model_impl.py @@ -85,7 +85,7 @@ def _make_one_cache_for_generate(self, env, pos): (cache_array_k, cache_array_v) ) cache_decode = cache_manager.KVCacheGenerate( - cache_array_k, cache_array_v, pos, None + cache_array_k, cache_array_v, pos, None, env ) return cache_decode diff --git a/tests/test_quantization.py b/tests/test_quantization.py index e2f2764e..581553d1 100644 --- a/tests/test_quantization.py +++ b/tests/test_quantization.py @@ -72,7 +72,10 @@ def test_kv_cache(self): """test kv cache quantization""" cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim with jax.default_device(jax.devices("cpu")[0]): - cache = cache_manager.Int8KVCacheGenerate.empty(cache_shape, None, False) + env, _ = helpers.make_env_tiny() + cache = cache_manager.Int8KVCacheGenerate.empty( + cache_shape, None, False, env + ) # seqlen is 1 k = self._xla_tensor((3, 2, 1, 2)) v = self._xla_tensor((3, 2, 1, 2)) @@ -101,7 +104,7 @@ def test_kv_kernel(self): cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax)) - cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None) + cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None, env) # 1 is seqlen xq = jax.random.normal(key, (3, 2, 1, 2)) @@ -119,7 +122,13 @@ def test_kv_kernel(self): cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (1, 3)) cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (1, 3)) cache_int = cache_manager.Int8KVCacheGenerate( - cache_k_int, cache_v_int, cache_k_scaler, cache_v_scaler, [0], None + cache_k_int, + cache_v_int, + cache_k_scaler, + cache_v_scaler, + [0], + None, + env, ) attention_quant = layers.Int8KVAttentionKernel(env) int_res = attention_quant(xq, xk, xv, None, cache_int)