Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DO NOT MERGE] Merged PRs for verl integration #2849

Draft
wants to merge 589 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
589 commits
Select commit Hold shift + click to select a range
664eeaa
more
fzyzcjy Jan 8, 2025
4564450
more
fzyzcjy Jan 8, 2025
f8a31a7
more
fzyzcjy Jan 8, 2025
a73b3fb
cp
fzyzcjy Jan 8, 2025
f1faa4a
more
fzyzcjy Jan 8, 2025
f14b4fd
more
fzyzcjy Jan 8, 2025
4a53cba
more
fzyzcjy Jan 8, 2025
e08c5cc
more
fzyzcjy Jan 8, 2025
2c3b4cb
more
fzyzcjy Jan 8, 2025
e4d224d
more
fzyzcjy Jan 8, 2025
007dceb
more
fzyzcjy Jan 8, 2025
3048da5
more
fzyzcjy Jan 8, 2025
2568788
more
fzyzcjy Jan 8, 2025
f258286
more
fzyzcjy Jan 8, 2025
3ccc372
fix
fzyzcjy Jan 8, 2025
73af0d3
more
fzyzcjy Jan 8, 2025
967788d
Merge branch 'main' into feat/refactor_many
fzyzcjy Jan 8, 2025
58e4d92
update
fzyzcjy Jan 8, 2025
20a18c8
Merge branch 'feat/refactor_many' into feat/refactor_layer
fzyzcjy Jan 8, 2025
2bbb75c
more
fzyzcjy Jan 8, 2025
066515b
Merge branch 'feat/refactor_many' into feat/refactor_layer
fzyzcjy Jan 8, 2025
81b70b9
fmt
fzyzcjy Jan 8, 2025
e1dd31f
fmt
fzyzcjy Jan 8, 2025
61d972a
Merge branch 'feat/refactor_many' into feat/refactor_layer
fzyzcjy Jan 8, 2025
4ce7f7c
simp
fzyzcjy Jan 8, 2025
6baabdb
more
fzyzcjy Jan 8, 2025
8fc897e
Merge branch 'feat/refactor_many' into feat/refactor_layer
fzyzcjy Jan 8, 2025
e0a81cd
Merge branch 'main' into feat/refactor_many
fzyzcjy Jan 8, 2025
9919be9
Merge branch 'feat/refactor_many' into feat/refactor_layer
fzyzcjy Jan 8, 2025
d1faea4
more
fzyzcjy Jan 8, 2025
de472bc
fmt
fzyzcjy Jan 8, 2025
2ed857c
more
fzyzcjy Jan 8, 2025
3ece2bc
fmt
fzyzcjy Jan 8, 2025
7dd4371
fmt
fzyzcjy Jan 8, 2025
700f3d3
cp
fzyzcjy Jan 8, 2025
1c1a886
more
fzyzcjy Jan 8, 2025
179da2e
more
fzyzcjy Jan 8, 2025
df6fd3a
more
fzyzcjy Jan 8, 2025
370ed33
more
fzyzcjy Jan 8, 2025
b46abca
more
fzyzcjy Jan 8, 2025
732f258
more
fzyzcjy Jan 8, 2025
1a1e6f0
more
fzyzcjy Jan 8, 2025
6c6ecf0
more
fzyzcjy Jan 8, 2025
2de5826
more
fzyzcjy Jan 8, 2025
3c6e15c
fix
fzyzcjy Jan 8, 2025
698d771
fix
fzyzcjy Jan 8, 2025
12ff201
fmt
fzyzcjy Jan 8, 2025
435d5d6
Revert "fmt"
fzyzcjy Jan 8, 2025
b047d04
more
fzyzcjy Jan 8, 2025
707df61
more
fzyzcjy Jan 8, 2025
8853190
more
fzyzcjy Jan 8, 2025
b6699de
more
fzyzcjy Jan 8, 2025
67c702f
more
fzyzcjy Jan 8, 2025
76268bd
cp original
fzyzcjy Jan 8, 2025
ff5e752
rm adhoc
fzyzcjy Jan 8, 2025
905efc3
try rm
fzyzcjy Jan 10, 2025
8e1be9e
empty
fzyzcjy Jan 10, 2025
68ff88d
cp
fzyzcjy Jan 10, 2025
fe49867
more
fzyzcjy Jan 10, 2025
7fcf036
more
fzyzcjy Jan 10, 2025
16cb914
cp
fzyzcjy Jan 10, 2025
6ca3732
more
fzyzcjy Jan 10, 2025
12f6b61
more
fzyzcjy Jan 10, 2025
834e121
more
fzyzcjy Jan 10, 2025
319e00a
more
fzyzcjy Jan 10, 2025
7ce6fde
more
fzyzcjy Jan 10, 2025
aaa01a4
more
fzyzcjy Jan 10, 2025
2de8d7e
more
fzyzcjy Jan 10, 2025
ca292a5
more
fzyzcjy Jan 10, 2025
c1e766a
more
fzyzcjy Jan 10, 2025
2f43fba
more
fzyzcjy Jan 10, 2025
b75a87b
pass around
fzyzcjy Jan 10, 2025
87e81e8
more around
fzyzcjy Jan 10, 2025
0857374
more
fzyzcjy Jan 10, 2025
0a4113a
more
fzyzcjy Jan 10, 2025
f38f234
more
fzyzcjy Jan 10, 2025
46258f6
more
fzyzcjy Jan 10, 2025
488e065
more
fzyzcjy Jan 10, 2025
ff78931
more
fzyzcjy Jan 10, 2025
aaf343c
more
fzyzcjy Jan 10, 2025
dd346fd
more
fzyzcjy Jan 10, 2025
4e41524
more
fzyzcjy Jan 10, 2025
b08deb5
more
fzyzcjy Jan 10, 2025
899a4f6
rename
fzyzcjy Jan 10, 2025
f68e3a6
more
fzyzcjy Jan 10, 2025
c1fe9d0
more
fzyzcjy Jan 10, 2025
3559b90
more
fzyzcjy Jan 10, 2025
d8356ae
more
fzyzcjy Jan 10, 2025
12994e3
more
fzyzcjy Jan 10, 2025
a11d5a5
more
fzyzcjy Jan 10, 2025
a52ee3a
more
fzyzcjy Jan 10, 2025
d44fe54
more
fzyzcjy Jan 10, 2025
a902a7e
more
fzyzcjy Jan 10, 2025
e988bd3
more
fzyzcjy Jan 10, 2025
1912a7a
more
fzyzcjy Jan 10, 2025
965b1b7
more
fzyzcjy Jan 10, 2025
b0424c7
more
fzyzcjy Jan 10, 2025
38c648b
more
fzyzcjy Jan 10, 2025
38d7265
more
fzyzcjy Jan 10, 2025
af295b1
more
fzyzcjy Jan 10, 2025
4476c1e
more
fzyzcjy Jan 10, 2025
85df0a6
more
fzyzcjy Jan 10, 2025
be75a19
more
fzyzcjy Jan 10, 2025
8df7965
fmt
fzyzcjy Jan 10, 2025
10941d0
fmt
fzyzcjy Jan 10, 2025
7cb9e41
more
fzyzcjy Jan 10, 2025
9ea1cb1
more
fzyzcjy Jan 10, 2025
de0b1d4
more
fzyzcjy Jan 10, 2025
5301ce6
cp
fzyzcjy Jan 10, 2025
26b728f
more
fzyzcjy Jan 10, 2025
6cbbb32
more
fzyzcjy Jan 10, 2025
e78b98f
Revert "more"
fzyzcjy Jan 10, 2025
205a07b
more
fzyzcjy Jan 10, 2025
d7e8652
more
fzyzcjy Jan 10, 2025
2ddc725
more
fzyzcjy Jan 10, 2025
a64a610
more
fzyzcjy Jan 10, 2025
f132829
more
fzyzcjy Jan 10, 2025
8980295
extract
fzyzcjy Jan 10, 2025
9396760
forward
fzyzcjy Jan 10, 2025
48ee5c4
more
fzyzcjy Jan 10, 2025
701bc73
more
fzyzcjy Jan 10, 2025
0f1b017
more
fzyzcjy Jan 10, 2025
5e8009c
more
fzyzcjy Jan 10, 2025
d1284ab
more
fzyzcjy Jan 10, 2025
e1ae075
more
fzyzcjy Jan 10, 2025
1b25247
more
fzyzcjy Jan 10, 2025
d051f86
more
fzyzcjy Jan 10, 2025
e9020ea
rename
fzyzcjy Jan 10, 2025
6e1b545
more
fzyzcjy Jan 10, 2025
bb95a8a
more
fzyzcjy Jan 10, 2025
1729ae3
more
fzyzcjy Jan 10, 2025
1609898
more
fzyzcjy Jan 10, 2025
50104b8
more
fzyzcjy Jan 10, 2025
cd65f62
more
fzyzcjy Jan 10, 2025
da9fb86
more
fzyzcjy Jan 10, 2025
5242996
more
fzyzcjy Jan 10, 2025
d8be5eb
more
fzyzcjy Jan 10, 2025
8419d6a
more
fzyzcjy Jan 10, 2025
165b968
more
fzyzcjy Jan 10, 2025
e676e9b
more
fzyzcjy Jan 10, 2025
f789a41
more
fzyzcjy Jan 10, 2025
55b20f4
more
fzyzcjy Jan 10, 2025
8ffe95d
more
fzyzcjy Jan 10, 2025
934acd0
more
fzyzcjy Jan 10, 2025
d1aea30
more
fzyzcjy Jan 10, 2025
85737e5
Revert "more"
fzyzcjy Jan 10, 2025
fc86a06
Revert "more"
fzyzcjy Jan 10, 2025
51519a0
Revert "more"
fzyzcjy Jan 10, 2025
af2c780
Revert "more"
fzyzcjy Jan 10, 2025
4315918
Revert "more"
fzyzcjy Jan 10, 2025
429ba05
more
fzyzcjy Jan 10, 2025
d4670f5
more
fzyzcjy Jan 10, 2025
598742d
mv
fzyzcjy Jan 10, 2025
b4dbd4b
mv
fzyzcjy Jan 10, 2025
95f4158
more
fzyzcjy Jan 10, 2025
8910aff
Revert "more"
fzyzcjy Jan 10, 2025
bf2e87d
Revert "mv"
fzyzcjy Jan 10, 2025
511512d
Revert "mv"
fzyzcjy Jan 10, 2025
a1e0b6f
more
fzyzcjy Jan 10, 2025
4ec9df0
more
fzyzcjy Jan 10, 2025
78da75d
more
fzyzcjy Jan 10, 2025
557a97b
more
fzyzcjy Jan 10, 2025
f74d5a0
more
fzyzcjy Jan 10, 2025
8698234
more
fzyzcjy Jan 10, 2025
2a965bc
more
fzyzcjy Jan 10, 2025
8b39ac3
more
fzyzcjy Jan 10, 2025
dc72169
more
fzyzcjy Jan 10, 2025
9023a77
more
fzyzcjy Jan 10, 2025
d3197c2
more
fzyzcjy Jan 10, 2025
865b429
more
fzyzcjy Jan 10, 2025
e119bd0
more
fzyzcjy Jan 10, 2025
523ff85
more
fzyzcjy Jan 10, 2025
6f1d381
fmt
fzyzcjy Jan 10, 2025
347b5ea
dbg
fzyzcjy Jan 10, 2025
a272797
Merge branch 'feat/device_mesh' into feat/weight_dtensor
fzyzcjy Jan 10, 2025
84b165f
fmt
fzyzcjy Jan 10, 2025
6ba871e
Merge branch 'main' into feat/refactor_many
fzyzcjy Jan 10, 2025
c046621
merge
fzyzcjy Jan 10, 2025
e09e9ce
fmt
fzyzcjy Jan 10, 2025
881ef94
Merge branch 'feat/refactor_many' into feat/refactor_layer
fzyzcjy Jan 10, 2025
a980d86
merge
fzyzcjy Jan 11, 2025
e1cb699
Merge branch 'feat/refactor_layer' into feat/device_mesh
fzyzcjy Jan 11, 2025
5e04faa
Merge branch 'feat/device_mesh' into feat/weight_dtensor
fzyzcjy Jan 11, 2025
bd9a505
more tests
fzyzcjy Jan 11, 2025
04a4948
Merge branch 'feat/device_mesh' into feat/weight_dtensor
fzyzcjy Jan 11, 2025
62a8c9a
move to 2gpu
fzyzcjy Jan 11, 2025
061e87d
Merge branch 'feat/device_mesh' into feat/weight_dtensor
fzyzcjy Jan 11, 2025
be93f62
more tests
fzyzcjy Jan 11, 2025
f40d605
Merge branch 'feat/device_mesh' into feat/weight_dtensor
fzyzcjy Jan 11, 2025
883aef7
Merge branch 'main' into feat/refactor_many
fzyzcjy Jan 12, 2025
fd14677
merge
fzyzcjy Jan 12, 2025
19a15d8
Merge branch 'feat/refactor_many' into feat/refactor_layer
fzyzcjy Jan 12, 2025
7596441
merge
fzyzcjy Jan 12, 2025
926b465
Merge branch 'feat/refactor_layer' into feat/device_mesh
fzyzcjy Jan 12, 2025
4786085
Merge branch 'feat/device_mesh' into feat/weight_dtensor
fzyzcjy Jan 12, 2025
4f04950
logging
fzyzcjy Jan 12, 2025
765213e
merge
fzyzcjy Jan 12, 2025
4e4e335
Merge branch 'feat/refactor_layer' into feat/device_mesh
fzyzcjy Jan 12, 2025
380760b
Merge branch 'feat/device_mesh' into feat/weight_dtensor
fzyzcjy Jan 12, 2025
136fac9
lint
fzyzcjy Jan 12, 2025
9080854
Merge branch 'feat/refactor_layer' into feat/device_mesh
fzyzcjy Jan 12, 2025
e884e54
Merge branch 'feat/device_mesh' into feat/weight_dtensor
fzyzcjy Jan 12, 2025
6ebce3e
bump to test flaky ci
fzyzcjy Jan 12, 2025
e23a7a7
fix ci
fzyzcjy Jan 12, 2025
3507a92
pass around
fzyzcjy Jan 12, 2025
9b5b704
more
fzyzcjy Jan 12, 2025
67a2ad7
more
fzyzcjy Jan 12, 2025
f68c601
impl
fzyzcjy Jan 12, 2025
a2a188a
more
fzyzcjy Jan 12, 2025
a4e6f70
more
fzyzcjy Jan 12, 2025
ee3e311
more
fzyzcjy Jan 12, 2025
8be0a2e
more
fzyzcjy Jan 12, 2025
bc5c234
more
fzyzcjy Jan 12, 2025
4a5f2dd
fmt
fzyzcjy Jan 12, 2025
a86a0db
fmt
fzyzcjy Jan 12, 2025
69bb1ab
doc
fzyzcjy Jan 12, 2025
65f25fa
fmt
fzyzcjy Jan 12, 2025
13d944a
import
fzyzcjy Jan 12, 2025
92f40ba
bump ci to test flaky
fzyzcjy Jan 12, 2025
cad7433
fix ci (seems vllm version problem)
fzyzcjy Jan 12, 2025
eae26b0
Revert "bump ci to test flaky"
fzyzcjy Jan 12, 2025
c7b0c7a
fix ci (seems vllm version problem)
fzyzcjy Jan 13, 2025
2cc75cb
Merge branch 'feat/memory_saver' into feat/overall_verl
fzyzcjy Jan 13, 2025
e691e00
more
fzyzcjy Jan 13, 2025
c7586f3
more
fzyzcjy Jan 13, 2025
9d2b897
more
fzyzcjy Jan 13, 2025
d624733
more
fzyzcjy Jan 13, 2025
d550864
more
fzyzcjy Jan 13, 2025
dc17988
more
fzyzcjy Jan 13, 2025
0dcade1
more
fzyzcjy Jan 13, 2025
aaa3773
more
fzyzcjy Jan 13, 2025
a83eadf
fmt
fzyzcjy Jan 13, 2025
98aa981
more
fzyzcjy Jan 13, 2025
3fa43be
lint
fzyzcjy Jan 13, 2025
5189341
more
fzyzcjy Jan 13, 2025
15fc8f9
fix test
fzyzcjy Jan 13, 2025
2db9804
adhoc temp
fzyzcjy Jan 13, 2025
62785a9
more
fzyzcjy Jan 13, 2025
d9bb69b
adhoc
fzyzcjy Jan 13, 2025
1a8c5c7
more
fzyzcjy Jan 13, 2025
12a5f20
Revert "adhoc"
fzyzcjy Jan 13, 2025
56a1d40
Revert "adhoc temp"
fzyzcjy Jan 13, 2025
43eb65e
adhoc
fzyzcjy Jan 13, 2025
d227cc7
lint
fzyzcjy Jan 13, 2025
ddeb99e
more adhoc version
fzyzcjy Jan 13, 2025
96938d9
fmt
fzyzcjy Jan 13, 2025
8c1602d
more demo
fzyzcjy Jan 13, 2025
2ea41c4
cleanup demo
fzyzcjy Jan 13, 2025
e630132
more cleanup demo
fzyzcjy Jan 13, 2025
0a8a7d2
fmt
fzyzcjy Jan 13, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion .github/workflows/pr-test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
runs-on: 1-gpu-runner
strategy:
matrix:
range: [0-6, 6-16, 16-23, 23-30, 30-100]
range: [ 0-6, 6-16, 16-23, 23-30, 30-100 ]
steps:
- name: Checkout code
uses: actions/checkout@v3
Expand Down Expand Up @@ -107,6 +107,12 @@ jobs:
cd test/srt
python3 test_update_weights_from_distributed.py

