Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add left aligned cache support. #133

Merged
merged 6 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 34 additions & 12 deletions jetstream_pt/cache_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,20 +91,33 @@ def __init__(
cache_v: torch.Tensor, # previous cache
position: int, # position to store the cache
sharding,
env=None,
):
super().__init__()
self.cache_k = cache_k
self.cache_v = cache_v
self.pos = position
self.sharding = sharding
self.env = env

def update(self, key, value):
"""Update kv cache"""
keyj, valuej = torchjax.to_torch((key, value))
# pylint: disable-next=all
self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj)
# pylint: disable-next=all
self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej)
if self.env.ring_buffer:
# pylint: disable-next=all
self.cache_k._elem = self.cache_k._elem.at[:, :, self.pos].set(keyj)
# pylint: disable-next=all
self.cache_v._elem = self.cache_v._elem.at[:, :, self.pos].set(valuej)
else:
batch = jnp.arange(self.env.batch_size)
# pylint: disable-next=all
self.cache_k._elem = self.cache_k._elem.at[batch, :, self.pos].set(
keyj.squeeze(2)
)
# pylint: disable-next=all
self.cache_v._elem = self.cache_v._elem.at[batch, :, self.pos].set(
valuej.squeeze(2)
)
return self.cache_k, self.cache_v

def state(self):
Expand All @@ -113,13 +126,13 @@ def state(self):
return self.cache_k.jax(), self.cache_v.jax()

@classmethod
def empty(cls, shape, device, bf16_enable):
def empty(cls, shape, device, bf16_enable, env):
"""Create empty kv caches"""
default_dtype = jnp.bfloat16 if bf16_enable else jnp.float32
k = jnp.zeros(shape, device=device, dtype=default_dtype)
v = jnp.zeros(shape, device=device, dtype=default_dtype)
k, v = torchjax.to_torch((k, v))
return cls(k, v, 0, device)
return cls(k, v, 0, device, env=env)


# pylint: disable-next=all
Expand Down Expand Up @@ -155,6 +168,7 @@ def __init__(
cache_v_scaler,
input_pos, # used to write cache
sharding=None,
env=None,
):
super().__init__()
self.cache_k = cache_k
Expand All @@ -163,6 +177,7 @@ def __init__(
self.v_scaler = cache_v_scaler
self.input_pos = input_pos
self.sharding = sharding
self.env = env

def state(self):
"""Get kv cache state"""
Expand All @@ -174,7 +189,7 @@ def scalers(self):

@classmethod
# pylint: disable-next=all
def empty(cls, shape, device, bf16_enable):
def empty(cls, shape, device, bf16_enable, env):
"""Create empty kv caches"""
cache_k = jnp.zeros(shape, device=device, dtype=jnp.int8)
cache_v = jnp.zeros(shape, device=device, dtype=jnp.int8)
Expand All @@ -185,7 +200,7 @@ def empty(cls, shape, device, bf16_enable):
cache_k, cache_v, kscaler, vscaler = torchjax.to_torch(
(cache_k, cache_v, kscaler, vscaler)
)
return cls(cache_k, cache_v, kscaler, vscaler, 0, device)
return cls(cache_k, cache_v, kscaler, vscaler, 0, device, env=env)

def quantize(self, val):
"""Quantize value"""
Expand All @@ -198,8 +213,15 @@ def update(self, xk, xv):
"""Update kv cache"""
k_quant, kscale = self.quantize(xk)
v_quant, vscale = self.quantize(xv)
self.cache_k[:, :, self.input_pos, :] = k_quant
self.cache_v[:, :, self.input_pos, :] = v_quant
self.k_scaler[:, :, self.input_pos, :] = kscale
self.v_scaler[:, :, self.input_pos, :] = vscale
if self.env.ring_buffer:
self.cache_k[:, :, self.input_pos, :] = k_quant
self.cache_v[:, :, self.input_pos, :] = v_quant
self.k_scaler[:, :, self.input_pos, :] = kscale
self.v_scaler[:, :, self.input_pos, :] = vscale
else:
batch = jnp.arange(self.env.batch_size)
self.cache_k[batch, :, self.input_pos, :] = k_quant.squeeze(2)
self.cache_v[batch, :, self.input_pos, :] = v_quant.squeeze(2)
self.k_scaler[batch, :, self.input_pos, :] = kscale.squeeze(2)
self.v_scaler[batch, :, self.input_pos, :] = vscale.squeeze(2)
return self.cache_k, self.cache_v, self.k_scaler, self.v_scaler
7 changes: 7 additions & 0 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@
"for performance tuning and debugging only",
required=False,
)
flags.DEFINE_bool(
"ring_buffer",
True,
"Whether to enable ring buffer",
required=False,
)
flags.DEFINE_float(
"temperature",
1.0,
Expand Down Expand Up @@ -175,6 +181,7 @@ def create_engine_from_config_flags():
sampling_algorithm=FLAGS.sampling_algorithm,
nucleus_topp=FLAGS.nucleus_topp,
topk=FLAGS.topk,
ring_buffer=FLAGS.ring_buffer,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down
77 changes: 52 additions & 25 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ class DecodeState:
Tuple[jax.Array, jax.Array]
] # only present in quantized kv
current_position: int
lens: jax.Array # [batch_size, 1]
lens: jax.Array # [batch_size, 1], the output token length
start: jax.Array # [batch_size, 1], the starting pos for each slot
input_pos: jax.Array # [batch_size, 1] input pos for each slot
mask: jax.Array # [batch_size, seqlen] -inf for invalid; 0 for valid
Expand Down Expand Up @@ -157,15 +157,17 @@ def _call_model_generate(
):
if self.env.quant_config.enable_kv_quantization:
caches_obj = [
cache_manager.Int8KVCacheGenerate(k, v, ks, vs, input_indexes)
cache_manager.Int8KVCacheGenerate(
k, v, ks, vs, input_indexes, env=self.env
)
for (k, v), (ks, vs) in torchjax.to_torch(
list(zip(caches, cache_scales))
)
]
else:
caches_obj = [
cache_manager.KVCacheGenerate(
k, v, input_indexes, self.cache_sharding
k, v, input_indexes, self.cache_sharding, env=self.env
)
for k, v in torchjax.to_torch(caches)
]
Expand Down Expand Up @@ -295,11 +297,16 @@ def _insert_no_wrap(
):
scales = []
caches = []
pos = decode_state.current_position - prefix.seq_len
if self.env.ring_buffer:
current_pos = decode_state.current_position
else:
current_pos = prefix.seq_len

