Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chat] Support overall loss, update KTO logging #5962

Merged
merged 1 commit into from
Aug 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 17 additions & 2 deletions applications/ColossalChat/coati/dataset/tokenization_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,10 @@ def tokenize_sft(

messages = data_point["messages"]
template = deepcopy(conversation_template)

if messages[0]["from"] == "system":
template.system_message = str(messages[0]["content"])
messages.pop(0)
template.messages = []
for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
Expand Down Expand Up @@ -148,11 +152,14 @@ def tokenize_prompt(
template = deepcopy(conversation_template)
template.messages = []

if messages[0]["from"] == "system":
template.system_message = str(messages[0]["content"])
messages.pop(0)
YeAnbang marked this conversation as resolved.
Show resolved Hide resolved

for idx, mess in enumerate(messages):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
f"Message should iterate between user and assistant and starts with a \
line from the user. Got the following data:\n{messages}"
f"Message should iterate between user and assistant and starts with a line from the user. Got the following data:\n{messages}"
)
template.append_message(mess["from"], mess["content"])

Expand Down Expand Up @@ -225,6 +232,10 @@ def tokenize_rlhf(
template = deepcopy(conversation_template)
template.clear()

if context[0]["from"] == "system":
template.system_message = str(context[0]["content"])
context.pop(0)

for idx, mess in enumerate(context):
if mess["from"] != template.roles[idx % 2]:
raise ValueError(
Expand Down Expand Up @@ -345,6 +356,10 @@ def tokenize_kto(
template = deepcopy(conversation_template)
template.clear()

if prompt[0]["from"] == "system":
template.system_message = str(prompt[0]["content"])
prompt.pop(0)

if prompt[0].get("from", None) != "user":
raise ValueError("conversation should start with user")
if completion.get("from", None) != "assistant":
Expand Down
16 changes: 12 additions & 4 deletions applications/ColossalChat/coati/models/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,10 @@ def forward(
action_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
skip = False
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()
if action_mask is None:
ratio_ = (log_probs - old_log_probs).exp()
else:
ratio_ = ((log_probs - old_log_probs) * action_mask).exp()

# note that if dropout is disabled (recommanded), ratio will always be 1.
if ratio_.mean() > self.skip_threshold:
Expand All @@ -56,7 +59,10 @@ def forward(
surr1 = ratio * advantages
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
loss = -torch.min(surr1, surr2)
loss = masked_mean(loss, action_mask)
if action_mask is not None:
loss = masked_mean(loss, action_mask)
else:
loss = loss.mean(dim=1)
loss = loss.mean()
return loss, skip, ratio_.max()

Expand All @@ -81,8 +87,10 @@ def forward(
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
surr1 = (values_clipped - returns) ** 2
surr2 = (values - returns) ** 2
loss = torch.max(surr1, surr2) / torch.sum(action_mask)
loss = torch.sum(loss * action_mask)
if action_mask is not None:
loss = torch.sum(torch.max(surr1, surr2) / torch.sum(action_mask) * action_mask)
else:
loss = torch.mean(torch.max(surr1, surr2))
return 0.5 * loss


Expand Down
9 changes: 9 additions & 0 deletions applications/ColossalChat/coati/trainer/dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
beta: float = 0.1,
gamma: float = 0.0,
length_normalization: bool = False,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
Expand All @@ -67,6 +68,7 @@ def __init__(
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.actor_loss_fn = DpoLoss(beta, gamma)
self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
Expand Down Expand Up @@ -135,6 +137,10 @@ def _train(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)

batch_size = chosen_input_ids.size()[0]

actor_all_logits = self.model(
Expand Down Expand Up @@ -284,6 +290,9 @@ def _eval(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)
if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)

batch_size = chosen_input_ids.size()[0]

Expand Down
37 changes: 34 additions & 3 deletions applications/ColossalChat/coati/trainer/kto.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Any, Optional

import torch
import torch.distributed
import torch.distributed as dist
from coati.models.loss import KTOLoss
from coati.models.utils import calc_masked_log_probs
from coati.trainer.utils import all_reduce_mean
Expand Down Expand Up @@ -59,6 +59,7 @@ def __init__(
beta: float = 0.1,
desirable_weight: float = 1.0,
undesirable_weight: float = 1.0,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
Expand All @@ -70,6 +71,7 @@ def __init__(
self.actor_scheduler = actor_lr_scheduler
self.tokenizer = tokenizer
self.kto_loss = KTOLoss(beta=beta, desirable_weight=desirable_weight, undesirable_weight=undesirable_weight)
self.apply_loss_mask = apply_loss_mask
self.save_interval = save_interval
self.coordinator = coordinator
self.save_dir = save_dir
Expand Down Expand Up @@ -134,6 +136,10 @@ def _train(self, epoch: int):
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)
if not self.apply_loss_mask:
loss_mask = loss_mask.fill_(1.0)
kl_loss_mask = kl_loss_mask.fill_(1.0)

batch_size = input_ids.size()[0]

# actor logits
Expand Down Expand Up @@ -182,8 +188,28 @@ def _train(self, epoch: int):

# sync
loss_mean = all_reduce_mean(tensor=loss)
chosen_rewards_mean = all_reduce_mean(tensor=chosen_rewards.mean())
rejected_rewards_mean = all_reduce_mean(tensor=rejected_rewards.mean())
chosen_reward_mean = chosen_rewards.mean()
chosen_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(chosen_rewards_list, chosen_reward_mean)
rejected_reward_mean = rejected_rewards.mean()
rejected_rewards_list = [
torch.tensor(0, dtype=loss.dtype, device=loss.device) for _ in range(dist.get_world_size())
]
dist.all_gather(rejected_rewards_list, rejected_reward_mean)
chosen_rewards_list = [i for i in chosen_rewards_list if not i.isnan()]
rejected_rewards_list = [i for i in rejected_rewards_list if not i.isnan()]
chosen_rewards_mean = (
torch.stack(chosen_rewards_list).mean()
if len(chosen_rewards_list) > 0
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
)
rejected_rewards_mean = (
torch.stack(rejected_rewards_list).mean()
if len(rejected_rewards_list) > 0
else torch.tensor(torch.nan, dtype=loss.dtype, device=loss.device)
)
self.accumulative_meter.add("chosen_rewards", chosen_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("rejected_rewards", rejected_rewards_mean.to(torch.float16).mean().item())
self.accumulative_meter.add("loss", loss_mean.to(torch.float16).detach().item())
Expand Down Expand Up @@ -256,6 +282,11 @@ def _eval(self, epoch: int):
batch["kl_attention_mask"],
batch["kl_loss_mask"],
)

if not self.apply_loss_mask:
loss_mask = loss_mask.fill_(1.0)
kl_loss_mask = kl_loss_mask.fill_(1.0)

batch_size = input_ids.size()[0]

# actor logits
Expand Down
12 changes: 12 additions & 0 deletions applications/ColossalChat/coati/trainer/orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(
tokenizer: PreTrainedTokenizerBase,
max_epochs: int = 1,
lam: float = 0.1,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
start_epoch: int = 0,
save_interval: int = 0,
Expand All @@ -67,6 +68,7 @@ def __init__(
self.save_dir = save_dir
self.num_train_step = 0
self.lam = lam
self.apply_loss_mask = apply_loss_mask
self.accumulation_steps = accumulation_steps
self.device = get_current_device()
self.accumulative_meter = AccumulativeMeanMeter()
Expand Down Expand Up @@ -130,6 +132,11 @@ def _train(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)

if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)

batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
Expand Down Expand Up @@ -263,6 +270,11 @@ def _eval(self, epoch: int):
batch["reject_attention_mask"],
batch["reject_loss_mask"],
)

if not self.apply_loss_mask:
chosen_loss_mask = chosen_loss_mask.fill_(1.0)
reject_loss_mask = reject_loss_mask.fill_(1.0)

batch_size = chosen_input_ids.size()[0]
actor_out = self.model(
input_ids=torch.cat([chosen_input_ids, reject_input_ids]),
Expand Down
12 changes: 10 additions & 2 deletions applications/ColossalChat/coati/trainer/ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ def __init__(
sample_buffer: bool = False,
dataloader_pin_memory: bool = True,
offload_inference_models: bool = True,
apply_loss_mask: bool = True,
accumulation_steps: int = 1,
save_interval: int = 0,
save_dir: str = None,
Expand Down Expand Up @@ -140,6 +141,7 @@ def __init__(
self.actor_optim = actor_optim
self.critic_optim = critic_optim
self.save_interval = save_interval
self.apply_loss_mask = apply_loss_mask
self.coordinator = coordinator
self.actor_save_dir = os.path.join(save_dir, "actor")
self.critic_save_dir = os.path.join(save_dir, "critic")
Expand Down Expand Up @@ -229,7 +231,10 @@ def _training_step(self, experience: Experience):
action_log_probs = calc_action_log_probs(actor_logits, experience.sequences, num_actions)

actor_loss, to_skip, max_ratio = self.actor_loss_fn(
action_log_probs, experience.action_log_probs, experience.advantages, action_mask=experience.action_mask
action_log_probs,
experience.action_log_probs,
experience.advantages,
action_mask=experience.action_mask if self.apply_loss_mask else None,
)
actor_loss = (1 - self.ptx_coef) * actor_loss
if not to_skip:
Expand All @@ -249,7 +254,10 @@ def _training_step(self, experience: Experience):
input_ids=experience.sequences, attention_mask=experience.attention_mask
) # [batch size, prompt_length + response_length]
critic_loss = self.critic_loss_fn(
values[:, -num_actions:], experience.values, experience.advantages, action_mask=experience.action_mask
values[:, -num_actions:],
experience.values,
experience.advantages,
action_mask=experience.action_mask if self.apply_loss_mask else None,
)
critic_loss = critic_loss * self.vf_coef
self.critic_booster.backward(loss=critic_loss, optimizer=self.critic_optim)
Expand Down
14 changes: 12 additions & 2 deletions applications/ColossalChat/coati/trainer/sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ def __init__(
lr_scheduler: _LRScheduler,
max_epochs: int = 2,
accumulation_steps: int = 8,
apply_loss_mask: bool = True,
start_epoch=0,
save_interval: int = None,
save_dir: str = None,
Expand All @@ -55,6 +56,7 @@ def __init__(
self.coordinator = coordinator
self.num_train_step = 0
self.num_eval_step = 0
self.apply_loss_mask = apply_loss_mask
self.accumulative_meter = AccumulativeMeanMeter()

def _before_fit(
Expand Down Expand Up @@ -100,7 +102,11 @@ def _train(self, epoch: int):
for i, batch in enumerate(self.train_dataloader):
batch = to_device(batch, torch.cuda.current_device())
batch_size = batch["input_ids"].size(0)
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss = outputs.loss

self.booster.backward(loss=loss, optimizer=self.optimizer)
Expand Down Expand Up @@ -158,7 +164,11 @@ def _eval(self, epoch: int):
)
for batch in self.eval_dataloader:
batch = to_device(batch, torch.cuda.current_device())
outputs = self.model(batch["input_ids"], attention_mask=batch["attention_mask"], labels=batch["labels"])
outputs = self.model(
batch["input_ids"],
attention_mask=batch["attention_mask"],
labels=batch["labels"] if self.apply_loss_mask else batch["input_ids"],
)
loss_mean = all_reduce_mean(tensor=outputs.loss)
self.accumulative_meter.add("loss", loss_mean.item(), count_update=batch["input_ids"].size(0))
step_bar.update()
Expand Down
1 change: 1 addition & 0 deletions applications/ColossalChat/examples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -387,6 +387,7 @@ colossalai run --nproc_per_node 4 --master_port 28534 --hostfile ./hostfile trai
- save_dir: path to store the model checkpoints.
- max_length: input will be padded/truncated to max_length before feeding to the model.
- max_epochs: number of epochs to train.
- disable_loss_mask: whether to use the loss mask to mask the loss or not. For example, in SFT, if the loss mask is disabled, the model will compute the loss across all tokens in the sequence, if the loss mask is applied, only tokens correspond to the assistant responses will contribute to the final loss.
- batch_size: training batch size.
- mixed_precision: precision to use in training. Support 'fp16' and 'bf16'. Note that some devices may not support the 'bf16' option, please refer to [Nvidia](https://developer.nvidia.com/) to check compatibility.
- save_interval: save the model weights as well as optimizer/scheduler states every save_interval steps/episodes.
Expand Down
4 changes: 2 additions & 2 deletions applications/ColossalChat/examples/inference/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@ def load_model_and_tokenizer(model_path, tokenizer_path, device="cuda", **kwargs
tuple: A tuple containing the loaded model and tokenizer.
"""

model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path)
model = AutoModelForCausalLM.from_pretrained(model_path, **kwargs, trust_remote_code=True).to(torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
model.to(device)

Expand Down
Loading
Loading