-
Notifications
You must be signed in to change notification settings - Fork 470
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add optional reward scaling #95
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ model: | |
|
||
train: | ||
seq_length: 48 # Size of LM context | ||
epochs: 1000 # Train for max(epochs, total_steps) | ||
epochs: 100 # Train for max(epochs, total_steps) | ||
total_steps: 10000 # Train for max(epochs, total_steps) | ||
batch_size: 128 # batch size | ||
|
||
|
@@ -35,6 +35,8 @@ method: | |
cliprange: 0.2 # clip range | ||
cliprange_value: 0.2 # clip range | ||
vf_coef: 2.3 # value term weight | ||
scale_reward: True | ||
clip_reward: 10 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't reward generally be in the range [-1,1]? |
||
gen_kwargs: | ||
max_length: 48 # LM max sample gen length | ||
min_length: 48 # LM min sample gen length | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -8,8 +8,9 @@ | |
from trlx.orchestrator import Orchestrator, register_orchestrator | ||
from trlx.pipeline import BasePipeline | ||
from trlx.utils import Clock | ||
from trlx.utils.modeling import logprobs_from_logits | ||
from trlx.utils.modeling import logprobs_from_logits, RunningMoments | ||
|
||
from time import time | ||
import ray | ||
from ray.air import session | ||
|
||
|
@@ -45,6 +46,10 @@ def __init__( | |
self.rl_model.reward_fn = reward_fn | ||
self.rl_model.metric_fn = metric_fn | ||
|
||
self.running = RunningMoments() | ||
self.ref_mean = None | ||
self.ref_std = None | ||
|
||
def score(self, samples): | ||
""" | ||
Batched scoring function taking text and generating scalar | ||
|
@@ -66,14 +71,34 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): | |
self.pipeline_iterator = iter(self.pipeline_loader) | ||
batch = next(self.pipeline_iterator) | ||
|
||
exp_generate_time = time() | ||
samples = self.rl_model.generate(**batch) | ||
stats["exp_generate_time"] = time() - exp_generate_time | ||
|
||
query_tensors = batch.input_ids | ||
response_tensors = samples[:, query_tensors.shape[1] :] | ||
texts = self.rl_model.tokenizer.batch_decode( | ||
samples, skip_special_tokens=True | ||
) | ||
scores = torch.as_tensor(self.score(texts)) | ||
exp_score_time = time() | ||
scores = torch.as_tensor(self.score(texts), device=samples.device) | ||
stats["exp_score_time"] = time() - exp_score_time | ||
|
||
if self.ref_mean is None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some comments here about what this does would be helpful :) |
||
self.ref_mean, self.ref_std = scores.mean(), scores.std() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps naming this ref_mean is a bit misleading? It is not the mean of the reference model but rather the mean of the training model. |
||
all_scores_mean, all_scores_std = self.running.update(scores) | ||
|
||
stats["exp_scores_mean"] = all_scores_mean | ||
stats["exp_scores_std"] = all_scores_std | ||
stats["running_mean"] = self.running.mean | ||
stats["running_std"] = self.running.std | ||
|
||
if self.rl_model.config.method.scale_reward: | ||
scores /= self.running.std | ||
|
||
clip_reward = self.rl_model.config.method.clip_reward | ||
if clip_reward: | ||
scores = torch.clip(scores, -clip_reward, clip_reward) | ||
|
||
# Precompute logprobs, values | ||
all_tokens = torch.cat( | ||
|
@@ -126,7 +151,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): | |
] | ||
ppo_rl_elements += new_ppo_rl_elements | ||
|
||
stats = {"exp_time": exp_time} | ||
stats["kl_ctl_value"] = self.rl_model.kl_ctl.value | ||
stats["exp_time"] = exp_time | ||
|
||
if not ray.is_initialized(): | ||
self.rl_model.accelerator.log(stats, step=iter_count) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,12 +2,33 @@ | |
|
||
import torch | ||
import torch.nn.functional as F | ||
import torch.distributed as dist | ||
from typing import Tuple | ||
|
||
|
||
def whiten(values, shift_mean=True): | ||
"""Whiten values.""" | ||
mean, var = torch.mean(values), torch.var(values) | ||
whitened = (values - mean) * torch.rsqrt(var + 1e-8) | ||
def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]: | ||
""" | ||
Computes element-wise mean and variance of the tensor across processes | ||
""" | ||
sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device) | ||
dist.all_reduce(sum_and_count, dist.ReduceOp.SUM) | ||
global_sum, count = sum_and_count | ||
global_mean = global_sum / count | ||
|
||
sum_var = torch.sum((xs - global_mean) ** 2) | ||
dist.all_reduce(sum_var, dist.ReduceOp.SUM) | ||
global_var = sum_var / count | ||
return global_mean, global_var, count | ||
|
||
|
||
def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor: | ||
"""Whitens values""" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Out of curiosity do we have a reference for whitening? (Some blog post, arxiv paper) |
||
if distributed and dist.is_initialized(): | ||
mean, var, _ = get_global_statistics(xs) | ||
else: | ||
var, mean = torch.var_mean(xs) | ||
|
||
whitened = (xs - mean) * torch.rsqrt(var + 1e-8) | ||
if not shift_mean: | ||
whitened += mean | ||
return whitened | ||
|
@@ -34,3 +55,49 @@ def flatten_dict( | |
else: | ||
items.append((new_key, v)) | ||
return dict(items) | ||
|
||
|
||
def log_stat(stats: dict, name: str, xs: torch.Tensor, mask: torch.Tensor, n: int): | ||
mean = (xs * mask).sum() / n | ||
stats.update( | ||
{ | ||
f"{name}/mean": mean, | ||
f"{name}/min": torch.where(mask.bool(), xs, np.inf).min(), | ||
f"{name}/max": torch.where(mask.bool(), xs, -np.inf).max(), | ||
f"{name}/std": torch.sqrt(((xs - mean) * mask).pow(2).sum() / n), | ||
} | ||
) | ||
|
||
|
||
class RunningMoments: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Very nice |
||
def __init__(self): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we precompute the mean and var of our initial reward distribution ahead of time do we have a way of incorporating that? |
||
""" | ||
Calculates the running mean and standard deviation of a data stream. Modified version of | ||
https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/running_mean_std.py | ||
""" | ||
self.mean = 0 | ||
self.std = 1 | ||
self.var = 1 | ||
self.count = 1e-24 | ||
|
||
def update(self, xs: torch.Tensor) -> Tuple[float, float]: | ||
"""Updates running moments from batch's moments computed across ranks""" | ||
if dist.is_initialized(): | ||
xs_mean, xs_var, xs_count = get_global_statistics(xs) | ||
else: | ||
xs_count = xs.numel() | ||
xs_var, xs_mean = torch.var_mean(xs, unbiased=False) | ||
|
||
delta = xs_mean - self.mean | ||
tot_count = self.count + xs_count | ||
|
||
m_a = self.var * self.count | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I hate having lots of math and no intuitive explanation of what the math is doing as comments. Please fix. |
||
m_b = xs_var * xs_count | ||
m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count | ||
|
||
self.mean += delta * xs_count / tot_count | ||
self.var = m_2 / tot_count | ||
self.std = (self.var * tot_count / (tot_count - 1)).sqrt() | ||
self.count = tot_count | ||
|
||
return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank god we finally decreased this haha