- name: Test EngineFragment
timeout-minutes: 10
run: |
cd test/srt
python3 test_fragment.py

- name: Evaluate MoE EP accuracy (TP=2)
timeout-minutes: 10
run: |
Expand Down
285 changes: 285 additions & 0 deletions examples/runtime/engine/adhoc_verl_torchrun.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,285 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from typing import List

import torch
import torch.nn.functional as F
from torch.distributed.device_mesh import init_device_mesh
from torch.distributed.fsdp import CPUOffload
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp import MixedPrecision
from torch.distributed.fsdp.api import (
ShardedStateDictConfig,
ShardingStrategy,
StateDictType,
)
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer

from sglang.srt.distributed import ParallelProcessGroups
from sglang.srt.server.engine_fragment import EngineFragment


def main():
assert torch.cuda.is_available(), "CUDA must be present to run FSDP vLLM example"
local_rank, rank, world_size = initialize_global_process_group()

# NOTE MODIFIED path-related logic
# local_cache_path = '~/.cache/verl/rlhf'
# local_cache_path = os.path.expanduser(local_cache_path)
hdfs_path = "Qwen/Qwen2-7B-Instruct"
local_model_path = hdfs_path
# from verl.utils.fs import copy_local_path_from_hdfs
# local_model_path = copy_local_path_from_hdfs(src=hdfs_path, cache_dir=local_cache_path)
tokenizer = AutoTokenizer.from_pretrained(local_model_path, trust_remote_code=True)
actor_model_config = AutoConfig.from_pretrained(
local_model_path, trust_remote_code=True
)
with torch.device("cuda"):
actor_model = AutoModelForCausalLM.from_pretrained(
local_model_path, trust_remote_code=True
)
actor_model.to(torch.bfloat16)

