Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update ppo value head + print logs #11

Merged
merged 10 commits into from
Oct 6, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 0 additions & 52 deletions configs/default_config.yml

This file was deleted.

82 changes: 41 additions & 41 deletions configs/ppo_config.yml
Original file line number Diff line number Diff line change
@@ -1,52 +1,52 @@
model:
model_path : "lvwerra/gpt2-imdb"
tokenizer_path : "gpt2"
model_type : "AcceleratePPOModel"
device : "cuda"
num_layers_unfrozen : -1
model_path : "lvwerra/gpt2-imdb" # Name of hf model to load
tokenizer_path : "gpt2" # Name of hf tokenizer to load
model_type : "AcceleratePPOModel" # Name of accelerate model type to load
device : "cuda" # Train device
num_layers_unfrozen : -1 # Number of bottom layers to freeze during training

train:
n_ctx : 512
epochs : 10
total_steps : 80000
batch_size : 16
grad_clip : 1.0
n_ctx : 512 # Size of LM context
epochs : 10 # Train for max(epochs, total_steps)
total_steps : 80000 # Train for max(epochs, total_steps)
batch_size : 128 # batch size
grad_clip : 1.0 # gradient clipping threshold

lr_ramp_steps : 100
lr_decay_steps : 10000000
weight_decay : 1.0e-6
learning_rate_init : 1.412e-5
learning_rate_target : 1.412e-5
lr_ramp_steps : 100 # learning rate warm up
lr_decay_steps : 79000 # learning rate decay
weight_decay : 1.0e-6 # weight decay param
learning_rate_init : 1.412e-4 # init learning rate
learning_rate_target : 1.412e-4 # target final learning rate

log_interval : 25
checkpoint_interval : 1000000
eval_interval : 100
log_interval : 25 # log interval
checkpoint_interval : 1000000 # checkpoint interval
eval_interval : 16 # eval interval

pipeline : "PPOPipeline"
orchestrator : "PPOSentimentOrchestrator"
pipeline : "PPOPipeline" # prompt pipeline to load
orchestrator : "PPOOrchestrator" # orchestrator to load

input_size : 8
gen_size : 16
input_size : 4 # max input size
gen_size : 48 # max gen size

accelerate : True
accelerate_config_path : '/fsx/alex/.cache/huggingface/accelerate/default_config.yaml'
accelerate : True # Use accelerate
accelerate_config_path : "" # Path to accelerate config(for logging purposes)

method:
name : 'ppoconfig'
num_rollouts : 16
chunk_size : 16
ppo_epochs : 4
init_kl_coef : 0.2
target : 6
horizon : 10000
gamma : 1
lam : 0.95
cliprange : 0.2
cliprange_value : 0.2
vf_coef : 0.2
name : 'ppoconfig' # Name of RL method config
num_rollouts : 128 # Number of rollouts to collect per epoch
chunk_size : 128 # Number of rollouts to collect in one loop of orchestrator
ppo_epochs : 4 # Number of ppo epochs
init_kl_coef : 0.2 # init kl coefficient
target : 6 # target kl coefficient
horizon : 10000 # PPO horizon
gamma : 1 # PPO discount
lam : 0.95 # PPO lambda
cliprange : 0.2 # clip range
cliprange_value : 0.2 # clip range
vf_coef : 0.2 # value term weight
gen_kwargs :
max_length : 24
min_length : 24
top_k : 0.0
top_p : 1.0
do_sample : True
max_length : 48 # LM max sample gen length
min_length : 48 # LM min sample gen length
top_k : 0.0 # top k
top_p : 1.0 # top p
do_sample : True # sample
7 changes: 4 additions & 3 deletions examples/ppo_sentiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from trlx.utils.loading import get_model, get_orchestrator, get_pipeline

if __name__ == "__main__":
cfg = TRLConfig.load_yaml("configs/default_config.yml")
cfg = TRLConfig.load_yaml("configs/ppo_config.yml")

sentiment_pipe = pipeline(
"sentiment-analysis", "lvwerra/distilbert-imdb", device=torch.device(0)
"sentiment-analysis", "lvwerra/distilbert-imdb", device=-1
)

def reward_fn(samples: List[str]):
Expand All @@ -28,7 +28,8 @@ def reward_fn(samples: List[str]):
return scores

model: AcceleratePPOModel = get_model(cfg.model.model_type)(cfg)
wandb.watch(model.model)
if model.accelerator.is_main_process:
wandb.watch(model.model)

pipeline: PPOPipeline = get_pipeline(cfg.train.pipeline)(model.tokenizer, cfg)
orch: PPOOrchestrator = get_orchestrator(cfg.train.orchestrator)(
Expand Down
17 changes: 13 additions & 4 deletions trlx/model/accelerate_base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
from trlx.pipeline.accelerate_base_pipeline import AccelerateRolloutStorage
from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask

WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1))
LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0))


@register_model
class AccelerateRLModel(BaseRLModel):
Expand All @@ -39,11 +42,17 @@ def __init__(self, config, rollout_storage, train_mode=True):
with open(self.config.train.accelerate_config_path, mode="r") as file:
accelerate_config = yaml.safe_load(file)
config_dict.update(accelerate_config)
# TODO(dahoas): might need to move this
self.accelerator = Accelerator(log_with="wandb")
self.accelerator.init_trackers(
self.config.train.project_name, config=config_dict
)

