Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make prefilling return first token for loadgen integration #143

Merged
merged 4 commits into from
Jul 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/prefill_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def prefill_benchmark(tokens_list, engine, params, warmup):
# pylint: disable-next=all
warmup_text = "warmup" if warmup else "execute"
it = time.time()
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params,
padded_tokens=prefill_tokens,
true_length=len(prefill_tokens),
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/run_offline.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def run_prefill_time(engine, params, decode_state, seqlen):
)

for _ in range(3):
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = engine.insert(
Expand All @@ -53,7 +53,7 @@ def run_prefill_time(engine, params, decode_state, seqlen):
nums = 5
start = time.perf_counter()
for i in range(nums):
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
decode_state = engine.insert(
Expand Down
2 changes: 1 addition & 1 deletion deps/JetStream
2 changes: 2 additions & 0 deletions jetstream_pt/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
flags.DEFINE_string("size", "tiny", "size of model")
flags.DEFINE_bool("quantize_kv_cache", False, "kv_cache_quantize")
flags.DEFINE_integer("max_cache_length", 1024, "kv_cache_quantize")
flags.DEFINE_integer("max_decode_length", 1024, "max length of generated text")
flags.DEFINE_string("sharding_config", "", "config file for sharding")
flags.DEFINE_bool(
"shard_on_batch",
Expand Down Expand Up @@ -173,6 +174,7 @@ def create_engine_from_config_flags():
batch_size=FLAGS.batch_size,
quant_config=quant_config,
max_cache_length=FLAGS.max_cache_length,
max_decode_length=FLAGS.max_decode_length,
sharding_config=sharding_file_name,
shard_on_batch=FLAGS.shard_on_batch,
ragged_mha=FLAGS.ragged_mha,
Expand Down
26 changes: 21 additions & 5 deletions jetstream_pt/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ def __init__(
jax.config.update("jax_enable_x64", False)

self.prefill = jax.jit(
self.prefill, out_shardings=self.get_prefix_destination_sharding()
self.prefill,
out_shardings=(self.get_prefix_destination_sharding(), None),
)
self.insert = jax.jit(
self.insert,
Expand Down Expand Up @@ -243,7 +244,7 @@ def prefill(
existing_prefix: Optional[Prefix] = None,
padded_tokens: PrefillInputs, # PrefillInputs[jax.Array],
true_length: int,
) -> Prefix:
) -> Tuple[Prefix, engine_api.ResultTokens]:
if isinstance(padded_tokens, jax.Array):
batched_token = padded_tokens.reshape(1, -1)
else:
Expand All @@ -260,7 +261,6 @@ def prefill(
)
if len(logits.shape) == 3: # b, seqlen, num words
logits = logits[0] # seqlen, num words

token = sampling_utils.sampling(
logits[true_length - 1],
self.rng,
Expand All @@ -269,7 +269,23 @@ def prefill(
self.env.nucleus_topp,
self.env.temperature,
)

token_out = jnp.reshape(token, (1, 1))
data = jnp.concatenate(
[
token_out, # First token
jnp.ones_like(token_out), # validity of first token
jnp.zeros((1, 1), dtype=jnp.int32), # length = 0
],
axis=-1,
)
length = token_out.shape[1]
result = engine_api.ResultTokens(
data=data,
tokens_idx=(0, length),
valid_idx=(length, 2 * length),
length_idx=(2 * length, 2 * length + 1),
samples_per_slot=1,
)
# truncate to true_length didnt work need to be out side of jit
# caches = [
# (jax.lax.dynamic_slice_in_dim(
Expand All @@ -278,7 +294,7 @@ def prefill(
# v, seq_len - true_length, true_length, axis=2))
# for k, v in updated_caches
# ]
return Prefix(token, updated_caches, true_length)
return Prefix(token, updated_caches, true_length), result

def shrink_prefix(
self,
Expand Down
2 changes: 1 addition & 1 deletion run_interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def main(argv):
print(f"---- Encoded tokens are: {tokens}")

# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=tokens, true_length=true_length
)
# pylint: disable-next=all
Expand Down
2 changes: 1 addition & 1 deletion run_interactive_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def main(argv):
print(
f"---- Do prefill in prefill engine pod_slice_name: {prefill_engine.pod_slice_name}"
)
prefill_result = prefill_engine.prefill(
prefill_result, _ = prefill_engine.prefill(
params=None, padded_tokens=tokens, true_length=true_length
)
print(
Expand Down
2 changes: 1 addition & 1 deletion run_interactive_multiple_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def main(argv):
print(f"---- Encoded tokens are: {tokens}")

# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=None, padded_tokens=tokens, true_length=true_length
)
# pylint: disable-next=all
Expand Down
12 changes: 6 additions & 6 deletions tests/test_llama_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ def test_jetstream_llama2_seed(self):
decode_state = engine.init_decode_state()
slot = 0
# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)

Expand Down Expand Up @@ -193,7 +193,7 @@ def _llama_e2e(self, env, model_arg):
decode_state = engine.init_decode_state()
slot = 0
# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)

Expand Down Expand Up @@ -278,7 +278,7 @@ def test_llama_e2e_two_addtional_tokens(self):
slot = 0

# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add the first_token to the out_tokens array (Line 288), otherwise out_tokens will not equal expected_output_tokens. Same for all other calls..

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

params=params, padded_tokens=padded_tokens, true_length=true_length
)

Expand Down Expand Up @@ -350,7 +350,7 @@ def test_llama_e2e_four_addtional_tokens(self):
slot = 0

# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)

Expand Down Expand Up @@ -416,7 +416,7 @@ def test_llama_with_original_prefill_decode_32(self):
# pylint: disable-next=all
decode_state = engine.init_decode_state()
# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)
out_tokens = prefill_result.token
Expand Down Expand Up @@ -491,7 +491,7 @@ def test_llama_with_original_prefill_decode(self):
# pylint: disable-next=all
decode_state = engine.init_decode_state()
# pylint: disable-next=all
prefill_result = engine.prefill(
prefill_result, _ = engine.prefill(
params=params, padded_tokens=padded_tokens, true_length=true_length
)
out_tokens = prefill_result.token
Expand Down
Loading