From dcabc515d8f1d1139bc0f7efe9636c2c56f6fcd2 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 17 Nov 2022 02:04:48 +0200 Subject: [PATCH 1/2] fix(ppo): optional reward scaling and minibatch advantage whitening --- configs/ppo_config.yml | 2 +- trlx/orchestrator/ppo_orchestrator.py | 27 ++++++++-- trlx/utils/modeling.py | 75 +++++++++++++++++++++++++-- 3 files changed, 95 insertions(+), 9 deletions(-) diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index 38a36cddc..d6ee43c70 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -6,7 +6,7 @@ model: train: seq_length: 48 # Size of LM context - epochs: 1000 # Train for max(epochs, total_steps) + epochs: 100 # Train for max(epochs, total_steps) total_steps: 10000 # Train for max(epochs, total_steps) batch_size: 128 # batch size diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 7f6096aae..b097d95d3 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -8,8 +8,9 @@ from trlx.orchestrator import Orchestrator, register_orchestrator from trlx.pipeline import BasePipeline from trlx.utils import Clock -from trlx.utils.modeling import logprobs_from_logits +from trlx.utils.modeling import logprobs_from_logits, RunningMoments +from time import time import ray from ray.air import session @@ -45,6 +46,10 @@ def __init__( self.rl_model.reward_fn = reward_fn self.rl_model.metric_fn = metric_fn + self.running = RunningMoments() + self.ref_mean = None + self.ref_std = None + def score(self, samples): """ Batched scoring function taking text and generating scalar @@ -66,15 +71,28 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): self.pipeline_iterator = iter(self.pipeline_loader) batch = next(self.pipeline_iterator) + exp_generate_time = time() samples = self.rl_model.generate(**batch) + stats["exp_generate_time"] = time() - exp_generate_time query_tensors = batch.input_ids response_tensors = samples[:, query_tensors.shape[1] :] texts = self.rl_model.tokenizer.batch_decode( samples, skip_special_tokens=True ) - scores = torch.as_tensor(self.score(texts)) - + exp_score_time = time() + scores = torch.as_tensor(self.score(texts), device=samples.device) + stats["exp_score_time"] = time() - exp_score_time + + if self.ref_mean is None: + self.ref_mean, self.ref_std = scores.mean(), scores.std() + all_scores_mean, all_scores_std = self.running.update(scores) + scores /= self.running.std + + stats["exp_scores_mean"] = all_scores_mean + stats["exp_scores_std"] = all_scores_std + stats["running_mean"] = self.running.mean + stats["running_std"] = self.running.std # Precompute logprobs, values all_tokens = torch.cat( (query_tensors.to(samples.device), response_tensors), dim=1 @@ -126,7 +144,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): ] ppo_rl_elements += new_ppo_rl_elements - stats = {"exp_time": exp_time} + stats["kl_ctl_value"] = self.rl_model.kl_ctl.value + stats["exp_time"] = exp_time if not ray.is_initialized(): self.rl_model.accelerator.log(stats, step=iter_count) diff --git a/trlx/utils/modeling.py b/trlx/utils/modeling.py index 2751edd67..3fe30c330 100644 --- a/trlx/utils/modeling.py +++ b/trlx/utils/modeling.py @@ -2,12 +2,33 @@ import torch import torch.nn.functional as F +import torch.distributed as dist +from typing import Tuple -def whiten(values, shift_mean=True): - """Whiten values.""" - mean, var = torch.mean(values), torch.var(values) - whitened = (values - mean) * torch.rsqrt(var + 1e-8) +def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]: + """ + Computes element-wise mean and variance of the tensor across processes + """ + sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device) + dist.all_reduce(sum_and_count, dist.ReduceOp.SUM) + global_sum, count = sum_and_count + global_mean = global_sum / count + + sum_var = torch.sum((xs - global_mean) ** 2) + dist.all_reduce(sum_var, dist.ReduceOp.SUM) + global_var = sum_var / count + return global_mean, global_var, count + + +def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor: + """Whitens values""" + if distributed and dist.is_initialized(): + mean, var, _ = get_global_statistics(xs) + else: + var, mean = torch.var_mean(xs) + + whitened = (xs - mean) * torch.rsqrt(var + 1e-8) if not shift_mean: whitened += mean return whitened @@ -34,3 +55,49 @@ def flatten_dict( else: items.append((new_key, v)) return dict(items) + + +def log_stat(stats: dict, name: str, xs: torch.Tensor, mask: torch.Tensor, n: int): + mean = (xs * mask).sum() / n + stats.update( + { + f"{name}/mean": mean, + f"{name}/min": torch.where(mask.bool(), xs, np.inf).min(), + f"{name}/max": torch.where(mask.bool(), xs, -np.inf).max(), + f"{name}/std": torch.sqrt(((xs - mean) * mask).pow(2).sum() / n), + } + ) + + +class RunningMoments: + def __init__(self): + """ + Calculates the running mean and standard deviation of a data stream. Modified version of + https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/running_mean_std.py + """ + self.mean = 0 + self.std = 1 + self.var = 1 + self.count = 1e-24 + + def update(self, xs: torch.Tensor) -> Tuple[float, float]: + """Updates running moments from batch's moments computed across ranks""" + if dist.is_initialized(): + xs_mean, xs_var, xs_count = get_global_statistics(xs) + else: + xs_count = xs.numel() + xs_var, xs_mean = torch.var_mean(xs, unbiased=False) + + delta = xs_mean - self.mean + tot_count = self.count + xs_count + + m_a = self.var * self.count + m_b = xs_var * xs_count + m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count + + self.mean += delta * xs_count / tot_count + self.var = m_2 / tot_count + self.std = (self.var * tot_count / (tot_count - 1)).sqrt() + self.count = tot_count + + return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt() From 267c9d44980e3f151a035636fefc18cdb02848d3 Mon Sep 17 00:00:00 2001 From: reciprocated <56548574+reciprocated@users.noreply.github.com> Date: Thu, 17 Nov 2022 14:43:15 +0200 Subject: [PATCH 2/2] feat(ppo): add optional reward clipping --- configs/ppo_config.yml | 2 ++ trlx/model/nn/ppo_models.py | 2 ++ trlx/orchestrator/ppo_orchestrator.py | 9 ++++++++- 3 files changed, 12 insertions(+), 1 deletion(-) diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index d6ee43c70..c526f564e 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -35,6 +35,8 @@ method: cliprange: 0.2 # clip range cliprange_value: 0.2 # clip range vf_coef: 2.3 # value term weight + scale_reward: True + clip_reward: 10 gen_kwargs: max_length: 48 # LM max sample gen length min_length: 48 # LM min sample gen length diff --git a/trlx/model/nn/ppo_models.py b/trlx/model/nn/ppo_models.py index 7c6158e3e..2e4129ea6 100644 --- a/trlx/model/nn/ppo_models.py +++ b/trlx/model/nn/ppo_models.py @@ -111,6 +111,8 @@ class PPOConfig(MethodConfig): cliprange: float cliprange_value: float vf_coef: float + scale_reward: bool + clip_reward: float gen_kwargs: dict def get_advantages_and_returns( diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index b097d95d3..291e09e02 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -87,12 +87,19 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): if self.ref_mean is None: self.ref_mean, self.ref_std = scores.mean(), scores.std() all_scores_mean, all_scores_std = self.running.update(scores) - scores /= self.running.std stats["exp_scores_mean"] = all_scores_mean stats["exp_scores_std"] = all_scores_std stats["running_mean"] = self.running.mean stats["running_std"] = self.running.std + + if self.rl_model.config.method.scale_reward: + scores /= self.running.std + + clip_reward = self.rl_model.config.method.clip_reward + if clip_reward: + scores = torch.clip(scores, -clip_reward, clip_reward) + # Precompute logprobs, values all_tokens = torch.cat( (query_tensors.to(samples.device), response_tensors), dim=1