diff --git a/configs/default_config.yml b/configs/default_config.yml deleted file mode 100644 index af4a6a3b7..000000000 --- a/configs/default_config.yml +++ /dev/null @@ -1,52 +0,0 @@ -model: - model_path : "lvwerra/gpt2-imdb" # Name of hf model to load - tokenizer_path : "gpt2" # Name of hf tokenizer to load - model_type : "AcceleratePPOModel" # Name of accelerate model type to load - device : "cuda" # Train device - num_layers_unfrozen : -1 # Number of bottom layers to freeze during training - -train: - n_ctx : 512 # Size of LM context - epochs : 10 # Train for max(epochs, total_steps) - total_steps : 80000 # Train for max(epochs, total_steps) - batch_size : 16 # batch size - grad_clip : 1.0 # gradient clipping threshold - - lr_ramp_steps : 100 # learning rate warm up - lr_decay_steps : 79000 # learning rate decay - weight_decay : 1.0e-6 # weight decay param - learning_rate_init : 1.412e-5 # init learning rate - learning_rate_target : 1.412e-5 # target final learning rate - - log_interval : 25 # log interval - checkpoint_interval : 1000000 # checkpoint interval - eval_interval : 100 # eval interval - - pipeline : "PPOPipeline" # prompt pipeline to load - orchestrator : "PPOOrchestrator" # orchestrator to load - - input_size : 8 # max input size - gen_size : 16 # max gen size - - accelerate : True # Use accelerate - accelerate_config_path : "" # Path to accelerate config(for logging purposes) - -method: - name : 'ppoconfig' # Name of RL method config - num_rollouts : 16 # Number of rollouts to collect per epoch - chunk_size : 16 # Number of rollouts to collect in one loop of orchestrator - ppo_epochs : 4 # Number of ppo epochs - init_kl_coef : 0.2 # init kl coefficient - target : 6 # target kl coefficient - horizon : 10000 # PPO horizon - gamma : 1 # PPO discount - lam : 0.95 # PPO lambda - cliprange : 0.2 # clip range - cliprange_value : 0.2 # clip range - vf_coef : 0.2 # value term weight - gen_kwargs : - max_length : 24 # LM max sample gen length - min_length : 24 # LM min sample gen length - top_k : 0.0 # top k - top_p : 1.0 # top p - do_sample : True # sample \ No newline at end of file diff --git a/configs/ppo_config.yml b/configs/ppo_config.yml index 8f3e2df36..34ff38c0f 100644 --- a/configs/ppo_config.yml +++ b/configs/ppo_config.yml @@ -1,52 +1,52 @@ model: - model_path : "lvwerra/gpt2-imdb" - tokenizer_path : "gpt2" - model_type : "AcceleratePPOModel" - device : "cuda" - num_layers_unfrozen : -1 + model_path : "lvwerra/gpt2-imdb" # Name of hf model to load + tokenizer_path : "gpt2" # Name of hf tokenizer to load + model_type : "AcceleratePPOModel" # Name of accelerate model type to load + device : "cuda" # Train device + num_layers_unfrozen : -1 # Number of bottom layers to freeze during training train: - n_ctx : 512 - epochs : 10 - total_steps : 80000 - batch_size : 16 - grad_clip : 1.0 + n_ctx : 512 # Size of LM context + epochs : 10 # Train for max(epochs, total_steps) + total_steps : 80000 # Train for max(epochs, total_steps) + batch_size : 128 # batch size + grad_clip : 1.0 # gradient clipping threshold - lr_ramp_steps : 100 - lr_decay_steps : 10000000 - weight_decay : 1.0e-6 - learning_rate_init : 1.412e-5 - learning_rate_target : 1.412e-5 + lr_ramp_steps : 100 # learning rate warm up + lr_decay_steps : 79000 # learning rate decay + weight_decay : 1.0e-6 # weight decay param + learning_rate_init : 1.412e-4 # init learning rate + learning_rate_target : 1.412e-4 # target final learning rate - log_interval : 25 - checkpoint_interval : 1000000 - eval_interval : 100 + log_interval : 25 # log interval + checkpoint_interval : 1000000 # checkpoint interval + eval_interval : 16 # eval interval - pipeline : "PPOPipeline" - orchestrator : "PPOSentimentOrchestrator" + pipeline : "PPOPipeline" # prompt pipeline to load + orchestrator : "PPOOrchestrator" # orchestrator to load - input_size : 8 - gen_size : 16 + input_size : 4 # max input size + gen_size : 48 # max gen size - accelerate : True - accelerate_config_path : '/fsx/alex/.cache/huggingface/accelerate/default_config.yaml' + accelerate : True # Use accelerate + accelerate_config_path : "" # Path to accelerate config(for logging purposes) method: - name : 'ppoconfig' - num_rollouts : 16 - chunk_size : 16 - ppo_epochs : 4 - init_kl_coef : 0.2 - target : 6 - horizon : 10000 - gamma : 1 - lam : 0.95 - cliprange : 0.2 - cliprange_value : 0.2 - vf_coef : 0.2 + name : 'ppoconfig' # Name of RL method config + num_rollouts : 128 # Number of rollouts to collect per epoch + chunk_size : 128 # Number of rollouts to collect in one loop of orchestrator + ppo_epochs : 4 # Number of ppo epochs + init_kl_coef : 0.2 # init kl coefficient + target : 6 # target kl coefficient + horizon : 10000 # PPO horizon + gamma : 1 # PPO discount + lam : 0.95 # PPO lambda + cliprange : 0.2 # clip range + cliprange_value : 0.2 # clip range + vf_coef : 0.2 # value term weight gen_kwargs : - max_length : 24 - min_length : 24 - top_k : 0.0 - top_p : 1.0 - do_sample : True \ No newline at end of file + max_length : 48 # LM max sample gen length + min_length : 48 # LM min sample gen length + top_k : 0.0 # top k + top_p : 1.0 # top p + do_sample : True # sample \ No newline at end of file diff --git a/examples/ppo_sentiments.py b/examples/ppo_sentiments.py index 50a0a55a4..ecf2805e1 100644 --- a/examples/ppo_sentiments.py +++ b/examples/ppo_sentiments.py @@ -11,10 +11,10 @@ from trlx.utils.loading import get_model, get_orchestrator, get_pipeline if __name__ == "__main__": - cfg = TRLConfig.load_yaml("configs/default_config.yml") + cfg = TRLConfig.load_yaml("configs/ppo_config.yml") sentiment_pipe = pipeline( - "sentiment-analysis", "lvwerra/distilbert-imdb", device=torch.device(0) + "sentiment-analysis", "lvwerra/distilbert-imdb", device=-1 ) def reward_fn(samples: List[str]): @@ -28,7 +28,8 @@ def reward_fn(samples: List[str]): return scores model: AcceleratePPOModel = get_model(cfg.model.model_type)(cfg) - wandb.watch(model.model) + if model.accelerator.is_main_process: + wandb.watch(model.model) pipeline: PPOPipeline = get_pipeline(cfg.train.pipeline)(model.tokenizer, cfg) orch: PPOOrchestrator = get_orchestrator(cfg.train.orchestrator)( diff --git a/trlx/model/accelerate_base_model.py b/trlx/model/accelerate_base_model.py index 3e5f95656..42c14a1ad 100644 --- a/trlx/model/accelerate_base_model.py +++ b/trlx/model/accelerate_base_model.py @@ -18,6 +18,9 @@ from trlx.pipeline.accelerate_base_pipeline import AccelerateRolloutStorage from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", 1)) +LOCAL_RANK = int(os.environ.get("LOCAL_RANK", 0)) + @register_model class AccelerateRLModel(BaseRLModel): @@ -39,11 +42,17 @@ def __init__(self, config, rollout_storage, train_mode=True): with open(self.config.train.accelerate_config_path, mode="r") as file: accelerate_config = yaml.safe_load(file) config_dict.update(accelerate_config) - # TODO(dahoas): might need to move this self.accelerator = Accelerator(log_with="wandb") - self.accelerator.init_trackers( - self.config.train.project_name, config=config_dict - ) + + if WORLD_SIZE > 1: + torch.distributed.barrier(device_ids=[LOCAL_RANK]) + else: + torch.random.manual_seed(1000) + if self.accelerator.is_main_process: + self.accelerator.init_trackers( + project_name=self.config.train.project_name, config=config_dict + ) + self.opt = torch.optim.AdamW( self.model.parameters(), lr=self.config.train.learning_rate_init ) diff --git a/trlx/model/accelerate_ilql_model.py b/trlx/model/accelerate_ilql_model.py index 16c2e6ef2..519d1cc41 100644 --- a/trlx/model/accelerate_ilql_model.py +++ b/trlx/model/accelerate_ilql_model.py @@ -157,7 +157,16 @@ def learn(self): if opt_steps % self.config.train.eval_interval == 0: logs.update(stats) - self.accelerator.log(logs) + if self.accelerator.is_main_process: + self.accelerator.log(logs) + self.accelerator.print( + "Step: {}, loss_cql: {}, loss_v: {}, reward: {}".format( + opt_steps, + logs["loss_cql"], + logs["loss_v"], + logs["reward"], + ) + ) self.accelerator.backward(loss) self.opt.step() diff --git a/trlx/model/accelerate_ppo_model.py b/trlx/model/accelerate_ppo_model.py index e66552249..21933660d 100644 --- a/trlx/model/accelerate_ppo_model.py +++ b/trlx/model/accelerate_ppo_model.py @@ -10,11 +10,12 @@ from tqdm import tqdm from transformers import AutoConfig, AutoTokenizer +import wandb from trlx.data.accelerate_base_datatypes import PromptBatch from trlx.data.configs import TRLConfig from trlx.model import BaseRLModel, register_model from trlx.model.accelerate_base_model import AccelerateRLModel -from trlx.model.nn.ppo_models import GPT2HeadWithValueModel +from trlx.model.nn.ppo_models import GPTHeadWithValueModel from trlx.pipeline.ppo_pipeline import PPORolloutStorage from trlx.utils import Clock, rampup_decay, safe_mkdir, topk_mask from trlx.utils.modeling import clip_by_value, logprobs_from_logits, whiten @@ -27,8 +28,8 @@ def __init__(self, config, train_mode=True): super().__init__(config, self.store) def get_arch(self, config: TRLConfig): - # TODO(dahoas): Assumes model is gpt2 based - return GPT2HeadWithValueModel.from_pretrained(self.config.model.model_path) + # TODO(dahoas): Assumes model is gpt like + return GPTHeadWithValueModel(self.config.model.model_path) def loss( self, query_tensors, response_tensors, all_logprobs, all_values, all_rewards @@ -82,7 +83,7 @@ def loss( pg_loss = torch.mean(torch.max(pg_losses, pg_losses2)) model_loss = pg_loss + self.config.method.vf_coef * vf_loss - return model_loss + return model_loss, pg_loss, vf_loss def post_epoch_callback(self): # TODO(dahoas): are experiences being made for dataloaders on each process or same dataloader @@ -92,8 +93,34 @@ def post_epoch_callback(self): self.config.method.num_rollouts, self.iter_count ) # Collect more rollouts for training - def post_backward_callback(self, batch, rewards): - pass + def post_backward_callback(self): + batch = self.logs["batch"] + if self.accelerator.is_main_process: + if ( + self.iter_count % self.config.train.eval_interval == 0 + or self.iter_count <= self.config.method.ppo_epochs + ): + text = self.tokenizer.batch_decode(batch.query_tensors) + eval_batch: PromptBatch = PromptBatch( + text=text, tokens=batch.query_tensors + ) + query_tensors, response_tensors, response_text = self.act(eval_batch) + gen_texts = [q + r for q, r in zip(eval_batch.text, response_text)] + scores = self.orch.score(gen_texts) + mean_score = torch.mean(scores).item() + rows = list(zip(gen_texts, scores.tolist())) + stats = { + "mean_score": mean_score, + "responses": wandb.Table(columns=["response", "score"], rows=rows), + "pg_loss": self.logs["pg_loss"], + "vf_loss": self.logs["vf_loss"], + } + self.accelerator.log(stats, step=self.iter_count) + self.accelerator.print( + "Step: {}, Mean score: {}, pg_loss: {}, vf_loss: {}".format( + self.iter_count, mean_score, stats["pg_loss"], stats["vf_loss"] + ) + ) def learn(self, log_fn=None, save_fn=None, eval_fn=None): @@ -117,17 +144,25 @@ def learn(self, log_fn=None, save_fn=None, eval_fn=None): rewards = batch.rewards.to(self.accelerator.device) for _ in range(self.config.method.ppo_epochs): - loss = self.loss( + loss, pg_loss, vf_loss = self.loss( query_tensors, response_tensors, logprobs, values, rewards ) + self.logs = { + "loss": loss, + "pg_loss": pg_loss, + "vf_loss": vf_loss, + "batch": batch, + "rewards": rewards, + } + # self.post_backward_callback() + # exit() self.opt.zero_grad() self.accelerator.backward(loss) self.opt.step() self.scheduler.step() self.iter_count += 1 - self.post_backward_callback(batch, rewards) - + self.post_backward_callback() self.accelerator.wait_for_everyone() self.post_epoch_callback() diff --git a/trlx/model/nn/ppo_models.py b/trlx/model/nn/ppo_models.py index 40e13ca38..988f5d316 100644 --- a/trlx/model/nn/ppo_models.py +++ b/trlx/model/nn/ppo_models.py @@ -1,12 +1,14 @@ from dataclasses import dataclass -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import torch import torch.nn.functional as F from torch import nn from torch.nn import Identity -from transformers import (GPT2LMHeadModel, GPT2Model, GPT2PreTrainedModel, - GPT2Tokenizer, top_k_top_p_filtering) +from transformers import (AutoConfig, AutoModelForCausalLM, GPT2LMHeadModel, + GPT2Model, GPT2PreTrainedModel, GPT2Tokenizer, + GPTJModel, PretrainedConfig, PreTrainedModel, + top_k_top_p_filtering) from transformers.modeling_outputs import ModelOutput @@ -25,83 +27,34 @@ class CausalLMOutputWithCrossAttentions(ModelOutput): # Cell -class ValueHead(nn.Module): - """The ValueHead class implements a head for GPT2 that returns a scalar for each output token.""" - - def __init__(self, config): - super().__init__() - self.detach_head = False - self.summary_type = ( - config.summary_type if hasattr(config, "summary_type") else "last" - ) - if self.summary_type == "attn": - raise NotImplementedError - - self.summary = Identity() - if hasattr(config, "summary_use_proj") and config.summary_use_proj: - if ( - hasattr(config, "summary_proj_to_labels") - and config.summary_proj_to_labels - and config.num_labels > 0 - ): - num_classes = config.num_labels - else: - num_classes = config.hidden_size - self.summary = nn.Linear(config.hidden_size, num_classes) - - self.activation = Identity() - if ( - hasattr(config, "summary_activation") - and config.summary_activation == "tanh" - ): - self.activation = nn.Tanh() - - self.first_dropout = Identity() - if ( - hasattr(config, "summary_first_dropout") - and config.summary_first_dropout > 0 - ): - self.first_dropout = nn.Dropout(config.summary_first_dropout) - - self.last_dropout = Identity() - if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0: - self.last_dropout = nn.Dropout(config.summary_last_dropout) - - self.flatten = nn.Flatten() - - def forward(self, hidden_states, cls_index=None): - if self.detach_head: - output = hidden_states.detach() - else: - output = hidden_states - output = self.first_dropout(output) - output = self.summary(output) - output = self.activation(output) - output = self.last_dropout(output) - - return output +def make_head(n_embd: int, out: int): + return nn.Sequential( + nn.Linear(n_embd, n_embd * 2), nn.ReLU(), nn.Linear(n_embd * 2, out) + ) # Cell -class GPT2HeadWithValueModel(GPT2PreTrainedModel): - """The GPT2HeadWithValueModel class implements a GPT2 language model with a secondary, scalar head.""" +class GPTHeadWithValueModel(nn.Module): + """The GPTHeadWithValueModel class implements a GPT-type language model with a secondary, scalar head.""" - def __init__(self, config): - super().__init__(config) - config.num_labels = 1 - self.transformer = GPT2Model(config) - self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False) - self.v_head = ValueHead(config) + def __init__(self, config: Union[PretrainedConfig, str]): + super().__init__() + if isinstance(config, PretrainedConfig): + self.gpt = AutoModelForCausalLM.from_config(config) + else: + self.gpt = AutoModelForCausalLM.from_pretrained(config) - self.init_weights() + if hasattr(self.gpt.config, "hidden_size"): + self.n_embd = self.gpt.config.hidden_size + else: + self.n_embd = self.gpt.config.n_embd - def get_output_embeddings(self): - return self.lm_head + self.v_head = make_head(self.n_embd, 1) - def detach_value_head(self): - self.v_head.detach_head = True + def generate(self, input_ids, **x): + return self.gpt.generate(input_ids, **x) def forward( self, @@ -120,7 +73,7 @@ def forward( output_hidden_states=False, ): loss = None - transformer_outputs = self.transformer( + transformer_outputs = self.gpt.transformer( input_ids, past_key_values=past_key_values, attention_mask=attention_mask, @@ -129,10 +82,8 @@ def forward( head_mask=head_mask, inputs_embeds=inputs_embeds, ) - hidden_states = transformer_outputs[0] - - lm_logits = self.lm_head(hidden_states) + lm_logits = self.gpt.lm_head(hidden_states) value = self.v_head(hidden_states).squeeze(-1) if not return_dict: diff --git a/trlx/orchestrator/offline_orchestrator.py b/trlx/orchestrator/offline_orchestrator.py index 2eaccc77b..f8e75f3e0 100644 --- a/trlx/orchestrator/offline_orchestrator.py +++ b/trlx/orchestrator/offline_orchestrator.py @@ -36,5 +36,5 @@ def __init__( self.model.reward_fn = reward_fn self.model.stats_fn = stats_fn - def score(samples): + def score(self, samples): return self.model.reward_fn(samples) diff --git a/trlx/orchestrator/ppo_orchestrator.py b/trlx/orchestrator/ppo_orchestrator.py index 3cb2dbe21..293edae6a 100644 --- a/trlx/orchestrator/ppo_orchestrator.py +++ b/trlx/orchestrator/ppo_orchestrator.py @@ -65,10 +65,6 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): # Precompute logprobs, values all_tokens = torch.cat((query_tensors, response_tensors), dim=1) - assert ( - all_tokens.size()[1] - == query_tensors.size()[1] + response_tensors.size()[1] - ) with torch.no_grad(): logits, _, v = self.rl_model.model(all_tokens) # TODO(dahoas): Need to make decision about what to do with ref model: keep on cpu? @@ -99,16 +95,8 @@ def make_experience(self, num_rollouts: int = 1024, iter_count: int = 0): exp_time = clock.tick() # Evaluate model on first chunk - if i == 0: - mean_score = torch.mean(scores).item() - rows = list(zip(texts, scores.tolist())) - stats = { - "exp_time": exp_time, - "mean_score": mean_score, - "responses": wandb.Table( - columns=["response", "score"], rows=rows[:16] - ), - } + if i == 0 and self.rl_model.accelerator.is_main_process: + stats = {"exp_time": exp_time} self.rl_model.accelerator.log(stats, step=iter_count) new_ppo_rl_elements = [ diff --git a/trlx/utils/__init__.py b/trlx/utils/__init__.py index 48779cbef..d136d460e 100644 --- a/trlx/utils/__init__.py +++ b/trlx/utils/__init__.py @@ -1,7 +1,7 @@ import os import time from functools import reduce -from typing import Any, Callable, Iterable, List +from typing import Any, Iterable, List import numpy as np import torch