Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[chat]: update rm, add wandb and fix bugs #4471

Merged
merged 34 commits into from
Sep 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
99e0fd0
feat: modify forward fn of critic and reward model
CWHer Aug 16, 2023
482e1e4
feat: modify calc_action_log_probs
CWHer Aug 16, 2023
3c28e62
to: add wandb in sft and rm trainer
CWHer Aug 16, 2023
69da25b
feat: update train_sft
CWHer Aug 16, 2023
a4ab376
feat: update train_rm
CWHer Aug 16, 2023
11babfa
style: modify type annotation and add warning
CWHer Aug 16, 2023
88a5409
feat: pass tokenizer to ppo trainer
CWHer Aug 16, 2023
2097139
to: modify trainer base and maker base
CWHer Aug 16, 2023
18d664c
feat: add wandb in ppo trainer
CWHer Aug 16, 2023
1dd3269
feat: pass tokenizer to generate
CWHer Aug 16, 2023
5af5e58
test: update generate fn tests
CWHer Aug 16, 2023
a0e32aa
test: update train tests
CWHer Aug 17, 2023
aa94aa3
fix: remove action_mask
CWHer Aug 17, 2023
c1c8026
feat: remove unused code
CWHer Aug 17, 2023
b160a26
fix: fix wrong ignore_index
CWHer Aug 17, 2023
a81f004
fix: fix mock tokenizer
CWHer Aug 17, 2023
18f879b
chore: update requirements
CWHer Aug 17, 2023
f4cd1a5
revert: modify make_experience
CWHer Aug 18, 2023
640867e
fix: fix inference
CWHer Aug 18, 2023
a43f481
fix: add padding side
CWHer Aug 18, 2023
6c9fa1d
style: modify _on_learn_batch_end
CWHer Aug 18, 2023
d0166ec
test: use mock tokenizer
CWHer Aug 18, 2023
756f84a
fix: use bf16 to avoid overflow
CWHer Aug 19, 2023
dde4b13
fix: fix workflow
CWHer Aug 19, 2023
e8d1b7b
[chat] fix gemini strategy
flybird11111 Sep 7, 2023
d1084e4
[chat] fix
flybird11111 Sep 7, 2023
93caf5a
sync: update colossalai strategy
CWHer Sep 11, 2023
50488a0
fix: fix args and model dtype
CWHer Sep 20, 2023
2ef1100
fix: fix checkpoint test
CWHer Sep 11, 2023
c26f751
fix: fix requirements
CWHer Sep 11, 2023
82b61da
fix: fix missing import and wrong arg
CWHer Sep 20, 2023
ab6bc55
fix: temporarily skip gemini test in stage 3
CWHer Sep 20, 2023
3434309
style: apply pre-commit
CWHer Sep 20, 2023
1f3d7f1
fix: temporarily skip gemini test in stage 1&2
CWHer Sep 20, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/run_chatgpt_examples.yml
Original file line number Diff line number Diff line change
Expand Up @@ -49,5 +49,5 @@ jobs:
NCCL_SHM_DISABLE: 1
MAX_JOBS: 8
SFT_DATASET: /data/scratch/github_actions/chat/data.json
PROMPT_PATH: /data/scratch/github_actions/chat/prompts_en.jsonl
PROMPT_DATASET: /data/scratch/github_actions/chat/prompts_en.jsonl
PRETRAIN_DATASET: /data/scratch/github_actions/chat/alpaca_data.json
4 changes: 2 additions & 2 deletions applications/Chat/benchmarks/benchmark_opt_lora_dummy.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def main(args):

tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))

Expand All @@ -154,6 +155,7 @@ def main(args):
initial_model,
actor_optim,
critic_optim,
tokenizer=tokenizer,
ptx_coef=0,
train_batch_size=args.train_batch_size,
offload_inference_models=args.offload_inference_models,
Expand All @@ -162,8 +164,6 @@ def main(args):
temperature=1.0,
top_k=50,
use_cache=True,
pad_token_id=tokenizer.pad_token_id,
eos_token_id=tokenizer.eos_token_id,
callbacks=[performance_evaluator],
)