if WORLD_SIZE > 1:
torch.distributed.barrier(device_ids=[LOCAL_RANK])
else:
torch.random.manual_seed(1000)
if self.accelerator.is_main_process:
self.accelerator.init_trackers(
project_name=self.config.train.project_name, config=config_dict
)

self.opt = torch.optim.AdamW(
self.model.parameters(), lr=self.config.train.learning_rate_init
)
Expand Down
11 changes: 10 additions & 1 deletion trlx/model/accelerate_ilql_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,16 @@ def learn(self):

if opt_steps % self.config.train.eval_interval == 0:
logs.update(stats)
self.accelerator.log(logs)
if self.accelerator.is_main_process:
self.accelerator.log(logs)
self.accelerator.print(
"Step: {}, loss_cql: {}, loss_v: {}, reward: {}".format(
opt_steps,
logs["loss_cql"],
logs["loss_v"],
logs["reward"],
)
)

self.accelerator.backward(loss)
self.opt.step()
Expand Down
53 changes: 44 additions & 9 deletions trlx/model/accelerate_ppo_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@
from tqdm import tqdm
from transformers import AutoConfig, AutoTokenizer

import wandb
from trlx.data.accelerate_base_datatypes import PromptBatch
from trlx.data.configs import TRLConfig
from trlx.model import BaseRLModel, register_model
from trlx.model.accelerate_base_model import AccelerateRLModel
from trlx.model.nn.ppo_models import GPT2HeadWithValueModel
from trlx.model.nn.ppo_models import GPTHeadWithValueModel
from trlx.pipeline.ppo_pipeline import PPORolloutStorage
from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask
from trlx.utils.modeling import clip_by_value, logprobs_from_logits, whiten
Expand All @@ -27,8 +28,8 @@ def __init__(self, config, train_mode=True):
super().__init__(config, self.store)

def get_arch(self, config: TRLConfig):
# TODO(dahoas): Assumes model is gpt2 based
return GPT2HeadWithValueModel.from_pretrained(self.config.model.model_path)
# TODO(dahoas): Assumes model is gpt like
return GPTHeadWithValueModel(self.config.model.model_path)

def loss(
self, query_tensors, response_tensors, all_logprobs, all_values, all_rewards
Expand Down Expand Up @@ -82,7 +83,7 @@ def loss(
pg_loss = torch.mean(torch.max(pg_losses, pg_losses2))

model_loss = pg_loss + self.config.method.vf_coef * vf_loss
return model_loss
return model_loss, pg_loss, vf_loss

def post_epoch_callback(self):
# TODO(dahoas): are experiences being made for dataloaders on each process or same dataloader
Expand All @@ -92,8 +93,34 @@ def post_epoch_callback(self):
self.config.method.num_rollouts, self.iter_count
) # Collect more rollouts for training

def post_backward_callback(self, batch, rewards):
pass
def post_backward_callback(self):
batch = self.logs["batch"]
if self.accelerator.is_main_process:
if (
self.iter_count % self.config.train.eval_interval == 0
or self.iter_count <= self.config.method.ppo_epochs
):
text = self.tokenizer.batch_decode(batch.query_tensors)
eval_batch: PromptBatch = PromptBatch(
text=text, tokens=batch.query_tensors
)
query_tensors, response_tensors, response_text = self.act(eval_batch)
gen_texts = [q + r for q, r in zip(eval_batch.text, response_text)]
scores = self.orch.score(gen_texts)
mean_score = torch.mean(scores).item()
rows = list(zip(gen_texts, scores.tolist()))
stats = {
"mean_score": mean_score,
"responses": wandb.Table(columns=["response", "score"], rows=rows),
"pg_loss": self.logs["pg_loss"],
"vf_loss": self.logs["vf_loss"],
}
self.accelerator.log(stats, step=self.iter_count)
self.accelerator.print(
"Step: {}, Mean score: {}, pg_loss: {}, vf_loss: {}".format(
self.iter_count, mean_score, stats["pg_loss"], stats["vf_loss"]
)
)

def learn(self, log_fn=None, save_fn=None, eval_fn=None):

Expand All @@ -117,17 +144,25 @@ def learn(self, log_fn=None, save_fn=None, eval_fn=None):
rewards = batch.rewards.to(self.accelerator.device)

for _ in range(self.config.method.ppo_epochs):
loss = self.loss(
loss, pg_loss, vf_loss = self.loss(
query_tensors, response_tensors, logprobs, values, rewards
)
self.logs = {
"loss": loss,
"pg_loss": pg_loss,
"vf_loss": vf_loss,
"batch": batch,
"rewards": rewards,
}
# self.post_backward_callback()
# exit()
self.opt.zero_grad()
self.accelerator.backward(loss)
self.opt.step()
self.scheduler.step()
self.iter_count += 1

self.post_backward_callback(batch, rewards)

self.post_backward_callback()
self.accelerator.wait_for_everyone()

self.post_epoch_callback()
Expand Down
Loading