Skip to content

Commit

Permalink
Make prefilling return first token for loadgen integration (#143)
Browse files Browse the repository at this point in the history
* Make prefilling return first token for loadgen integration

* minor fix and lint

* enable passing of max_decode_length as a flag
  • Loading branch information
sixiang-google authored and wang2yn84 committed Jul 18, 2024
1 parent 17ab200 commit 8675c30
Show file tree
Hide file tree
Showing 9 changed files with 36 additions and 18 deletions.
2 changes: 1 addition & 1 deletion benchmarks/prefill_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def prefill_benchmark(tokens_list, engine, params, warmup):
# pylint: disable-next=all
warmup_text = "warmup" if warmup else "execute"
it = time.time()
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params,
padded_tokens=prefill_tokens,
true_length=len(prefill_tokens),
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def run_prefill_time(engine, params, decode_state, seqlen):
)

for _ in range(3):
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = engine.insert(
Expand All @@ -58,7 +58,7 @@ def run_prefill_time(engine, params, decode_state, seqlen):
jax.profiler.start_trace(FLAGS.profiling_output)
profiler_started = True

prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = engine.insert(
Expand Down
2 changes: 1 addition & 1 deletion deps/JetStream
2 changes: 2 additions & 0 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
flags.DEFINE_string("size", "tiny", "size of model")
flags.DEFINE_bool("quantize_kv_cache", False, "kv_cache_quantize")
flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize")
flags.DEFINE_integer("max_decode_length", 1024, "max length of generated text")
flags.DEFINE_string("sharding_config", "", "config file for sharding")
flags.DEFINE_bool(
"shard_on_batch",
Expand Down Expand Up @@ -197,6 +198,7 @@ def create_engine_from_config_flags():
batch_size=FLAGS.batch_size,
quant_config=quant_config,
max_cache_length=FLAGS.max_cache_length,
max_decode_length=FLAGS.max_decode_length,
sharding_config=sharding_file_name,
shard_on_batch=FLAGS.shard_on_batch,
ragged_mha=FLAGS.ragged_mha,
Expand Down
26 changes: 21 additions & 5 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __init__(
jax.config.update("jax_enable_x64", False)

self.prefill = jax.jit(
self.prefill, out_shardings=self.get_prefix_destination_sharding()
self.prefill,
out_shardings=(self.get_prefix_destination_sharding(), None),
)
self.insert = jax.jit(
self.insert,
Expand Down Expand Up @@ -247,7 +248,7 @@ def prefill(
existing_prefix: Optional[Prefix] = None,
padded_tokens: PrefillInputs, # PrefillInputs[jax.Array],
true_length: int,
) -> Prefix:
) -> Tuple[Prefix, engine_api.ResultTokens]:
if isinstance(padded_tokens, jax.Array):
batched_token = padded_tokens.reshape(1, -1)
else:
Expand All @@ -264,7 +265,6 @@ def prefill(
)
if len(logits.shape) == 3: # b, seqlen, num words
logits = logits[0] # seqlen, num words

token = sampling_utils.sampling(
logits[true_length - 1],
self.rng,
Expand All @@ -273,7 +273,23 @@ def prefill(
self.env.nucleus_topp,
self.env.temperature,
)

token_out = jnp.reshape(token, (1, 1))
data = jnp.concatenate(
[
token_out, # First token
jnp.ones_like(token_out), # validity of first token
jnp.zeros((1, 1), dtype=jnp.int32), # length = 0
],
axis=-1,
)
length = token_out.shape[1]
result = engine_api.ResultTokens(
data=data,
tokens_idx=(0, length),
valid_idx=(length, 2 * length),
length_idx=(2 * length, 2 * length + 1),
samples_per_slot=1,
)
# truncate to true_length didnt work need to be out side of jit
# caches = [
# (jax.lax.dynamic_slice_in_dim(
Expand All @@ -282,7 +298,7 @@ def prefill(
# v, seq_len - true_length, true_length, axis=2))
# for k, v in updated_caches
# ]
return Prefix(token, updated_caches, true_length)
return Prefix(token, updated_caches, true_length), result

def shrink_prefix(
self,
Expand Down
2 changes: 1 addition & 1 deletion run_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def main(argv):
# pylint: disable-next=all
if profiling_prefill:
jax.profiler.start_trace(profiling_output)
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
# pylint: disable-next=all
Expand Down
2 changes: 1 addition & 1 deletion run_interactive_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def main(argv):
print(
f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}"
)
prefill_result = prefill_engine.prefill(
prefill_result, _ = prefill_engine.prefill(
params=None, padded_tokens=tokens, true_length=true_length
)
print(
Expand Down
2 changes: 1 addition & 1 deletion run_interactive_multiple_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def main(argv):
print(f"---- Encoded tokens are: {tokens}")

# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=None, padded_tokens=tokens, true_length=true_length
)
# pylint: disable-next=all
Expand Down
12 changes: 6 additions & 6 deletions tests/test_llama_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def test_jetstream_llama2_seed(self):
decode_state = engine.init_decode_state()
slot = 0
# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)

Expand Down Expand Up @@ -197,7 +197,7 @@ def _llama_e2e(self, env, model_arg):
decode_state = engine.init_decode_state()
slot = 0
# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)

Expand Down Expand Up @@ -334,7 +334,7 @@ def test_llama_e2e_two_addtional_tokens(self):
slot = 0

# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)

Expand Down Expand Up @@ -406,7 +406,7 @@ def test_llama_e2e_four_addtional_tokens(self):
slot = 0

# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)

Expand Down Expand Up @@ -472,7 +472,7 @@ def test_llama_with_original_prefill_decode_32(self):
# pylint: disable-next=all
decode_state = engine.init_decode_state()
# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)
out_tokens = prefill_result.token
Expand Down Expand Up @@ -547,7 +547,7 @@ def test_llama_with_original_prefill_decode(self):
# pylint: disable-next=all
decode_state = engine.init_decode_state()
# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)
out_tokens = prefill_result.token
Expand Down

0 comments on commit 8675c30

Please sign in to comment.