-
Notifications
You must be signed in to change notification settings - Fork 320
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: Chendong98 <chendong136@huawei.com>
- Loading branch information
1 parent
0a1b16f
commit 5975d6d
Showing
16 changed files
with
486 additions
and
97 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
set -x | ||
|
||
|
||
python3 -m verl.trainer.main_ppo \ | ||
algorithm.adv_estimator=grpo \ | ||
data.train_files=$HOME/data/gsm8k/train.parquet \ | ||
data.val_files=$HOME/data/gsm8k/test.parquet \ | ||
data.train_batch_size=32 \ | ||
data.val_batch_size=1312 \ | ||
data.max_prompt_length=64 \ | ||
data.max_response_length=128 \ | ||
actor_rollout_ref.model.path=Qwen/Qwen2-7B-Instruct \ | ||
actor_rollout_ref.actor.optim.lr=1e-6 \ | ||
actor_rollout_ref.model.use_remove_padding=False \ | ||
actor_rollout_ref.actor.ppo_mini_batch_size=32 \ | ||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=20 \ | ||
actor_rollout_ref.actor.use_kl_loss=True \ | ||
actor_rollout_ref.actor.kl_loss_coef=0.001 \ | ||
actor_rollout_ref.actor.kl_loss_type=low_var_kl \ | ||
actor_rollout_ref.model.enable_gradient_checkpointing=True \ | ||
actor_rollout_ref.actor.fsdp_config.param_offload=False \ | ||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \ | ||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=20 \ | ||
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \ | ||
actor_rollout_ref.rollout.name=vllm \ | ||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \ | ||
actor_rollout_ref.rollout.n=5 \ | ||
actor_rollout_ref.rollout.enable_chunked_prefill=False \ | ||
actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=160 \ | ||
actor_rollout_ref.ref.fsdp_config.param_offload=True \ | ||
actor_rollout_ref.rollout.enable_chunked_prefill=False \ | ||
algorithm.kl_ctrl.kl_coef=0.001 \ | ||
trainer.critic_warmup=0 \ | ||
trainer.logger=['console','wandb'] \ | ||
trainer.project_name='verl_grpo_example_gsm8k' \ | ||
trainer.experiment_name='qwen2_7b_function_rm' \ | ||
trainer.n_gpus_per_node=8 \ | ||
trainer.nnodes=1 \ | ||
trainer.save_freq=-1 \ | ||
trainer.test_freq=5 \ | ||
trainer.total_epochs=15 $@ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -44,7 +44,6 @@ dependencies = [ | |
"ray>=2.10", | ||
"tensordict<0.6", | ||
"transformers", | ||
"vllm<=0.6.3", | ||
'wandb', | ||
] | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
# requirements.txt records the full set of dependencies for development | ||
accelerate | ||
codetiming | ||
datasets | ||
dill | ||
hydra-core | ||
numpy | ||
pandas | ||
peft | ||
pyarrow>=15.0.0 | ||
pybind11 | ||
pylatexenc | ||
ray | ||
tensordict<0.6 | ||
transformers | ||
wandb | ||
vllm | ||
vllm-ascend |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,220 @@ | ||
# Adapted from https://github.com/mlcommons/training_results_v1.1/blob/main/NVIDIA/benchmarks/bert/implementations/pytorch/padding.py | ||
|
||
import torch | ||
import torch.nn.functional as F | ||
from einops import rearrange, repeat | ||
|
||
|
||
class IndexFirstAxis(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, input, indices): | ||
ctx.save_for_backward(indices) | ||
assert input.ndim >= 2 | ||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] | ||
second_dim = other_shape.numel() | ||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. | ||
# return input[indices] | ||
return torch.gather( | ||
rearrange(input, "b ... -> b (...)"), 0, repeat(indices, "z -> z d", d=second_dim) | ||
).reshape(-1, *other_shape) | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
(indices,) = ctx.saved_tensors | ||
assert grad_output.ndim >= 2 | ||
other_shape = grad_output.shape[1:] | ||
grad_output = rearrange(grad_output, "b ... -> b (...)") | ||
grad_input = torch.zeros( | ||
[ctx.first_axis_dim, grad_output.shape[1]], | ||
device=grad_output.device, | ||
dtype=grad_output.dtype, | ||
) | ||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. | ||
# grad_input[indices] = grad_output | ||
grad_input.scatter_(0, repeat(indices, "z -> z d", d=grad_output.shape[1]), grad_output) | ||
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None | ||
|
||
|
||
index_first_axis = IndexFirstAxis.apply | ||
|
||
|
||
class IndexPutFirstAxis(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, values, indices, first_axis_dim): | ||
ctx.save_for_backward(indices) | ||
assert indices.ndim == 1 | ||
assert values.ndim >= 2 | ||
output = torch.zeros( | ||
first_axis_dim, *values.shape[1:], device=values.device, dtype=values.dtype | ||
) | ||
# TD [2022-03-04] For some reason torch.scatter is a bit faster than indexing. | ||
output[indices] = values | ||
# output.scatter_(0, repeat(indices, 'z -> z d', d=values.shape[1]), values) | ||
return output | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output): | ||
(indices,) = ctx.saved_tensors | ||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. | ||
grad_values = grad_output[indices] | ||
# grad_values = torch.gather(grad_output, 0, repeat(indices, 'z -> z d', d=grad_output.shape[1])) | ||
return grad_values, None, None | ||
|
||
|
||
index_put_first_axis = IndexPutFirstAxis.apply | ||
|
||
|
||
class IndexFirstAxisResidual(torch.autograd.Function): | ||
@staticmethod | ||
def forward(ctx, input, indices): | ||
ctx.save_for_backward(indices) | ||
assert input.ndim >= 2 | ||
ctx.first_axis_dim, other_shape = input.shape[0], input.shape[1:] | ||
second_dim = other_shape.numel() | ||
# TD [2022-03-04] For some reason torch.gather is a bit faster than indexing. | ||
output = input[indices] | ||
# We don't want to reshape input (b ... -> b (...)) since it could change the channel_last | ||
# memory format to channel_first. In other words, input might not be contiguous. | ||
# If we don't detach, Pytorch complains about output being a view and is being modified inplace | ||
return output, input.detach() | ||
|
||
@staticmethod | ||
def backward(ctx, grad_output, grad_residual): | ||
(indices,) = ctx.saved_tensors | ||
assert grad_output.ndim >= 2 | ||
other_shape = grad_output.shape[1:] | ||
assert grad_residual.shape[1:] == other_shape | ||
grad_input = grad_residual | ||
# grad_input[indices] += grad_output | ||
indices = indices.reshape(indices.shape[0], *((1,) * (grad_output.ndim - 1))) | ||
indices = indices.expand_as(grad_output) | ||
grad_input.scatter_add_(0, indices, grad_output) | ||
return grad_input.reshape(ctx.first_axis_dim, *other_shape), None | ||
|
||
|
||
index_first_axis_residual = IndexFirstAxisResidual.apply | ||
|
||
|
||
def unpad_input(hidden_states, attention_mask, unused_mask=None): | ||
""" | ||
Arguments: | ||
hidden_states: (batch, seqlen, ...) | ||
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. | ||
unused_mask: (batch, seqlen), bool / int, 1 means the element is allocated but unused. | ||
Return: | ||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask + unused_mask. | ||
indices: (total_nnz), the indices of masked tokens from the flattened input sequence. | ||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. | ||
max_seqlen_in_batch: int | ||
seqused: (batch), returns the number of tokens selected in attention_mask + unused_mask. | ||
""" | ||
all_masks = (attention_mask + unused_mask) if unused_mask is not None else attention_mask | ||
seqlens_in_batch = all_masks.sum(dim=-1, dtype=torch.int32) | ||
used_seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) | ||
indices = torch.nonzero(all_masks.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | ||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the | ||
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim | ||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to | ||
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be, | ||
# so we write custom forward and backward to make it a bit faster. | ||
return ( | ||
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
used_seqlens_in_batch, | ||
) | ||
|
||
|
||
def unpad_input_for_concatenated_sequences(hidden_states, attention_mask_in_length): | ||
""" | ||
Supports concatenating short samples in one sequence. The attention_mask_in_length is utilized to mask other short samples. It helps efficient training of variant lengths-based samples (e.g., the supervised fine-tuning task in large language model). | ||
The motivation for this function is explained [here](https://github.com/Dao-AILab/flash-attention/issues/432#issuecomment-1668822286). | ||
For example, if batch = 3 and seqlen = 6, the attention_mask_in_length is: | ||
``` | ||
[ | ||
[2, 3, 0, 0, 0, 0], | ||
[3, 2, 0, 0, 0, 0], | ||
[6, 0, 0, 0, 0, 0] | ||
] | ||
``` | ||
, which refers to the 3D-attention mask: | ||
``` | ||
[ | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[0, 0, 1, 0, 0, 0], | ||
[0, 0, 1, 1, 0, 0], | ||
[0, 0, 1, 1, 1, 0], | ||
[0, 0, 0, 0, 0, 1] | ||
], | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[1, 1, 1, 0, 0, 0], | ||
[0, 0, 0, 1, 0, 0], | ||
[0, 0, 0, 1, 1, 0], | ||
[0, 0, 0, 0, 0, 1] | ||
], | ||
[ | ||
[1, 0, 0, 0, 0, 0], | ||
[1, 1, 0, 0, 0, 0], | ||
[1, 1, 1, 0, 0, 0], | ||
[1, 1, 1, 1, 0, 0], | ||
[1, 1, 1, 1, 1, 0], | ||
[1, 1, 1, 1, 1, 1] | ||
] | ||
] | ||
```. | ||
Arguments: | ||
hidden_states: (batch, seqlen, ...) | ||
attention_mask_in_length: (batch, seqlen), int, a nonzero number (e.g., 1, 2, 3, etc.) means length of concatenated sequence in b-th batch, and 0 means none. | ||
Return: | ||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. | ||
indices: (total_nnz), the indices of non-masked tokens from the flattened input sequence. | ||
cu_seqlens: (batch + 1), the cumulative sequence lengths, used to index into hidden_states. | ||
max_seqlen_in_batch: int | ||
""" | ||
length = attention_mask_in_length.sum(dim=-1) | ||
seqlen = attention_mask_in_length.size(-1) | ||
attention_mask_2d = torch.arange(seqlen, device=length.device, dtype=length.dtype).expand(len(length), | ||
seqlen) < length.unsqueeze( | ||
1) | ||
real_indices_idx = torch.nonzero(attention_mask_in_length.flatten(), as_tuple=False).flatten() | ||
seqlens_in_batch = attention_mask_in_length.flatten()[real_indices_idx] | ||
indices = torch.nonzero(attention_mask_2d.flatten(), as_tuple=False).flatten() | ||
max_seqlen_in_batch = seqlens_in_batch.max().item() | ||
cu_seqlens = F.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) | ||
# TD [2022-03-04] We don't want to index with a bool mask, because Pytorch will expand the | ||
# bool mask, then call nonzero to get the indices, then index with those. The indices is @dim | ||
# times larger than it needs to be, wasting memory. It's faster and more memory-efficient to | ||
# index with integer indices. Moreover, torch's index is a bit slower than it needs to be, | ||
# so we write custom forward and backward to make it a bit faster. | ||
return ( | ||
index_first_axis(rearrange(hidden_states, "b s ... -> (b s) ..."), indices), | ||
indices, | ||
cu_seqlens, | ||
max_seqlen_in_batch, | ||
) | ||
|
||
|
||
def pad_input(hidden_states, indices, batch, seqlen): | ||
""" | ||
Arguments: | ||
hidden_states: (total_nnz, ...), where total_nnz = number of tokens in selected in attention_mask. | ||
indices: (total_nnz), the indices that represent the non-masked tokens of the original padded input sequence. | ||
batch: int, batch size for the padded sequence. | ||
seqlen: int, maximum sequence length for the padded sequence. | ||
Return: | ||
hidden_states: (batch, seqlen, ...) | ||
""" | ||
dim = hidden_states.shape[-1] | ||
# output = torch.zeros((batch * seqlen), dim, device=hidden_states.device, dtype=hidden_states.dtype) | ||
# output[indices] = hidden_states | ||
output = index_put_first_axis(hidden_states, indices, batch * seqlen) | ||
return rearrange(output, "(b s) ... -> b s ...", b=batch) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.