From 74636749708855c9f1711c8315b5e106d72ca643 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 13:50:48 +0800 Subject: [PATCH 01/17] extract --- verl/trainer/ppo/ray_trainer.py | 36 ++++++++++++++++----------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 27efa448..ba625865 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -17,19 +17,19 @@ """ import os +from contextlib import contextmanager from dataclasses import dataclass, field from enum import Enum from pprint import pprint -from typing import Callable, Type, Tuple, Union +from typing import Type, Dict -from omegaconf import OmegaConf, open_dict import numpy as np from codetiming import Timer - +from omegaconf import OmegaConf, open_dict +from verl import DataProto from verl.single_controller.base import Worker from verl.single_controller.ray import RayResourcePool, RayWorkerGroup, RayClassWithInitArgs from verl.single_controller.ray.base import create_colocated_worker_cls -from verl import DataProto from verl.trainer.ppo import core_algos WorkerType = Type[Worker] @@ -438,26 +438,23 @@ def fit(self): gen_batch = batch.pop(batch_keys=['input_ids', 'attention_mask', 'position_ids']) # generate a batch - with Timer(name='gen', logger=None) as timer: + with _timer('gen', timing_raw): gen_batch_output = self.actor_rollout_wg.generate_sequences(gen_batch) - metrics['timing/gen'] = timer.last batch = batch.union(gen_batch_output) if self.use_reference_policy: # compute reference log_prob - with Timer(name='ref', logger=None) as timer: + with _timer('ref', timing_raw): ref_log_prob = self.ref_policy_wg.compute_ref_log_prob(batch) batch = batch.union(ref_log_prob) - metrics['timing/ref'] = timer.last # compute values - with Timer(name='values', logger=None) as timer: + with _timer('values', timing_raw): values = self.critic_wg.compute_values(batch) batch = batch.union(values) - metrics['timing/values'] = timer.last - with Timer(name='adv', logger=None) as timer: + with _timer('adv', timing_raw): # compute scores. Support both model and function-based. # We first compute the scores using reward model. Then, we call reward_fn to combine # the results from reward model and rule-based results. @@ -481,31 +478,27 @@ def fit(self): self.config.algorithm.gamma, self.config.algorithm.lam, adv_estimator=self.config.algorithm.adv_estimator) - metrics['timing/adv'] = timer.last # update critic if self.use_critic: - with Timer(name='update_critic', logger=None) as timer: + with _timer('update_critic', timing_raw): critic_output = self.critic_wg.update_critic(batch) - metrics['timing/update_critic'] = timer.last critic_output_metrics = reduce_metrics(critic_output.meta_info['metrics']) metrics.update(critic_output_metrics) # implement critic warmup if self.config.trainer.critic_warmup <= global_steps: # update actor - with Timer(name='update_actor', logger=None) as timer: + with _timer('update_actor', timing_raw): actor_output = self.actor_rollout_wg.update_actor(batch) - metrics['timing/update_actor'] = timer.last actor_output_metrics = reduce_metrics(actor_output.meta_info['metrics']) metrics.update(actor_output_metrics) # validate if self.val_reward_fn is not None and (global_steps + 1) % self.config.trainer.test_freq == 0: - with Timer(name='testing', logger=None) as timer: + with _timer('testing', timing_raw): val_metrics: dict = self._validate() val_metrics = {f'val/{key}': val for key, val in val_metrics.items()} - metrics['timing/testing'] = timer.last metrics.update(val_metrics) # collect metrics @@ -536,3 +529,10 @@ def fit(self): val_metrics = self._validate() pprint(f'Final validation metrics: {val_metrics}') logger.log(data=val_metrics, step=global_steps) + + +@contextmanager +def _timer(name: str, timing_raw: Dict[str, float]): + with Timer(name=name, logger=None) as timer: + yield + timing_raw[f'timing/{name}'] = timer.last From 826e4948d6f9f8945d30b992a0aa786ef0a6c5f9 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 13:55:35 +0800 Subject: [PATCH 02/17] more --- verl/trainer/ppo/ray_trainer.py | 41 ++++++++++++++++++++++++++------- 1 file changed, 33 insertions(+), 8 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index ba625865..b7a196d7 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -138,23 +138,36 @@ def reduce_metrics(metrics: dict): return metrics -def compute_data_metrics(batch): - # TODO: add response length - sequence_score = batch.batch['token_level_scores'].sum(-1) - sequence_reward = batch.batch['token_level_rewards'].sum(-1) - +def _compute_response_info(batch): response_length = batch.batch['responses'].shape[-1] - advantages = batch.batch['advantages'] prompt_mask = batch.batch['attention_mask'][:, :-response_length] response_mask = batch.batch['attention_mask'][:, -response_length:] prompt_length = prompt_mask.sum(-1).float() response_length = response_mask.sum(-1).float() # (batch_size,) + return dict( + response_mask=response_mask, + prompt_length=prompt_length, + response_length=response_length, + ) + + +def compute_data_metrics(batch): + # TODO: add response length + sequence_score = batch.batch['token_level_scores'].sum(-1) + sequence_reward = batch.batch['token_level_rewards'].sum(-1) + + advantages = batch.batch['advantages'] returns = batch.batch['returns'] values = batch.batch['values'] + response_info = _compute_response_info(batch) + response_mask = response_info['response_mask'] + prompt_length = response_info['prompt_length'] + response_length = response_info['response_length'] + metrics = { # score 'critic/score/mean': torch.mean(sequence_score).detach().item(), @@ -188,6 +201,17 @@ def compute_data_metrics(batch): return metrics +def compute_timing_metrics(batch, timing_raw): + response_info = _compute_response_info(batch) + num_prompt_tokens = torch.sum(response_info['prompt_length']).item() + num_response_tokens = torch.sum(response_info['response_length']).item() + + return { + **{f'timing/{name}': value for name, value in timing_raw.items()}, + f'timing_per_token/{name}': TODO, + } + + class RayPPOTrainer(object): """ Note that this trainer runs on the driver process on a single CPU/GPU node. @@ -430,6 +454,7 @@ def fit(self): for epoch in range(self.config.trainer.total_epochs): for batch_dict in self.train_dataloader: metrics = {} + timing_raw = {} batch: DataProto = DataProto.from_single_dict(batch_dict) # batch = batch.to('cuda') @@ -502,8 +527,8 @@ def fit(self): metrics.update(val_metrics) # collect metrics - data_metrics = compute_data_metrics(batch=batch) - metrics.update(data_metrics) + metrics.update(compute_data_metrics(batch=batch)) + metrics.update(compute_timing_metrics(batch=batch, timing_raw=timing_raw)) # TODO: make a canonical logger that supports various backend logger.log(data=metrics, step=global_steps) From 6aa467c7d5367b7847482ef39a6a71384d0a08ed Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 13:57:08 +0800 Subject: [PATCH 03/17] more --- verl/trainer/ppo/ray_trainer.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index b7a196d7..4687e0f2 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -205,10 +205,20 @@ def compute_timing_metrics(batch, timing_raw): response_info = _compute_response_info(batch) num_prompt_tokens = torch.sum(response_info['prompt_length']).item() num_response_tokens = torch.sum(response_info['response_length']).item() + num_overall_tokens = num_prompt_tokens + num_response_tokens + + num_tokens_of_section = { + 'gen': num_response_tokens, + 'ref': num_overall_tokens, + 'values': num_overall_tokens, + 'critic': num_overall_tokens, + 'actor': num_overall_tokens, + } return { **{f'timing/{name}': value for name, value in timing_raw.items()}, - f'timing_per_token/{name}': TODO, + **{f'timing_per_token/{name}': timing_raw[name] / num_tokens + for name, num_tokens in num_tokens_of_section.items()}, } From f0eea62407331f3dc92642812b201908bef9f74a Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 13:58:45 +0800 Subject: [PATCH 04/17] fix --- verl/trainer/ppo/ray_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 4687e0f2..0855e76d 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -570,4 +570,4 @@ def fit(self): def _timer(name: str, timing_raw: Dict[str, float]): with Timer(name=name, logger=None) as timer: yield - timing_raw[f'timing/{name}'] = timer.last + timing_raw[name] = timer.last From cf475110856e0d2bbba4cfd7d71dfb9ac8cf2667 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 14:00:32 +0800 Subject: [PATCH 05/17] more --- verl/trainer/ppo/ray_trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 0855e76d..b3b6df77 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -217,8 +217,8 @@ def compute_timing_metrics(batch, timing_raw): return { **{f'timing/{name}': value for name, value in timing_raw.items()}, - **{f'timing_per_token/{name}': timing_raw[name] / num_tokens - for name, num_tokens in num_tokens_of_section.items()}, + **{f'timing_per_token/{name}': timing_raw[name] / num_tokens_of_section[name] + for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())}, } From fce7267918dce260e8fc5842bfe85c25b45c9849 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 14:04:52 +0800 Subject: [PATCH 06/17] more --- verl/trainer/ppo/ray_trainer.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index b3b6df77..823e91b6 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -209,10 +209,7 @@ def compute_timing_metrics(batch, timing_raw): num_tokens_of_section = { 'gen': num_response_tokens, - 'ref': num_overall_tokens, - 'values': num_overall_tokens, - 'critic': num_overall_tokens, - 'actor': num_overall_tokens, + **{name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor']}, } return { From 3f4fa246d1e7fafdc1759831d73f5a1731e202b0 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 14:14:50 +0800 Subject: [PATCH 07/17] fmt --- verl/trainer/ppo/ray_trainer.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 823e91b6..12c3ed27 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -209,13 +209,19 @@ def compute_timing_metrics(batch, timing_raw): num_tokens_of_section = { 'gen': num_response_tokens, - **{name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor']}, + **{ + name: num_overall_tokens for name in ['ref', 'values', 'adv', 'update_critic', 'update_actor'] + }, } return { - **{f'timing/{name}': value for name, value in timing_raw.items()}, - **{f'timing_per_token/{name}': timing_raw[name] / num_tokens_of_section[name] - for name in set(num_tokens_of_section.keys()) & set(timing_raw.keys())}, + **{ + f'timing/{name}': value for name, value in timing_raw.items() + }, + **{ + f'timing_per_token/{name}': timing_raw[name] / num_tokens_of_section[name] for name in set(num_tokens_of_section.keys( + )) & set(timing_raw.keys()) + }, } From 2a1c193685575a59c4e7a9af9f03192a2d8c883d Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 16:30:09 +0800 Subject: [PATCH 08/17] more --- verl/utils/tracking.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index 5a65f954..6f3c24e7 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -19,7 +19,7 @@ class Tracking(object): - supported_backend = ['wandb', 'console'] + supported_backend = ['wandb', 'mlflow', 'console'] def __init__(self, project_name, experiment_name, default_backend: Union[str, List[str]] = 'console', config=None): if isinstance(default_backend, str): @@ -38,6 +38,11 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li wandb.init(project=project_name, name=experiment_name, config=config) self.logger['wandb'] = wandb + if 'mlflow' in default_backend: + import mlflow + TODO + self.logger['mlflow'] = mlflow + if 'console' in default_backend: from verl.utils.logger.aggregate_logger import LocalLogger self.console_logger = LocalLogger(print_to_console=True) From 50cf7d9263a3c33ab8a64bff956f613f50320017 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 16:34:33 +0800 Subject: [PATCH 09/17] more --- verl/utils/tracking.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index 6f3c24e7..7d51a29e 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -40,8 +40,9 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li if 'mlflow' in default_backend: import mlflow - TODO - self.logger['mlflow'] = mlflow + mlflow.start_run() + mlflow.log_params(_convert_config_to_mlflow(config)) + self.logger['mlflow'] = _MlflowLoggingAdapter() if 'console' in default_backend: from verl.utils.logger.aggregate_logger import LocalLogger @@ -52,3 +53,9 @@ def log(self, data, step, backend=None): for default_backend, logger_instance in self.logger.items(): if backend is None or default_backend in backend: logger_instance.log(data=data, step=step) + + +class _MlflowLoggingAdapter: + def log(self, data, step): + import mlflow + mlflow.log_metrics(metrics=data, step=step) From f20da363baecfba8270a1d12d3fbba297ae54dbc Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 16:36:00 +0800 Subject: [PATCH 10/17] more --- verl/utils/tracking.py | 43 +++++++++++++++++++++++++++++++++++++++--- 1 file changed, 40 insertions(+), 3 deletions(-) diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index 7d51a29e..db3eb95e 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -14,8 +14,11 @@ """ A unified tracking interface that supports logging data to different backend """ - -from typing import List, Union +import dataclasses +from enum import Enum +from functools import partial +from pathlib import Path +from typing import List, Union, Dict, Any class Tracking(object): @@ -41,7 +44,7 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li if 'mlflow' in default_backend: import mlflow mlflow.start_run() - mlflow.log_params(_convert_config_to_mlflow(config)) + mlflow.log_params(_compute_mlflow_params_from_objects(config)) self.logger['mlflow'] = _MlflowLoggingAdapter() if 'console' in default_backend: @@ -59,3 +62,37 @@ class _MlflowLoggingAdapter: def log(self, data, step): import mlflow mlflow.log_metrics(metrics=data, step=step) + + +def _compute_mlflow_params_from_objects(params) -> Dict[str, Any]: + if params is None: + return {} + + return _flatten_dict(_transform_params_to_json_serializable(params, convert_list_to_dict=True), sep='/') + + +def _transform_params_to_json_serializable(x, convert_list_to_dict: bool): + _transform = partial(_transform_params_to_json_serializable, convert_list_to_dict=convert_list_to_dict) + + if dataclasses.is_dataclass(x): + return _transform(dataclasses.asdict(x)) + if isinstance(x, dict): + return {k: _transform(v) for k, v in x.items()} + if isinstance(x, list): + if convert_list_to_dict: + return {'list_len': len(x)} | {f'{i}': _transform(v) for i, v in enumerate(x)} + else: + return [_transform(v) for v in x] + if isinstance(x, Path): + return str(x) + if isinstance(x, Enum): + return x.value + + return x + + +def _flatten_dict(raw: Dict[str, Any], *, sep: str) -> Dict[str, Any]: + import pandas as pd + ans = pd.json_normalize(raw, sep=sep).to_dict(orient='records')[0] + assert isinstance(ans, dict) + return ans From 0bd5c8227792b7ddb9397796d5d9b2ed207096f7 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 16:44:44 +0800 Subject: [PATCH 11/17] more --- verl/utils/tracking.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index db3eb95e..57f5fb44 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -43,7 +43,7 @@ def __init__(self, project_name, experiment_name, default_backend: Union[str, Li if 'mlflow' in default_backend: import mlflow - mlflow.start_run() + mlflow.start_run(run_name=experiment_name) mlflow.log_params(_compute_mlflow_params_from_objects(config)) self.logger['mlflow'] = _MlflowLoggingAdapter() From b83fcbed81bb60badc21a73f670acc040f49d53e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 16:45:41 +0800 Subject: [PATCH 12/17] more --- verl/workers/fsdp_workers.py | 39 +++++++++++++++++------------------- 1 file changed, 18 insertions(+), 21 deletions(-) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index d1479f26..c884f5e0 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -15,26 +15,26 @@ The main entry point to run the PPO algorithm """ -import os import logging +import os import warnings -import ray + import torch import torch.distributed -from omegaconf import DictConfig, open_dict - -from verl.single_controller.base import Worker -from verl.single_controller.base.decorator import register, Dispatch +import verl.utils.hdfs_io as hdfs_io import verl.utils.torch_functional as verl_F +from omegaconf import DictConfig from verl import DataProto -from verl.utils.model import compute_position_id_with_mask +from verl.single_controller.base import Worker +from verl.single_controller.base.decorator import register, Dispatch +from verl.utils import hf_tokenizer +from verl.utils.debug import log_gpu_memory_usage from verl.utils.fs import copy_local_path_from_hdfs -from verl.utils.fsdp_utils import get_fsdp_wrap_policy, load_fsdp_grad, offload_fsdp_grad, init_fn, get_init_weight_context_manager -from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_param_and_grad, load_fsdp_optimizer, load_fsdp_param_and_grad +from verl.utils.fsdp_utils import get_fsdp_wrap_policy, offload_fsdp_grad, init_fn, get_init_weight_context_manager +from verl.utils.fsdp_utils import offload_fsdp_optimizer, offload_fsdp_param_and_grad, load_fsdp_optimizer, \ + load_fsdp_param_and_grad from verl.utils.import_utils import import_external_libs -from verl.utils.debug import log_gpu_memory_usage -import verl.utils.hdfs_io as hdfs_io -from verl.utils import hf_tokenizer +from verl.utils.model import compute_position_id_with_mask logger = logging.getLogger(__file__) logger.setLevel(os.getenv('VERL_PPO_LOGGING_LEVEL', 'WARN')) @@ -95,9 +95,8 @@ def _build_model_optimizer(self, trust_remote_code=False): from verl.utils.model import print_model_size, update_model_config from verl.utils.torch_dtypes import PrecisionType - from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \ - CPUOffload + from transformers import AutoModelForCausalLM, AutoConfig + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision from torch import optim log_gpu_memory_usage('Before init from HF AutoModel', logger=logger) @@ -322,7 +321,7 @@ def update_actor(self, data: DataProto): self.actor_lr_scheduler.step() lr = self.actor_lr_scheduler.get_last_lr()[0] - metrics['actor/lr(1e-4)'] = lr * 1e4 + metrics['actor/lr'] = lr log_gpu_memory_usage('After update policy', logger=logger) @@ -453,15 +452,13 @@ def _build_critic_model_optimizer(self, config): # the following line is necessary from verl.utils.model import LambdaLayer, print_model_size, squeeze from verl.utils.torch_dtypes import PrecisionType - from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision, \ - CPUOffload + from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, MixedPrecision from torch import optim local_path = copy_local_path_from_hdfs(config.model.path) # note that the tokenizer between actor and critic may be different. So override tokenizer info with actor info # using random initialized model from any architecture. May not be the same as Actor. # TODO: support loading critic weights from RM. Support using AutoModelForTokenClassification - from transformers import AutoTokenizer tokenizer_path = copy_local_path_from_hdfs(config.model.tokenizer_path) self.tokenizer = hf_tokenizer(tokenizer_path, trust_remote_code=config.model.get('trust_remote_code', False)) @@ -600,7 +597,7 @@ def update_critic(self, data: DataProto): self.critic_lr_scheduler.step() lr = self.critic_lr_scheduler.get_last_lr()[0] - metrics['critic/lr(1e-4)'] = lr * 1e4 + metrics['critic/lr'] = lr output = DataProto(batch=None, meta_info={'metrics': metrics}) if self._is_offload_param: @@ -656,7 +653,7 @@ def __init__(self, config): def _build_model(self, config): # the following line is necessary - from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig + from transformers import AutoModelForSequenceClassification, AutoConfig from torch.distributed.fsdp import FullyShardedDataParallel as FSDP, ShardingStrategy, CPUOffload # download the checkpoint from hdfs From d9c3fef4511d47ea46ed7f38fdfd340f35cb853e Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 16:46:43 +0800 Subject: [PATCH 13/17] fmt --- verl/utils/tracking.py | 1 + 1 file changed, 1 insertion(+) diff --git a/verl/utils/tracking.py b/verl/utils/tracking.py index 57f5fb44..809a8a2b 100644 --- a/verl/utils/tracking.py +++ b/verl/utils/tracking.py @@ -59,6 +59,7 @@ def log(self, data, step, backend=None): class _MlflowLoggingAdapter: + def log(self, data, step): import mlflow mlflow.log_metrics(metrics=data, step=step) From 553e310a868b4d695011eee0c032fbe215f64725 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 17:00:10 +0800 Subject: [PATCH 14/17] more --- verl/trainer/config/ppo_trainer.yaml | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 22835cca..2c7dcc5e 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -15,7 +15,7 @@ actor_rollout_ref: model: path: ~/models/deepseek-llm-7b-chat external_lib: null - override_config: {} + override_config: { } enable_gradient_checkpointing: False actor: strategy: fsdp # This is for backward-compatibility @@ -78,7 +78,7 @@ critic: model: path: ~/models/deepseek-llm-7b-chat tokenizer_path: ${actor_rollout_ref.model.path} - override_config: {} + override_config: { } external_lib: ${actor_rollout_ref.model.external_lib} enable_gradient_checkpointing: False fsdp_config: @@ -90,6 +90,7 @@ critic: min_num_params: 0 ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size: 64 + log_prob_micro_batch_size: ${critic.ppo_micro_batch_size} ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} shuffle: ${actor_rollout_ref.actor.shuffle} grad_clip: 1.0 @@ -121,7 +122,7 @@ trainer: total_epochs: 30 project_name: verl_examples experiment_name: gsm8k - logger: ['console', 'wandb'] + logger: [ 'console', 'wandb' ] nnodes: 1 n_gpus_per_node: 8 save_freq: -1 From 8505d861a46f2dc56953d5eb7977460e143162fa Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 17:00:55 +0800 Subject: [PATCH 15/17] more --- verl/trainer/config/ppo_trainer.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/verl/trainer/config/ppo_trainer.yaml b/verl/trainer/config/ppo_trainer.yaml index 2c7dcc5e..29da7499 100644 --- a/verl/trainer/config/ppo_trainer.yaml +++ b/verl/trainer/config/ppo_trainer.yaml @@ -90,7 +90,7 @@ critic: min_num_params: 0 ppo_mini_batch_size: ${actor_rollout_ref.actor.ppo_mini_batch_size} ppo_micro_batch_size: 64 - log_prob_micro_batch_size: ${critic.ppo_micro_batch_size} + forward_micro_batch_size: ${critic.ppo_micro_batch_size} ppo_epochs: ${actor_rollout_ref.actor.ppo_epochs} shuffle: ${actor_rollout_ref.actor.shuffle} grad_clip: 1.0 From 2ea1e8a36a2bf0de00001afdb2b6e2c55cbfca92 Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 17:02:34 +0800 Subject: [PATCH 16/17] more --- verl/workers/fsdp_workers.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/verl/workers/fsdp_workers.py b/verl/workers/fsdp_workers.py index c884f5e0..0328a15e 100644 --- a/verl/workers/fsdp_workers.py +++ b/verl/workers/fsdp_workers.py @@ -447,6 +447,7 @@ def __init__(self, config): # normalize config self.config.ppo_mini_batch_size //= torch.distributed.get_world_size() self.config.ppo_micro_batch_size //= torch.distributed.get_world_size() + self.config.forward_micro_batch_size //= torch.distributed.get_world_size() def _build_critic_model_optimizer(self, config): # the following line is necessary @@ -574,7 +575,7 @@ def compute_values(self, data: DataProto): load_fsdp_param_and_grad(module=self.critic_module, device_id=torch.cuda.current_device(), load_grad=self._is_offload_grad) - micro_batch_size = self.config.ppo_micro_batch_size + micro_batch_size = self.config.forward_micro_batch_size data.meta_info['micro_batch_size'] = micro_batch_size values = self.critic.compute_values(data=data) output = DataProto.from_dict(tensors={'values': values}) From 0176a64e559d841046ae2f92c3391d285d2d79af Mon Sep 17 00:00:00 2001 From: fzyzcjy Date: Thu, 2 Jan 2025 20:06:07 +0800 Subject: [PATCH 17/17] mv --- verl/trainer/ppo/ray_trainer.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 12c3ed27..37d65956 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -225,6 +225,13 @@ def compute_timing_metrics(batch, timing_raw): } +@contextmanager +def _timer(name: str, timing_raw: Dict[str, float]): + with Timer(name=name, logger=None) as timer: + yield + timing_raw[name] = timer.last + + class RayPPOTrainer(object): """ Note that this trainer runs on the driver process on a single CPU/GPU node. @@ -567,10 +574,3 @@ def fit(self): val_metrics = self._validate() pprint(f'Final validation metrics: {val_metrics}') logger.log(data=val_metrics, step=global_steps) - - -@contextmanager -def _timer(name: str, timing_raw: Dict[str, float]): - with Timer(name=name, logger=None) as timer: - yield - timing_raw[name] = timer.last