Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Mlflow, allow forward to have different batch size, compute more metrics #74

Merged
merged 17 commits into from
Jan 2, 2025
7 changes: 4 additions & 3 deletions verl/trainer/config/ppo_trainer.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ actor_rollout_ref:
model:
path: ~/models/deepseek-llm-7b-chat
external_lib: null
override_config: {}
override_config: { }
enable_gradient_checkpointing: False
actor:
strategy: fsdp # This is for backward-compatibility
Expand Down Expand Up @@ -78,7 +78,7 @@ critic:
model:
path: ~/models/deepseek-llm-7b-chat
tokenizer_path: ${actor_rollout_ref.model.path}
override_config: {}
override_config: { }
external_lib: ${actor_rollout_ref.model.external_lib}
enable_gradient_checkpointing: False
fsdp_config:
Expand All @@ -90,6 +90,7 @@ critic:
min_num_params: 0
ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size}
ppo_micro_batch_size: 64
forward_micro_batch_size: ${critic.ppo_micro_batch_size}
ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs}
shuffle: ${actor_rollout_ref.actor.shuffle}
grad_clip: 1.0
Expand Down Expand Up @@ -121,7 +122,7 @@ trainer:
total_epochs: 30
project_name: verl_examples
experiment_name: gsm8k
logger: ['console', 'wandb']
logger: [ 'console', 'wandb' ]
nnodes: 1
n_gpus_per_node: 8
save_freq: -1
Expand Down
90 changes: 64 additions & 26 deletions verl/trainer/ppo/ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,19 @@
"""

import os
from contextlib import contextmanager
from dataclasses import dataclass, field
from enum import Enum
from pprint import pprint
from typing import Callable, Type, Tuple, Union
from typing import Type, Dict

from omegaconf import OmegaConf, open_dict
import numpy as np
from codetiming import Timer

from omegaconf import OmegaConf, open_dict
from verl import DataProto
from verl.single_controller.base import Worker
from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs
from verl.single_controller.ray.base import create_colocated_worker_cls
from verl import DataProto
from verl.trainer.ppo import core_algos

WorkerType = Type[Worker]
Expand Down Expand Up @@ -138,23 +138,36 @@ def reduce_metrics(metrics: dict):
return metrics


def compute_data_metrics(batch):
# TODO: add response length
sequence_score = batch.batch['token_level_scores'].sum(-1)
sequence_reward = batch.batch['token_level_rewards'].sum(-1)

def _compute_response_info(batch):
response_length = batch.batch['responses'].shape[-1]

advantages = batch.batch['advantages']
prompt_mask = batch.batch['attention_mask'][:, :-response_length]
response_mask = batch.batch['attention_mask'][:, -response_length:]

prompt_length = prompt_mask.sum(-1).float()
response_length = response_mask.sum(-1).float() # (batch_size,)

return dict(
response_mask=response_mask,
prompt_length=prompt_length,
response_length=response_length,
)


def compute_data_metrics(batch):
# TODO: add response length
sequence_score = batch.batch['token_level_scores'].sum(-1)
sequence_reward = batch.batch['token_level_rewards'].sum(-1)

advantages = batch.batch['advantages']
returns = batch.batch['returns']
values = batch.batch['values']

response_info = _compute_response_info(batch)
response_mask = response_info['response_mask']
prompt_length = response_info['prompt_length']
response_length = response_info['response_length']

metrics = {
# score
'critic/score/mean': torch.mean(sequence_score).detach().item(),
Expand Down Expand Up @@ -188,6 +201,37 @@ def compute_data_metrics(batch):
return metrics


def compute_timing_metrics(batch, timing_raw):
response_info = _compute_response_info(batch)
num_prompt_tokens = torch.sum(response_info['prompt_length']).item()
num_response_tokens = torch.sum(response_info['response_length']).item()
num_overall_tokens = num_prompt_tokens + num_response_tokens

num_tokens_of_section = {
'gen': num_response_tokens,
**{
name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor']
},
}

return {
**{
f'timing/{name}': value for name, value in timing_raw.items()
},
**{
f'timing_per_token/{name}': timing_raw[name] / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys(
)) & set(timing_raw.keys())
},
}


@contextmanager
def _timer(name: str, timing_raw: Dict[str, float]):
with Timer(name=name, logger=None) as timer:
yield
timing_raw[name] = timer.last


class RayPPOTrainer(object):
"""
Note that this trainer runs on the driver process on a single CPU/GPU node.
Expand Down Expand Up @@ -430,6 +474,7 @@ def fit(self):
for epoch in range(self.config.trainer.total_epochs):
for batch_dict in self.train_dataloader:
metrics = {}
timing_raw = {}