pos = current_pos - prefix.seq_len
tokens = decode_state.tokens.at[slot].set(prefix.token)

x = jnp.arange(0, self.env.cache_sequence_length)
cond = jnp.logical_and(x <= decode_state.current_position, x >= pos)
cond = jnp.logical_and(x <= current_pos, x >= pos)
mask_insert = jnp.where(cond, 0, float("-inf"))
mask = decode_state.mask.at[slot].set(mask_insert)
start = decode_state.start.at[slot].set(
Expand Down Expand Up @@ -470,18 +477,22 @@ def insert(
# prefix,
# decode_state,
# )
start_insert = decode_state.current_position - prefix.seq_len
end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen
return jax.lax.cond(
jnp.logical_and(
start_insert >= 0, end_insert < self.env.cache_sequence_length
),
self._insert_no_wrap,
self._insert_wrap,
prefix,
decode_state,
slot,
)
if self.env.ring_buffer:
start_insert = decode_state.current_position - prefix.seq_len
end_insert = start_insert + prefix.caches[0][0].shape[2] # padded seclen
return jax.lax.cond(
jnp.logical_and(
start_insert >= 0, end_insert < self.env.cache_sequence_length
),
self._insert_no_wrap,
self._insert_wrap,
prefix,
decode_state,
slot,
)
# Left aligned, starts from 0, guaranteed no wrap
else:
return self._insert_no_wrap(prefix, decode_state, slot)

def precompute_ragged_block_indices(self, decode_state: DecodeState):
"""Precompute the ragged attention block indices. Ragged attention iterates the grid
Expand Down Expand Up @@ -545,10 +556,13 @@ def generate(
) -> tuple[DecodeState, engine_api.ResultTokens]:
# seq_len = padded_tokens.shape[0]
pos = decode_state.current_position
input_indexes = jnp.full((1,), pos)

# fill mask first
mask = decode_state.mask.at[:, decode_state.current_position].set(0)
if self.env.ring_buffer:
input_indexes = jnp.full((1,), pos)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if we change current_position to [batch_size, 1], can we use same logic do mask for both ring_buffer and onn_ring_buffer?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not really. For non ring buffer case, there is one single value of current position to indicate the decoding position for all the batches. But for ring buffer, every batch has different position, so we cannot use the current_position here.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. I mean if we change current_position to [batch_size, 1], different slot can have different the current_position. For non ring buffer case, the current_position should be same as input_pos.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It will cause performance regression. Please check jax_experiments.py/test7, inserting with batching + position array takes much longer, like x4~x5

mask = decode_state.mask.at[:, decode_state.current_position].set(0)
else:
input_indexes = decode_state.input_pos
batch = jnp.arange(self.env.batch_size)
mask = decode_state.mask.at[batch, decode_state.input_pos].set(0)
ragged_batch_index, ragged_block_index = (
self.precompute_ragged_block_indices(decode_state)
)
Expand All @@ -570,7 +584,19 @@ def generate(
)

next_token = self._sampling(logits, self.env.batch_size)
lens = decode_state.lens + 1
if self.env.ring_buffer:
input_pos = decode_state.input_pos + 1
lens = decode_state.lens + 1
else:
input_pos = jnp.where(
decode_state.input_pos == 0,
0,
decode_state.input_pos + 1 % self.env.cache_len,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In non ring buffer case, can input_pos be larger than cache len?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If no, I feel we don't need do % since it never reach the cache len.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't have control for this. Generate() will keep running if no new prefill results are inserted.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for sharing the details!

)
lens = jnp.where(
decode_state.lens == 0, 0, decode_state.lens + 1 % self.env.cache_len
)
wang2yn84 marked this conversation as resolved.
Show resolved Hide resolved

data = jnp.concatenate(
[
decode_state.tokens,
Expand All @@ -597,15 +623,14 @@ def generate(
(decode_state.current_position + 1) % self.env.cache_sequence_length,
lens,
decode_state.start,
decode_state.input_pos + 1,
input_pos,
mask,
)
print(
"new_pos",
(decode_state.current_position + 1) % self.env.cache_sequence_length,
)
print("cache_seq_len", self.env.cache_sequence_length)

print(f"new_token: {jnp.squeeze(next_token)}")
return new_decode_state, result_tokens

# pylint: disable-next=all
Expand Down Expand Up @@ -782,6 +807,7 @@ def create_pytorch_engine(
sampling_algorithm="greedy",
nucleus_topp=None,
topk=None,
ring_buffer=True,
) -> PyTorchEngine:
"""Returns: The pytorch engine."""

Expand Down Expand Up @@ -851,6 +877,7 @@ def create_pytorch_engine(
sampling_algorithm=sampling_algorithm,
nucleus_topp=nucleus_topp,
topk=topk,
ring_buffer=ring_buffer,
)

if shard_on_batch and sharding_config:
Expand Down
9 changes: 7 additions & 2 deletions jetstream_pt/environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ class JetEngineEnvironmentData:
# Starting position
starting_position: int = 512

# Ring buffer
ring_buffer: bool = True

# Variables used in token sampling
# sampling algorithm to use ("greedy", "weighted", "neucleus", "topk")
sampling_algorithm: str = "greedy"
Expand All @@ -120,11 +123,13 @@ class JetEngineEnvironment:
def __init__(self, data: JetEngineEnvironmentData):
self._data = data

self.batch_size = self._data.batch_size
self.seq_len = self._data.max_input_sequence_length
self.cache_len = self._data.cache_sequence_length
self.ragged_mha = self._data.ragged_mha
self.block_size = self._data.block_size
self.starting_position = self._data.starting_position
self.ring_buffer = self._data.ring_buffer
P = jax.sharding.PartitionSpec

num_of_partitions = jax.device_count()
Expand Down Expand Up @@ -202,13 +207,13 @@ def make_caches_generate(self):
if self._data.quant_config.enable_kv_quantization:
caches.append(
cache_manager.Int8KVCacheGenerate.empty(
shape, self.cache_sharding, self.bf16_enable
shape, self.cache_sharding, self.bf16_enable, env=self
)
)
else:
caches.append(
cache_manager.KVCacheGenerate.empty(
shape, self.cache_sharding, self.bf16_enable
shape, self.cache_sharding, self.bf16_enable, env=self
)
)
return caches
Expand Down
2 changes: 1 addition & 1 deletion tests/test_model_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def _make_one_cache_for_generate(self, env, pos):
(cache_array_k, cache_array_v)
)
cache_decode = cache_manager.KVCacheGenerate(
cache_array_k, cache_array_v, pos, None
cache_array_k, cache_array_v, pos, None, env
)
return cache_decode

Expand Down
15 changes: 12 additions & 3 deletions tests/test_quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,10 @@ def test_kv_cache(self):
"""test kv cache quantization"""
cache_shape = (3, 2, 100, 2) # bs, num heads, seqlen, dim
with jax.default_device(jax.devices("cpu")[0]):
cache = cache_manager.Int8KVCacheGenerate.empty(cache_shape, None, False)
env, _ = helpers.make_env_tiny()
cache = cache_manager.Int8KVCacheGenerate.empty(
cache_shape, None, False, env
)
# seqlen is 1
k = self._xla_tensor((3, 2, 1, 2))
v = self._xla_tensor((3, 2, 1, 2))
Expand Down Expand Up @@ -101,7 +104,7 @@ def test_kv_kernel(self):

cache_k, cache_v = torchjax.to_torch((cache_k_jax, cache_v_jax))

cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None)
cache = cache_manager.KVCacheGenerate(cache_k, cache_v, [0], None, env)

# 1 is seqlen
xq = jax.random.normal(key, (3, 2, 1, 2))
Expand All @@ -119,7 +122,13 @@ def test_kv_kernel(self):
cache_k_int, cache_k_scaler, _ = quantize_tensor(cache_k, (1, 3))
cache_v_int, cache_v_scaler, _ = quantize_tensor(cache_v, (1, 3))
cache_int = cache_manager.Int8KVCacheGenerate(
cache_k_int, cache_v_int, cache_k_scaler, cache_v_scaler, [0], None
cache_k_int,
cache_v_int,
cache_k_scaler,
cache_v_scaler,
[0],
None,
env,
)
attention_quant = layers.Int8KVAttentionKernel(env)
int_res = attention_quant(xq, xk, xv, None, cache_int)
Expand Down
Loading