Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate disaggregated serving with JetStream #117

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Iterable, Optional, Union
from typing import Any, Iterable, Optional, Union, Tuple, List

import numpy as np
import ray
Expand Down Expand Up @@ -180,7 +180,9 @@ def create_pytorch_ray_engine(
decode_pod_slice_name: str = None,
enable_jax_profiler: bool = False,
jax_profiler_port: int = 9999,
) -> Any:
) -> Union[
PyTorchRayEngine, Tuple[List[PyTorchRayEngine], List[PyTorchRayEngine]]
]:

# Return tuple as reponse: issues/107
supported_models = ["llama-2", "llama-3", "gemma"]
Expand Down Expand Up @@ -254,4 +256,4 @@ def create_pytorch_ray_engine(
is_disaggregated=is_disaggregated,
pod_slice_name=decode_pod_slice_name,
)
return (prefill_engine, decode_engine)
return ([prefill_engine], [decode_engine])
34 changes: 18 additions & 16 deletions run_interactive_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,27 @@ def create_disaggregated_engines():
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

start = time.perf_counter()
prefill_engine, decode_engine = ray_engine.create_pytorch_ray_engine(
model_name=_MODEL_NAME.value,
tokenizer_path=_TOKENIZER_PATH.value,
ckpt_path=_CKPT_PATH.value,
bf16_enable=True,
param_size=_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
sharding_config=_SHARDING_CONFIG.value,
is_disaggregated=_IS_DISAGGREGATED.value,
num_hosts=_NUM_HOSTS.value,
decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value,
prefill_engine_list, decode_engine_list = (
ray_engine.create_pytorch_ray_engine(
model_name=_MODEL_NAME.value,
tokenizer_path=_TOKENIZER_PATH.value,
ckpt_path=_CKPT_PATH.value,
bf16_enable=True,
param_size=_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
sharding_config=_SHARDING_CONFIG.value,
is_disaggregated=_IS_DISAGGREGATED.value,
num_hosts=_NUM_HOSTS.value,
decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value,
)
)

print("Initialize engine", time.perf_counter() - start)
return (prefill_engine, decode_engine)
return (prefill_engine_list[0], decode_engine_list[0])


# pylint: disable-next=all
Expand Down
61 changes: 56 additions & 5 deletions run_server_with_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")

flags.DEFINE_bool(
"is_disaggregated", False, "Disaggregated serving if it's True"
)

flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False)

flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name")


def create_engine():
"""create a pytorch engine"""
Expand Down Expand Up @@ -64,6 +72,37 @@ def create_engine():
return engine


def create_disaggregated_engine():
"""create a pytorch engine"""
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

start = time.perf_counter()
prefill_engine_list, decode_engine_list = (
ray_engine.create_pytorch_ray_engine(
model_name=FLAGS.model_name,
tokenizer_path=FLAGS.tokenizer_path,
ckpt_path=FLAGS.checkpoint_path,
bf16_enable=FLAGS.bf16_enable,
param_size=FLAGS.size,
context_length=FLAGS.context_length,
batch_size=FLAGS.batch_size,
quantize_weights=FLAGS.quantize_weights,
quantize_kv=FLAGS.quantize_kv_cache,
max_cache_length=FLAGS.max_cache_length,
sharding_config=FLAGS.sharding_config,
enable_jax_profiler=FLAGS.enable_jax_profiler,
jax_profiler_port=FLAGS.jax_profiler_port,
is_disaggregated=FLAGS.is_disaggregated,
num_hosts=FLAGS.num_hosts,
decode_pod_slice_name=FLAGS.decode_pod_slice_name,
)
)

print("Initialize engine", time.perf_counter() - start)
return (prefill_engine_list, decode_engine_list)


# pylint: disable-next=all
def main(argv: Sequence[str]):
del argv
Expand All @@ -74,12 +113,24 @@ def main(argv: Sequence[str]):

print(f"devices: {devices}")

engine = create_engine()
if FLAGS.is_disaggregated:
prefill_engine_list, decode_engine_list = create_disaggregated_engine()
chips = int(len(devices) / 2)
server_config = ServerConfig(
prefill_slices=(f"tpu={chips}",),
prefill_engine_create_fns=(lambda a: prefill_engine_list[0],),
generate_slices=(f"tpu={chips}",),
generate_engine_create_fns=(lambda a: decode_engine_list[0],),
is_ray_backend=True,
)

else:
engine = create_engine()
server_config = ServerConfig(
interleaved_slices=(f"tpu={len(devices)}",),
interleaved_engine_create_fns=(lambda a: engine,),
)

server_config = ServerConfig(
interleaved_slices=(f"tpu={len(devices)}",),
interleaved_engine_create_fns=(lambda a: engine,),
)
print(f"server_config: {server_config}")

jetstream_server = server_lib.run(
Expand Down
Loading