max_prompt_length = 16
response_length = 32
preencode_prompts = [
"The president of the United States is",
"The capital of France is",
"The future of AI is",
]
tokenizer.pad_token = tokenizer.eos_token
prompts = tokenizer(
preencode_prompts, return_tensors="pt", padding=True, padding_side="left"
) # NOTE MODIFIED add
input_ids = prompts["input_ids"]
attention_mask = prompts["attention_mask"]
# from verl.utils.torch_functional import pad_sequence_to_length
input_ids = pad_sequence_to_length(
input_ids, max_prompt_length, tokenizer.pad_token_id, left_pad=True
).cuda()
attention_mask = pad_sequence_to_length(
attention_mask, max_prompt_length, 0, left_pad=True
).cuda()

from transformers import GenerationConfig

generation_config = GenerationConfig(do_sample=False)
actor_model.cuda()
output = actor_model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=32,
# max_length=max_length,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config=generation_config,
# renormalize_logits=True,
output_scores=False, # this is potentially very large
return_dict_in_generate=True,
use_cache=False,
) # may OOM when use_cache = True
seq = output.sequences
response = seq[:, max_prompt_length:]

print(f"hf response: {tokenizer.batch_decode(response)}")

tensor_model_parallel_size = 4
device_mesh = init_device_mesh(
"cuda", mesh_shape=(world_size,), mesh_dim_names=["fsdp"]
)

