Skip to content

Commit

Permalink
Integrate disaggregated serving with JetStream (#117)
Browse files Browse the repository at this point in the history
* add diaggregated server with ray support

* add run_server wity ray

* format
  • Loading branch information
FanhaiLu1 authored Jun 6, 2024
1 parent 7f6e45f commit 52ec00f
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 24 deletions.
8 changes: 5 additions & 3 deletions jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from collections import defaultdict
from typing import Any, Iterable, Optional, Union
from typing import Any, Iterable, Optional, Union, Tuple, List

import numpy as np
import ray
Expand Down Expand Up @@ -180,7 +180,9 @@ def create_pytorch_ray_engine(
decode_pod_slice_name: str = None,
enable_jax_profiler: bool = False,
jax_profiler_port: int = 9999,
) -> Any:
) -> Union[
PyTorchRayEngine, Tuple[List[PyTorchRayEngine], List[PyTorchRayEngine]]
]:

# Return tuple as reponse: issues/107
supported_models = ["llama-2", "llama-3", "gemma"]
Expand Down Expand Up @@ -254,4 +256,4 @@ def create_pytorch_ray_engine(
is_disaggregated=is_disaggregated,
pod_slice_name=decode_pod_slice_name,
)
return (prefill_engine, decode_engine)
return ([prefill_engine], [decode_engine])
34 changes: 18 additions & 16 deletions run_interactive_disaggregated.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,25 +94,27 @@ def create_disaggregated_engines():
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

start = time.perf_counter()
prefill_engine, decode_engine = ray_engine.create_pytorch_ray_engine(
model_name=_MODEL_NAME.value,
tokenizer_path=_TOKENIZER_PATH.value,
ckpt_path=_CKPT_PATH.value,
bf16_enable=True,
param_size=_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
sharding_config=_SHARDING_CONFIG.value,
is_disaggregated=_IS_DISAGGREGATED.value,
num_hosts=_NUM_HOSTS.value,
decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value,
prefill_engine_list, decode_engine_list = (
ray_engine.create_pytorch_ray_engine(
model_name=_MODEL_NAME.value,
tokenizer_path=_TOKENIZER_PATH.value,
ckpt_path=_CKPT_PATH.value,
bf16_enable=True,
param_size=_SIZE.value,
context_length=_CONTEXT_LENGTH.value,
batch_size=_BATCH_SIZE.value,
quantize_weights=_QUANTIZE_WEIGHTS.value,
quantize_kv=_QUANTIZE_KV_CACHE.value,
max_cache_length=_MAX_CACHE_LENGTH.value,
sharding_config=_SHARDING_CONFIG.value,
is_disaggregated=_IS_DISAGGREGATED.value,
num_hosts=_NUM_HOSTS.value,
decode_pod_slice_name=_DECODE_POD_SLICE_NAME.value,
)
)

print("Initialize engine", time.perf_counter() - start)
return (prefill_engine, decode_engine)
return (prefill_engine_list[0], decode_engine_list[0])


# pylint: disable-next=all
Expand Down
61 changes: 56 additions & 5 deletions run_server_with_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,14 @@
flags.DEFINE_bool("enable_jax_profiler", False, "enable jax profiler")
flags.DEFINE_integer("jax_profiler_port", 9999, "port of JAX profiler server")

flags.DEFINE_bool(
"is_disaggregated", False, "Disaggregated serving if it's True"
)

flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False)

flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name")


def create_engine():
"""create a pytorch engine"""
Expand Down Expand Up @@ -64,6 +72,37 @@ def create_engine():
return engine


def create_disaggregated_engine():
"""create a pytorch engine"""
jax.config.update("jax_default_prng_impl", "unsafe_rbg")
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0"

start = time.perf_counter()
prefill_engine_list, decode_engine_list = (
ray_engine.create_pytorch_ray_engine(
model_name=FLAGS.model_name,
tokenizer_path=FLAGS.tokenizer_path,
ckpt_path=FLAGS.checkpoint_path,
bf16_enable=FLAGS.bf16_enable,
param_size=FLAGS.size,
context_length=FLAGS.context_length,
batch_size=FLAGS.batch_size,
quantize_weights=FLAGS.quantize_weights,
quantize_kv=FLAGS.quantize_kv_cache,
max_cache_length=FLAGS.max_cache_length,
sharding_config=FLAGS.sharding_config,
enable_jax_profiler=FLAGS.enable_jax_profiler,
jax_profiler_port=FLAGS.jax_profiler_port,
is_disaggregated=FLAGS.is_disaggregated,
num_hosts=FLAGS.num_hosts,
decode_pod_slice_name=FLAGS.decode_pod_slice_name,
)
)

print("Initialize engine", time.perf_counter() - start)
return (prefill_engine_list, decode_engine_list)


# pylint: disable-next=all
def main(argv: Sequence[str]):
del argv
Expand All @@ -74,12 +113,24 @@ def main(argv: Sequence[str]):

print(f"devices: {devices}")

engine = create_engine()
if FLAGS.is_disaggregated:
prefill_engine_list, decode_engine_list = create_disaggregated_engine()
chips = int(len(devices) / 2)
server_config = ServerConfig(
prefill_slices=(f"tpu={chips}",),
prefill_engine_create_fns=(lambda a: prefill_engine_list[0],),
generate_slices=(f"tpu={chips}",),
generate_engine_create_fns=(lambda a: decode_engine_list[0],),
is_ray_backend=True,
)

else:
engine = create_engine()
server_config = ServerConfig(
interleaved_slices=(f"tpu={len(devices)}",),
interleaved_engine_create_fns=(lambda a: engine,),
)

server_config = ServerConfig(
interleaved_slices=(f"tpu={len(devices)}",),
interleaved_engine_create_fns=(lambda a: engine,),
)
print(f"server_config: {server_config}")

jetstream_server = server_lib.run(
Expand Down

0 comments on commit 52ec00f

Please sign in to comment.