Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add optional reward scaling #95

Merged
merged 2 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion configs/ppo_config.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ model:

train:
seq_length: 48 # Size of LM context
epochs: 1000 # Train for max(epochs, total_steps)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank god we finally decreased this haha

epochs: 100 # Train for max(epochs, total_steps)
total_steps: 10000 # Train for max(epochs, total_steps)
batch_size: 128 # batch size

Expand Down Expand Up @@ -35,6 +35,8 @@ method:
cliprange: 0.2 # clip range
cliprange_value: 0.2 # clip range
vf_coef: 2.3 # value term weight
scale_reward: True
clip_reward: 10
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't reward generally be in the range [-1,1]?

gen_kwargs:
max_length: 48 # LM max sample gen length
min_length: 48 # LM min sample gen length
Expand Down
2 changes: 2 additions & 0 deletions trlx/model/nn/ppo_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ class PPOConfig(MethodConfig):
cliprange: float
cliprange_value: float
vf_coef: float
scale_reward: bool
clip_reward: float
gen_kwargs: dict

def get_advantages_and_returns(
Expand Down
32 changes: 29 additions & 3 deletions trlx/orchestrator/ppo_orchestrator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from trlx.orchestrator import Orchestrator, register_orchestrator
from trlx.pipeline import BasePipeline
from trlx.utils import Clock
from trlx.utils.modeling import logprobs_from_logits
from trlx.utils.modeling import logprobs_from_logits, RunningMoments

from time import time
import ray
from ray.air import session

Expand Down Expand Up @@ -45,6 +46,10 @@ def __init__(
self.rl_model.reward_fn = reward_fn
self.rl_model.metric_fn = metric_fn

self.running = RunningMoments()
self.ref_mean = None
self.ref_std = None

def score(self, samples):
"""
Batched scoring function taking text and generating scalar
Expand All @@ -66,14 +71,34 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
self.pipeline_iterator = iter(self.pipeline_loader)
batch = next(self.pipeline_iterator)

exp_generate_time = time()
samples = self.rl_model.generate(**batch)
stats["exp_generate_time"] = time() - exp_generate_time

query_tensors = batch.input_ids
response_tensors = samples[:, query_tensors.shape[1] :]
texts = self.rl_model.tokenizer.batch_decode(
samples, skip_special_tokens=True
)
scores = torch.as_tensor(self.score(texts))
exp_score_time = time()
scores = torch.as_tensor(self.score(texts), device=samples.device)
stats["exp_score_time"] = time() - exp_score_time

if self.ref_mean is None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some comments here about what this does would be helpful :)

self.ref_mean, self.ref_std = scores.mean(), scores.std()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps naming this ref_mean is a bit misleading? It is not the mean of the reference model but rather the mean of the training model.

all_scores_mean, all_scores_std = self.running.update(scores)

stats["exp_scores_mean"] = all_scores_mean
stats["exp_scores_std"] = all_scores_std
stats["running_mean"] = self.running.mean
stats["running_std"] = self.running.std

if self.rl_model.config.method.scale_reward:
scores /= self.running.std

clip_reward = self.rl_model.config.method.clip_reward
if clip_reward:
scores = torch.clip(scores, -clip_reward, clip_reward)

# Precompute logprobs, values
all_tokens = torch.cat(
Expand Down Expand Up @@ -126,7 +151,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0):
]
ppo_rl_elements += new_ppo_rl_elements

stats = {"exp_time": exp_time}
stats["kl_ctl_value"] = self.rl_model.kl_ctl.value
stats["exp_time"] = exp_time

if not ray.is_initialized():
self.rl_model.accelerator.log(stats, step=iter_count)
Expand Down
75 changes: 71 additions & 4 deletions trlx/utils/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,33 @@

import torch
import torch.nn.functional as F
import torch.distributed as dist
from typing import Tuple


def whiten(values, shift_mean=True):
"""Whiten values."""
mean, var = torch.mean(values), torch.var(values)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
def get_global_statistics(xs: torch.Tensor) -> Tuple[float, float, int]:
"""
Computes element-wise mean and variance of the tensor across processes
"""
sum_and_count = torch.tensor([xs.sum(), xs.numel()], device=xs.device)
dist.all_reduce(sum_and_count, dist.ReduceOp.SUM)
global_sum, count = sum_and_count
global_mean = global_sum / count

sum_var = torch.sum((xs - global_mean) ** 2)
dist.all_reduce(sum_var, dist.ReduceOp.SUM)
global_var = sum_var / count
return global_mean, global_var, count


def whiten(xs: torch.Tensor, shift_mean=True, distributed=True) -> torch.Tensor:
"""Whitens values"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Out of curiosity do we have a reference for whitening? (Some blog post, arxiv paper)

if distributed and dist.is_initialized():
mean, var, _ = get_global_statistics(xs)
else:
var, mean = torch.var_mean(xs)

whitened = (xs - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
Expand All @@ -34,3 +55,49 @@ def flatten_dict(
else:
items.append((new_key, v))
return dict(items)


def log_stat(stats: dict, name: str, xs: torch.Tensor, mask: torch.Tensor, n: int):
mean = (xs * mask).sum() / n
stats.update(
{
f"{name}/mean": mean,
f"{name}/min": torch.where(mask.bool(), xs, np.inf).min(),
f"{name}/max": torch.where(mask.bool(), xs, -np.inf).max(),
f"{name}/std": torch.sqrt(((xs - mean) * mask).pow(2).sum() / n),
}
)


class RunningMoments:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice

def __init__(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we precompute the mean and var of our initial reward distribution ahead of time do we have a way of incorporating that?

"""
Calculates the running mean and standard deviation of a data stream. Modified version of
https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/running_mean_std.py
"""
self.mean = 0
self.std = 1
self.var = 1
self.count = 1e-24

def update(self, xs: torch.Tensor) -> Tuple[float, float]:
"""Updates running moments from batch's moments computed across ranks"""
if dist.is_initialized():
xs_mean, xs_var, xs_count = get_global_statistics(xs)
else:
xs_count = xs.numel()
xs_var, xs_mean = torch.var_mean(xs, unbiased=False)

delta = xs_mean - self.mean
tot_count = self.count + xs_count

m_a = self.var * self.count
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I hate having lots of math and no intuitive explanation of what the math is doing as comments. Please fix.

m_b = xs_var * xs_count
m_2 = m_a + m_b + delta**2 * self.count * xs_count / tot_count

self.mean += delta * xs_count / tot_count
self.var = m_2 / tot_count
self.std = (self.var * tot_count / (tot_count - 1)).sqrt()
self.count = tot_count

return xs_mean, (xs_var * xs_count / (xs_count - 1)).sqrt()