mixed_precision = MixedPrecision(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
buffer_dtype=torch.float32,
)
fsdp_model = FSDP(
actor_model,
use_orig_params=True,
auto_wrap_policy=None,
device_id=torch.cuda.current_device(),
sharding_strategy=ShardingStrategy.FULL_SHARD,
mixed_precision=mixed_precision,
cpu_offload=CPUOffload(offload_params=False),
sync_module_states=False,
device_mesh=device_mesh,
)

FSDP.set_state_dict_type(
fsdp_model,
state_dict_type=StateDictType.SHARDED_STATE_DICT,
state_dict_config=ShardedStateDictConfig(),
)

state_dict = fsdp_model.state_dict()

if rank == 0:
lines = ["------------------------ state_dict ------------------------"]
for k, v in state_dict.items():
v_local = v.to_local()
lines.append(
f"{k}\t: {v.shape=} {v_local.shape=} {v.dtype=} {v_local.dtype=} {type(v)=} {type(v_local)=}"
)
print("\n".join(lines))

# NOTE MODIFIED
# sampling_params = SamplingParams(temperature=0,
# top_p=1,
# n=1,
# max_tokens=response_length,
# logprobs=1,
# ignore_eos=True,
# detokenize=False)
sampling_params = dict(
temperature=0, top_p=1, n=1, max_new_tokens=response_length, ignore_eos=True
)