Expand Down
18 changes: 15 additions & 3 deletions applications/Chat/coati/dataset/sft_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import copy
from typing import Dict, Sequence, Tuple
from typing import Dict, Optional, Sequence, Tuple

import torch
from coati.models.chatglm.chatglm_tokenizer import ChatGLMTokenizer
Expand Down Expand Up @@ -57,16 +57,18 @@ def _preprocess(
sources, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt"
)

assert sequences_token["attention_mask"].dim() == 2, "seq2seq model should be preprocessed differently"
labels = copy.deepcopy(sequences_token["input_ids"])
for i in range(labels.shape[0]):
source_len = sources_token["attention_mask"][i].sum().item()
pad_len = max_length - sequences_token["attention_mask"][i].sum().item()
if tokenizer.padding_side == "right":
# |prompt|completion|eos|pad|
labels[i][:source_len] = IGNORE_INDEX
labels[i][-pad_len:] = IGNORE_INDEX
elif tokenizer.padding_side == "left":
# |pad|prompt|completion|eos|
labels[i][pad_len : pad_len + source_len] = IGNORE_INDEX
labels[i][: pad_len + source_len] = IGNORE_INDEX
else:
raise RuntimeError()

Expand Down Expand Up @@ -126,13 +128,17 @@ def __init__(self, dataset: Dict, tokenizer: PreTrainedTokenizer, max_length: in

sources = [data["prompt"] for data in dataset]
targets = [data["completion"] + tokenizer.eos_token for data in tqdm(dataset, disable=not is_rank_0())]

logger.info("Tokenizing inputs... This may take some time...")
if isinstance(tokenizer, ChatGLMTokenizer):
self.input_ids, self.labels, self.attention_mask = _preprocess_chatglm(
sources, targets, tokenizer, max_length
)
else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)

logger.info("Loaded dataset.")

def __len__(self):
length = self.input_ids.shape[0]
return length
Expand All @@ -148,7 +154,11 @@ class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""