batch: DataProto = DataProto.from_single_dict(batch_dict)
# batch = batch.to('cuda')
Expand All @@ -438,26 +483,23 @@ def fit(self):
gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids'])

# generate a batch
with Timer(name='gen', logger=None) as timer:
with _timer('gen', timing_raw):
gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch)
metrics['timing/gen'] = timer.last

batch = batch.union(gen_batch_output)

if self.use_reference_policy:
# compute reference log_prob
with Timer(name='ref', logger=None) as timer:
with _timer('ref', timing_raw):
ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch)
batch = batch.union(ref_log_prob)
metrics['timing/ref'] = timer.last

# compute values
with Timer(name='values', logger=None) as timer:
with _timer('values', timing_raw):
values = self.critic_wg.compute_values(batch)
batch = batch.union(values)
metrics['timing/values'] = timer.last

with Timer(name='adv', logger=None) as timer:
with _timer('adv', timing_raw):
# compute scores. Support both model and function-based.
# We first compute the scores using reward model. Then, we call reward_fn to combine
# the results from reward model and rule-based results.
Expand All @@ -481,36 +523,32 @@ def fit(self):
self.config.algorithm.gamma,
self.config.algorithm.lam,
adv_estimator=self.config.algorithm.adv_estimator)
metrics['timing/adv'] = timer.last

# update critic
if self.use_critic:
with Timer(name='update_critic', logger=None) as timer:
with _timer('update_critic', timing_raw):
critic_output = self.critic_wg.update_critic(batch)
metrics['timing/update_critic'] = timer.last
critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics'])
metrics.update(critic_output_metrics)

# implement critic warmup
if self.config.trainer.critic_warmup <= global_steps:
# update actor
with Timer(name='update_actor', logger=None) as timer:
with _timer('update_actor', timing_raw):
actor_output = self.actor_rollout_wg.update_actor(batch)
metrics['timing/update_actor'] = timer.last
actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics'])
metrics.update(actor_output_metrics)

# validate
if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0:
with Timer(name='testing', logger=None) as timer:
with _timer('testing', timing_raw):
val_metrics: dict = self._validate()
val_metrics = {f'val/{key}': val for key, val in val_metrics.items()}
metrics['timing/testing'] = timer.last
metrics.update(val_metrics)

# collect metrics
data_metrics = compute_data_metrics(batch=batch)
metrics.update(data_metrics)
metrics.update(compute_data_metrics(batch=batch))
metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw))

# TODO: make a canonical logger that supports various backend
logger.log(data=metrics, step=global_steps)
Expand Down
56 changes: 53 additions & 3 deletions verl/utils/tracking.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,15 @@
"""
A unified tracking interface that supports logging data to different backend
"""

from typing import List, Union
import dataclasses
from enum import Enum
from functools import partial
from pathlib import Path
from typing import List, Union, Dict, Any


class Tracking(object):
supported_backend = ['wandb', 'console']
supported_backend = ['wandb', 'mlflow', 'console']

def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None):
if isinstance(default_backend, str):
Expand All @@ -38,6 +41,12 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li
wandb.init(project=project_name, name=experiment_name, config=config)
self.logger['wandb'] = wandb

if 'mlflow' in default_backend:
import mlflow
mlflow.start_run(run_name=experiment_name)
mlflow.log_params(_compute_mlflow_params_from_objects(config))
self.logger['mlflow'] = _MlflowLoggingAdapter()

if 'console' in default_backend:
from verl.utils.logger.aggregate_logger import LocalLogger
self.console_logger = LocalLogger(print_to_console=True)
Expand All @@ -47,3 +56,44 @@ def log(self, data, step, backend=None):
for default_backend, logger_instance in self.logger.items():
if backend is None or default_backend in backend:
logger_instance.log(data=data, step=step)


class _MlflowLoggingAdapter:

def log(self, data, step):
import mlflow
mlflow.log_metrics(metrics=data, step=step)


def _compute_mlflow_params_from_objects(params) -> Dict[str, Any]:
if params is None:
return {}

return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep='/')


def _transform_params_to_json_serializable(x, convert_list_to_dict: bool):
_transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict)

if dataclasses.is_dataclass(x):
return _transform(dataclasses.asdict(x))
if isinstance(x, dict):
return {k: _transform(v) for k, v in x.items()}
if isinstance(x, list):
if convert_list_to_dict:
return {'list_len': len(x)} | {f'{i}': _transform(v) for i, v in enumerate(x)}
else:
return [_transform(v) for v in x]
if isinstance(x, Path):
return str(x)
if isinstance(x, Enum):
return x.value

return x


def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]:
import pandas as pd
ans = pd.json_normalize(raw, sep=sep).to_dict(orient='records')[0]
assert isinstance(ans, dict)
return ans
Loading
Loading