tp_size, dp_size = 4, 1
kwargs = dict(mesh_shape=(tp_size, dp_size, 1), mesh_dim_names=["tp", "dp", "pp"])
inference_device_mesh_device = init_device_mesh("cuda", **kwargs)
inference_device_mesh_cpu = init_device_mesh("cpu", **kwargs)
print(f"{inference_device_mesh_device=} {inference_device_mesh_cpu=}")

print(actor_model_config)
# llm = LLM(model=None,
# tokenizer=tokenizer,
# model_hf_config=actor_model_config,
# tensor_parallel_size=tensor_model_parallel_size,
# enforce_eager=True,
# dtype='bfloat16',
# load_format='dummy_dtensor',
# gpu_memory_utilization=0.1,
# trust_remote_code=True)
changed_model_path = local_model_path.replace("-Instruct", "")
assert changed_model_path != local_model_path
print(f"{changed_model_path=}")
llm = EngineFragment(
model_path=changed_model_path, # use model of same type but different weight to test update_weights
tp_size=tensor_model_parallel_size,
dtype="bfloat16",
memory_saver=True,
mem_fraction_static=0.6,
nccl_port=12345,
# TODO `tp_rank` (and maybe `nccl_port`?) can be removed later, will do it when #2827's depending PRs are merged
tp_rank=rank,
gpu_id=rank,
parallel_process_groups=ParallelProcessGroups.from_devices_meshes(
device_mesh_device=inference_device_mesh_device,
device_mesh_cpu=inference_device_mesh_cpu,
dim_tp="tp",
dim_pp="pp",
),
)

