Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ppo] refactor: refactor old_log_prob into a separate function #129

Merged
merged 2 commits into from
Jan 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -601,6 +601,11 @@ def fit(self):
# compute global_valid tokens
batch.meta_info['global_token_num'] = torch.sum(batch.batch['attention_mask'], dim=-1).tolist()

# recompute old_log_probs
with _timer('old_log_prob', timing_raw):
old_log_prob = self.actor_rollout_wg.compute_log_prob(batch)
batch = batch.union(old_log_prob)

if self.use_reference_policy:
# compute reference log_prob
with _timer('ref', timing_raw):
Expand Down
43 changes: 29 additions & 14 deletions verl/workers/fsdp_workers.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,19 +446,6 @@ def generate_sequences(self, prompts: DataProto):

output = self.rollout_sharding_manager.postprocess_data(output)

if self._is_actor and recompute_log_prob:
# we should always recompute old_log_probs when it is HybridEngine
output.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size
output.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu
output.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz
output.meta_info['temperature'] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
output = self.ulysses_sharding_manager.preprocess_data(output)
old_log_probs = self.actor.compute_log_prob(data=output)
output.batch['old_log_probs'] = old_log_probs
output = self.ulysses_sharding_manager.postprocess_data(output)

output = output.to('cpu')

if self._is_offload_param:
Expand All @@ -469,6 +456,33 @@ def generate_sequences(self, prompts: DataProto):
log_gpu_memory_usage('After recompute log prob', logger=logger)
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_log_prob(self, data: DataProto):
assert self._is_actor
data = data.to('cuda')
# we should always recompute old_log_probs when it is HybridEngine
data.meta_info['micro_batch_size'] = self.config.rollout.log_prob_micro_batch_size
data.meta_info['max_token_len'] = self.config.rollout.log_prob_max_token_len_per_gpu
data.meta_info['use_dynamic_bsz'] = self.config.rollout.log_prob_use_dynamic_bsz
data.meta_info['temperature'] = self.config.rollout.temperature
# perform recompute log_prob
with self.ulysses_sharding_manager:
data = self.ulysses_sharding_manager.preprocess_data(data)
old_log_probs = self.actor.compute_log_prob(data=data)
data.batch['old_log_probs'] = old_log_probs
data = self.ulysses_sharding_manager.postprocess_data(data)

output = data.select(batch_keys=['old_log_probs'])
output = output.to('cpu')

# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
if self.world_size > 1:
self.actor.actor_module._handle.reshard(True)

torch.cuda.empty_cache()
return output

@register(dispatch_mode=Dispatch.DP_COMPUTE_PROTO)
def compute_ref_log_prob(self, data: DataProto):
assert self._is_ref
Expand All @@ -490,7 +504,8 @@ def compute_ref_log_prob(self, data: DataProto):

# https://pytorch.org/docs/stable/notes/fsdp.html#fsdp-notes
# unshard the root FSDP module
self.ref_policy.actor_module._handle.reshard(True)
if self.world_size > 1:
self.ref_policy.actor_module._handle.reshard(True)

torch.cuda.empty_cache()
return output
Expand Down
Loading