def __init__(
self, data_path: str, tokenizer: PreTrainedTokenizer, max_datasets_size: int = None, max_length: int = 512
self,
data_path: str,
tokenizer: PreTrainedTokenizer,
max_datasets_size: Optional[int] = None,
max_length: int = 512,
):
super().__init__()
logger.info("Loading data...")
Expand All @@ -175,6 +185,8 @@ def __init__(
else:
self.input_ids, self.labels, self.attention_mask = _preprocess(sources, targets, tokenizer, max_length)

logger.info("Loaded dataset.")

def __len__(self):
length = self.input_ids.shape[0]
return length
Expand Down
3 changes: 3 additions & 0 deletions applications/Chat/coati/experience_buffer/naive.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import random
import warnings
from typing import List

import torch
Expand Down Expand Up @@ -30,9 +31,11 @@ def append(self, experience: Experience) -> None:
experience.to_device(torch.device("cpu"))
items = split_experience_batch(experience)
self.items.extend(items)

if self.limit > 0:
samples_to_remove = len(self.items) - self.limit
if samples_to_remove > 0:
warnings.warn(f"Experience buffer is full. Removing {samples_to_remove} samples.")
self.items = self.items[samples_to_remove:]

def clear(self) -> None:
Expand Down
10 changes: 3 additions & 7 deletions applications/Chat/coati/experience_maker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
from typing import Optional

import torch
import torch.nn as nn
from coati.models.base import Actor
from coati.models.base import Actor, Critic, RewardModel


@dataclass
Expand Down Expand Up @@ -59,16 +58,13 @@ def pin_memory(self):


class ExperienceMaker(ABC):
def __init__(
self, actor: Actor, critic: nn.Module, reward_model: nn.Module, initial_model: Actor, kl_coef: float = 0.1
) -> None:
def __init__(self, actor: Actor, critic: Critic, reward_model: RewardModel, initial_model: Actor) -> None:
super().__init__()
self.actor = actor
self.critic = critic
self.reward_model = reward_model
self.initial_model = initial_model
self.kl_coef = kl_coef

@abstractmethod
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
def make_experience(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **generate_kwargs) -> Experience:
pass
27 changes: 21 additions & 6 deletions applications/Chat/coati/experience_maker/naive.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch
import torch.nn.functional as F
from coati.models.base import Actor, Critic, RewardModel
from coati.models.generation import generate
from coati.models.utils import calc_action_log_probs, compute_reward
from transformers import PreTrainedTokenizer

from .base import Experience, ExperienceMaker

Expand All @@ -11,6 +13,19 @@ class NaiveExperienceMaker(ExperienceMaker):
Naive experience maker.
"""

def __init__(
self,
actor: Actor,
critic: Critic,
reward_model: RewardModel,
initial_model: Actor,
tokenizer: PreTrainedTokenizer,
kl_coef: float = 0.1,
) -> None:
super().__init__(actor, critic, reward_model, initial_model)
self.tokenizer = tokenizer
self.kl_coef = kl_coef

@torch.no_grad()
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
self.actor.eval()
Expand All @@ -19,16 +34,16 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie
self.reward_model.eval()

# generate sequences
sequences = generate(self.actor, input_ids, **generate_kwargs)
sequences = generate(self.actor, input_ids, self.tokenizer, **generate_kwargs)

# calculate auxiliary tensors
attention_mask = None
pad_token_id = generate_kwargs.get("pad_token_id", None)
pad_token_id = self.tokenizer.pad_token_id
if pad_token_id is not None:
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)

input_len = input_ids.size(1)
eos_token_id = generate_kwargs.get("eos_token_id", None)
eos_token_id = self.tokenizer.eos_token_id
if eos_token_id is None:
action_mask = torch.ones_like(sequences, dtype=torch.bool)
else:
Expand All @@ -40,11 +55,11 @@ def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experie
action_mask = action_mask[:, -(sequences.size(1) - input_len) :]
num_actions = action_mask.size(1)

actor_output = self.actor(sequences, attention_mask)
actor_output = self.actor(sequences, attention_mask)["logits"]
action_log_probs = calc_action_log_probs(actor_output, sequences, num_actions)
base_model_output = self.initial_model(sequences, attention_mask)
base_model_output = self.initial_model(sequences, attention_mask)["logits"]
base_action_log_probs = calc_action_log_probs(base_model_output, sequences, num_actions)
value = self.critic(sequences, action_mask, attention_mask)
value = self.critic(sequences, attention_mask)
r = self.reward_model(sequences, attention_mask)
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)

Expand Down
2 changes: 1 addition & 1 deletion applications/Chat/coati/models/base/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def forward(
self,
input_ids: torch.LongTensor,
attention_mask: Optional[torch.Tensor] = None,
**model_kwargs, # HACK: `generate` method may pass more kwargs
**model_kwargs,
) -> torch.Tensor:
"""Returns model output."""
output = self.model(input_ids, attention_mask=attention_mask, **model_kwargs)
Expand Down
37 changes: 8 additions & 29 deletions applications/Chat/coati/models/base/critic.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import Optional

import torch
import torch.nn as nn

from ..lora import LoRAModule
from ..utils import masked_mean


class Critic(LoRAModule):
Expand All @@ -19,37 +16,19 @@ class Critic(LoRAModule):
"""

def __init__(
self,
model: nn.Module,
value_head: nn.Module,
lora_rank: int = 0,
lora_train_bias: str = "none",
use_action_mask: bool = False,
self, model: nn.Module, value_head: nn.Module, lora_rank: int = 0, lora_train_bias: str = "none"
) -> None:
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
self.model = model
self.value_head = value_head
self.use_action_mask = use_action_mask
self.convert_to_lora()

def forward(
self,
sequences: torch.LongTensor,
action_mask: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"]

values = self.value_head(last_hidden_states).squeeze(-1)

if action_mask is not None and self.use_action_mask:
num_actions = action_mask.size(1)
prompt_mask = attention_mask[:, :-num_actions]
values = values[:, :-num_actions]
value = masked_mean(values, prompt_mask, dim=1)
return value

values = values[:, :-1]
value = values.mean(dim=1)
return value
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
0
]
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
return values
11 changes: 7 additions & 4 deletions applications/Chat/coati/models/base/reward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,12 @@ def __init__(
else:
self.value_head = nn.Linear(model.config.n_embd, 1)

def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
def forward(self, sequences: torch.LongTensor, attention_mask: torch.Tensor) -> torch.Tensor:
outputs = self.model(sequences, attention_mask=attention_mask)
last_hidden_states = outputs["last_hidden_state"]
values = self.value_head(last_hidden_states)[:, :-1]
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
return value
sequence_lengths = torch.max(attention_mask * torch.arange(sequences.size(1), device=sequences.device), dim=1)[
0
]
sequence_hidden_states = last_hidden_states[torch.arange(last_hidden_states.size(0)), sequence_lengths]
values = self.value_head(sequence_hidden_states).squeeze(1) # ensure shape is (B, )
return values
16 changes: 7 additions & 9 deletions applications/Chat/coati/models/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import torch
import torch.distributed as dist
from transformers import PreTrainedTokenizer

from .base import Actor

Expand Down Expand Up @@ -63,17 +64,16 @@ def _sample(
)
outputs = model(**model_inputs)

# NOTE: this is correct only in left padding mode
next_token_logits = outputs["logits"][:, -1, :]
# pre-process distribution
next_token_logits = logits_processor(input_ids, next_token_logits)
# sample
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)

# finished sentences should have their next token be a padding token
if eos_token_id is not None:
if pad_token_id is None:
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
assert pad_token_id is not None, "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)

# update generated ids, model inputs for next step
Expand All @@ -96,12 +96,11 @@ def _sample(
def generate(
model: Actor,
input_ids: torch.Tensor,
tokenizer: PreTrainedTokenizer,
max_length: int,
num_beams: int = 1,
do_sample: bool = True,
early_stopping: bool = False,
eos_token_id: Optional[int] = None,
pad_token_id: Optional[int] = None,
top_k: Optional[int] = None,
top_p: Optional[float] = None,
temperature: Optional[float] = None,
Expand All @@ -118,14 +117,13 @@ def generate(
num_beams (int, optional): number of beams. Defaults to 1.
do_sample (bool, optional): whether to do sample. Defaults to True.
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
"""
assert tokenizer.padding_side == "left", "Current generation only supports left padding."
is_greedy_gen_mode = (num_beams == 1) and do_sample is False
is_sample_gen_mode = (num_beams == 1) and do_sample is True
is_beam_gen_mode = (num_beams > 1) and do_sample is False
Expand All @@ -139,8 +137,8 @@ def generate(
input_ids,
max_length,
early_stopping=early_stopping,
eos_token_id=eos_token_id,
pad_token_id=pad_token_id,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
top_k=top_k,
top_p=top_p,
temperature=temperature,
Expand Down
1 change: 1 addition & 0 deletions applications/Chat/coati/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ class GPTLMLoss(nn.Module):

def __init__(self):
super().__init__()
# NOTE: default ignore_index is -100, which is equal to IGNORE_INDEX in sft_dataset.py
self.loss = nn.CrossEntropyLoss()

def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
Expand Down
5 changes: 2 additions & 3 deletions applications/Chat/coati/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,18 +46,17 @@ def _log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.
return log_probs_labels.squeeze(-1)


def calc_action_log_probs(output: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
"""Calculate action log probs.

Args:
output (torch.Tensor): Output tensor of Actor.forward.
output (torch.Tensor): Output tensor of Actor.forward.logits.
sequences (torch.LongTensor): Input sequences.
num_actions (int): Number of actions.

Returns:
torch.Tensor: Action log probs.
"""
logits = output["logits"]
log_probs = _log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
return log_probs[:, -num_actions:]

Expand Down
Loading