print("Sleep to have time checking memory consumption")
time.sleep(5)
print("release_gpu_occupation")
llm.release_gpu_occupation()
print("Sleep again to have time checking memory consumption")
time.sleep(5)
print("resume_gpu_occupation")
llm.resume_gpu_occupation()

t = time.time()
if 0:
# most naive way
state_dict_full = {k: v.full_tensor() for k, v in state_dict.items()}
print(f"gather full tensor: {time.time() - t:.2f}")
llm.update_weights_from_tensor([(k, v) for k, v in state_dict_full.items()])
else:
llm.update_weights_from_tensor([(k, v) for k, v in state_dict.items()])
print(f"gather + update weights: {time.time() - t:.2f}")

input_ids = input_ids.cuda()
attention_mask = attention_mask.cuda()
idx_list = []
batch_size = input_ids.shape[0]

pad_token_id = (
tokenizer.pad_token_id
if tokenizer.pad_token_id is not None
else tokenizer.eos_token_id
)
# from verl.workers.rollout.vllm_rollout.vllm_rollout import _pre_process_inputs
for i in range(batch_size):
idx_list.append(_pre_process_inputs(pad_token_id, input_ids[i]))
print("start generation")
# outputs = llm.generate(prompt_token_ids=idx_list, sampling_params=sampling_params, use_tqdm=False)
outputs = llm.generate(input_ids=idx_list, sampling_params=sampling_params)

