Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

V5e8 ray #159

Merged
merged 5 commits into from
Jul 30, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Note: Get address ip and port information from ray head.
Here is an example to run the server with ray for llama2 7B model:

```bash
python run_server_with_ray.py --tpu_chips=16 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
python run_server_with_ray.py --tpu_chips=16 --num_hosts=4 --worker_chips=4 -model_name=$model_name --size=7b --batch_size=96 --max_cache_length=2048 --quantize_weights=$quantize --quantize_type=$quantize_type --quantize_kv_cache=$quantize --checkpoint_path=$output_ckpt_dir --tokenizer_path=$tokenizer_path --sharding_config="default_shardings/llama.yaml"
```

# Run benchmark
Expand Down
2 changes: 1 addition & 1 deletion install_everything.sh
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ pip install tensorflow-text
pip install tensorflow
pip install huggingface_hub

pip install ray[default]==2.22.0
pip install ray[default]==2.33.0
# torch cpu
pip install torch==2.3.1+cpu --index-url https://download.pytorch.org/whl/cpu
pip install tensorflow flatbuffers absl-py sentencepiece seqio google-cloud-storage
Expand Down
7 changes: 3 additions & 4 deletions jetstream_pt/ray_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def create_pytorch_ray_engine(
sharding_config=None,
is_disaggregated: bool = False,
num_hosts: int = 0,
worker_chips: int = 0,
decode_pod_slice_name: str = None,
enable_jax_profiler: bool = False,
jax_profiler_port: int = 9999,
Expand All @@ -230,9 +231,7 @@ def create_pytorch_ray_engine(
)
ray.init(ignore_reinit_error=True)
pod_name = tpu.get_current_pod_name()
num_hosts = (
num_hosts if is_disaggregated else tpu.get_current_pod_worker_count()
)
num_hosts = num_hosts if num_hosts > 0 else tpu.get_current_pod_worker_count()
print(f"pod_name:{pod_name}, number of host: {num_hosts}")
assert (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit, consider adding more assertion to check the number of hosts * working chips == total devices.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, added.

pod_name is not None
Expand All @@ -242,7 +241,7 @@ def create_pytorch_ray_engine(
), f"num_hosts (current value {num_hosts}) should be a positive number"
# pylint: disable-next=all
engine_worker_with_tpu_resource = PyTorchRayWorker.options(
resources={"TPU": 4},
resources={"TPU": worker_chips if worker_chips > 0 else 4},
runtime_env=RuntimeEnv(env_vars={"JAX_PLATFORMS": "tpu,cpu"}),
)
engine_workers = []
Expand Down
10 changes: 10 additions & 0 deletions run_interactive_multiple_host.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
from jetstream_pt import ray_engine
from jetstream_pt.config import FLAGS

_NUM_HOSTS = flags.DEFINE_integer(
"num_hosts", 0, "Number of TPU host", required=False
)

_WORKER_CHIPS = flags.DEFINE_integer(
"worker_chips", 4, "Number of TPU chips per worker", required=False
)


def create_engine():
"""create a pytorch engine"""
Expand All @@ -43,6 +51,8 @@ def create_engine():
quantize_kv=FLAGS.quantize_kv_cache,
max_cache_length=FLAGS.max_cache_length,
sharding_config=FLAGS.sharding_config,
num_hosts=_NUM_HOSTS.value,
worker_chips=_WORKER_CHIPS.value,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down
8 changes: 7 additions & 1 deletion run_server_with_ray.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,11 @@
"is_disaggregated", False, "Disaggregated serving if it's True"
)

flags.DEFINE_integer("num_hosts", 4, "Number of TPU host", required=False)
flags.DEFINE_integer("num_hosts", 0, "Number of TPU host", required=False)

flags.DEFINE_integer(
"worker_chips", 4, "Number of TPU chips per worker", required=False
)

flags.DEFINE_string("decode_pod_slice_name", "", "Decode pod slice name")

Expand All @@ -68,6 +72,8 @@ def create_engine():
sharding_config=FLAGS.sharding_config,
enable_jax_profiler=FLAGS.enable_jax_profiler,
jax_profiler_port=FLAGS.jax_profiler_port,
num_hosts=FLAGS.num_hosts,
worker_chips=FLAGS.worker_chips,
)

print("Initialize engine", time.perf_counter() - start)
Expand Down
Loading