Skip to content

Commit

Permalink
Passing ppo_epochs to dp_actor.py
Browse files Browse the repository at this point in the history
  • Loading branch information
jayl940712 committed Feb 22, 2025
1 parent d02b8c7 commit 3007421
Showing 1 changed file with 66 additions and 65 deletions.
131 changes: 66 additions & 65 deletions verl/workers/actor/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,74 +216,75 @@ def update_policy(self, data: DataProto):
dataloader = batch.split(self.config.ppo_mini_batch_size)

metrics = {}
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
mini_batch = data
if self.config.use_dynamic_bsz:
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
for epoch in range(self.config.ppo_epochs):
for batch_idx, data in enumerate(dataloader):
# split batch into micro_batches
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

self.actor_optimizer.zero_grad()

for data in micro_batches:
data = data.cuda() # actor device is cpu when using offload
responses = data['responses']
response_length = responses.size(1)
attention_mask = data['attention_mask']
response_mask = attention_mask[:, -response_length:]
old_log_prob = data['old_log_probs']
advantages = data['advantages']

clip_ratio = self.config.clip_ratio
entropy_coeff = self.config.entropy_coeff

# all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature)

pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio)
# compute entropy loss from entropy
entropy_loss = verl_F.masked_mean(entropy, response_mask)

# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff

if self.config.use_kl_loss:
ref_log_prob = data['ref_log_prob']
# compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob,
ref_logprob=ref_log_prob,
kl_penalty=self.config.kl_loss_type)
kl_loss = masked_mean(kld, response_mask)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics['actor/kl_loss'] = kl_loss.detach().item()
metrics['actor/kl_coef'] = self.config.kl_loss_coef

mini_batch = data
if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
max_token_len = self.config.ppo_max_token_len_per_gpu * self.ulysses_sequence_parallel_size
micro_batches, _ = rearrange_micro_batches(batch=mini_batch, max_token_len=max_token_len)
else:
loss = policy_loss / self.gradient_accumulation
loss.backward()

data = {
'actor/entropy_loss': entropy_loss.detach().item(),
'actor/pg_loss': pg_loss.detach().item(),
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
'actor/ppo_kl': ppo_kl.detach().item(),
}
append_to_dict(metrics, data)

grad_norm = self._optimizer_step()
data = {'actor/grad_norm': grad_norm.detach().item()}
self.gradient_accumulation = self.config.ppo_mini_batch_size // self.config.ppo_micro_batch_size_per_gpu
# split batch into micro_batches
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

self.actor_optimizer.zero_grad()

for data in micro_batches:
data = data.cuda() # actor device is cpu when using offload
responses = data['responses']
response_length = responses.size(1)
attention_mask = data['attention_mask']
response_mask = attention_mask[:, -response_length:]
old_log_prob = data['old_log_probs']
advantages = data['advantages']

clip_ratio = self.config.clip_ratio
entropy_coeff = self.config.entropy_coeff

# all return: (bsz, response_length)
entropy, log_prob = self._forward_micro_batch(micro_batch=data, temperature=temperature)

pg_loss, pg_clipfrac, ppo_kl = core_algos.compute_policy_loss(old_log_prob=old_log_prob,
log_prob=log_prob,
advantages=advantages,
eos_mask=response_mask,
cliprange=clip_ratio)
# compute entropy loss from entropy
entropy_loss = verl_F.masked_mean(entropy, response_mask)

# compute policy loss
policy_loss = pg_loss - entropy_loss * entropy_coeff

if self.config.use_kl_loss:
ref_log_prob = data['ref_log_prob']
# compute kl loss
kld = core_algos.kl_penalty(logprob=log_prob,
ref_logprob=ref_log_prob,
kl_penalty=self.config.kl_loss_type)
kl_loss = masked_mean(kld, response_mask)

policy_loss = policy_loss + kl_loss * self.config.kl_loss_coef
metrics['actor/kl_loss'] = kl_loss.detach().item()
metrics['actor/kl_coef'] = self.config.kl_loss_coef

if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (len(data) / self.config.ppo_mini_batch_size)
else:
loss = policy_loss / self.gradient_accumulation
loss.backward()

data = {
'actor/entropy_loss': entropy_loss.detach().item(),
'actor/pg_loss': pg_loss.detach().item(),
'actor/pg_clipfrac': pg_clipfrac.detach().item(),
'actor/ppo_kl': ppo_kl.detach().item(),
}
append_to_dict(metrics, data)

grad_norm = self._optimizer_step()
data = {'actor/grad_norm': grad_norm.detach().item()}
append_to_dict(metrics, data)
self.actor_optimizer.zero_grad()
return metrics

0 comments on commit 3007421

Please sign in to comment.