# vllm_output = outputs[0].cuda()
if torch.distributed.get_rank() == 0:
print(f"hf response: {tokenizer.batch_decode(response)}")
# print(f'vllm response: {tokenizer.batch_decode(vllm_output)}')
print(f'vllm response: {[o["text"] for o in outputs]}')

llm.shutdown()


# COPIED FROM verl
def initialize_global_process_group(timeout_second=36000):
from datetime import timedelta

import torch.distributed

# NOTE MODIFIED should provide backend=None to have nccl+gloo
# torch.distributed.init_process_group('nccl', timeout=timedelta(seconds=timeout_second))
torch.distributed.init_process_group(timeout=timedelta(seconds=timeout_second))

local_rank = int(os.environ["LOCAL_RANK"])
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])

if torch.distributed.is_initialized():
torch.cuda.set_device(local_rank)
return local_rank, rank, world_size


# COPIED FROM verl
def pad_sequence_to_length(tensors, max_seq_len, pad_token_id, left_pad=False):
"""
pad a 2D tensors (e.g. responses, logprobs) in the last dim to max_seq_length.
input shape: [bs, seq_length]
output shape: [bs, max_seq_length]
(0, max_seq_len - tensors.shape[-1]) means right pad to max_seq_length and no left pad
"""
if tensors.shape[-1] >= max_seq_len:
return tensors
pad_tuple = (
(max_seq_len - tensors.shape[-1], 0)
if left_pad
else (0, max_seq_len - tensors.shape[-1])
)
return F.pad(tensors, pad_tuple, "constant", pad_token_id)


# COPIED FROM verl
# NOTE(sgm): add for verl. We can optimize it by making the dataloader yield List[int] without padding.
def _pre_process_inputs(pad_token_id, prompt_token_ids: torch.Tensor) -> List[int]:
# remove the left padding in the prompt token_id
# pad_token_id = self.llm_engine.tokenizer.pad_token_id if self.llm_engine.tokenizer.pad_token_id is not None else self.llm_engine.tokenizer.eos_token_id
non_pad_index = torch.nonzero(prompt_token_ids != pad_token_id, as_tuple=False)[0][
0
]
token_ids = prompt_token_ids[non_pad_index:].tolist()
return token_ids


if __name__ == "__main__":
"""
Run it: LD_PRELOAD=/usr/local/lib/python3.10/dist-packages/torch_memory_saver_cpp.cpython-310-x86_64-linux-gnu.so torchrun --nproc_per_node=4 adhoc_verl_torchrun.py
"""
main()
Loading
Loading