Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: chat_template masking due to truncation, consolidate turn build and keys within field #2123

Merged
merged 5 commits into from
Dec 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 27 additions & 3 deletions src/axolotl/prompt_strategies/bradley_terry/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def tokenize_prompt(self, prompt):
:return:
"""

max_length = self.prompter.max_length

self.messages = "chosen_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
Expand All @@ -39,6 +41,16 @@ def tokenize_prompt(self, prompt):
prompt[self.messages].append({"role": "assistant", "content": prompt["chosen"]})
chosen_tokenized = super().tokenize_prompt(prompt)

if len(chosen_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Chosen sequence exceeds max sequence length: {len(chosen_tokenized['input_ids'])}",
)

chosen_tokenized["input_ids"] = chosen_tokenized["input_ids"][:max_length]
chosen_tokenized["attention_mask"] = chosen_tokenized["attention_mask"][
:max_length
]

self.messages = "rejected_messages"
# pylint: disable=duplicate-code
prompt[self.messages] = []
Expand All @@ -52,6 +64,18 @@ def tokenize_prompt(self, prompt):
)
rejected_tokenized = super().tokenize_prompt(prompt)

if len(rejected_tokenized["input_ids"]) > max_length:
LOG.warning(
f"Rejected sequence exceeds max sequence length: {len(rejected_tokenized['input_ids'])}",
)

rejected_tokenized["input_ids"] = rejected_tokenized["input_ids"][
:max_length
]
rejected_tokenized["attention_mask"] = rejected_tokenized["attention_mask"][
:max_length
]

return {
"input_ids_chosen": chosen_tokenized["input_ids"],
"attention_mask_chosen": chosen_tokenized["attention_mask"],
Expand Down Expand Up @@ -80,9 +104,9 @@ def load(tokenizer, cfg, ds_cfg: Optional[Dict[str, Any]] = None):
"roles": ds_cfg.get("roles"),
"drop_system_message": ds_cfg.get("drop_system_message", False),
# we need to add one for detecting sequences with exceeding the `sequence_len` limit.
"max_length": cfg.sequence_len + 1
if not cfg.reward_model
else cfg.sequence_len,
"max_length": (
cfg.sequence_len + 1 if not cfg.reward_model else cfg.sequence_len
),
}

strategy_params = {
Expand Down
158 changes: 87 additions & 71 deletions src/axolotl/prompt_strategies/chat_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def __init__(
"gpt": "assistant",
"system": "system",
}

self.message_field_role = message_field_role
self.message_field_content = message_field_content
self.message_field_training = message_field_training
Expand All @@ -53,21 +54,9 @@ def __init__(
self.drop_system_message = drop_system_message

def build_prompt(self, conversation, add_generation_prompt=False, images=None):
turns = [
{
"role": self.roles[t[self.message_field_role]],
"content": t[self.message_field_content],
"training": t.get(self.message_field_training, None),
}
for t in conversation
]

if self.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]

if self.processor:
text = self.processor.apply_chat_template(
turns,
conversation,
chat_template=self.chat_template,
tokenize=False,
add_generation_prompt=add_generation_prompt,
Expand All @@ -76,8 +65,6 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
text=text,
images=images,
return_tensors="pt",
truncation=True,
max_length=self.max_length,
)
# workaround since processor works in batches instead of single examples
for k, val in batch.items():
Expand All @@ -88,9 +75,7 @@ def build_prompt(self, conversation, add_generation_prompt=False, images=None):
return batch

return self.tokenizer.apply_chat_template(
turns,
truncation=True,
max_length=self.max_length,
conversation,
add_generation_prompt=add_generation_prompt,
chat_template=self.chat_template,
)
Expand Down Expand Up @@ -215,7 +200,14 @@ def __init__(
train_on_eos=None,
):
super().__init__(prompter, tokenizer, train_on_inputs, sequence_len)
self.roles_to_train = roles_to_train if roles_to_train is not None else []

self.roles_to_train = []
if roles_to_train:
# map roles if exist in prompter.roles else use the role as is
self.roles_to_train = [
prompter.roles.get(role, role) for role in roles_to_train
]

self.train_on_eos = train_on_eos
self.images = "images"

Expand Down Expand Up @@ -262,37 +254,38 @@ def tokenize_prompt(self, prompt):

return tokenized_prompt

turns = prompt[self.messages]
turns = self.get_conversation_thread(prompt)
input_ids = self.prompter.build_prompt(turns)
labels = [IGNORE_TOKEN_ID] * len(input_ids)

last_eos_idx = -1
for index, turn in enumerate(turns):
role = turn.get(self.prompter.message_field_role)
content = turn.get(self.prompter.message_field_content)
train_turn = turn.get(self.prompter.message_field_training)
train_detail = turn.get(self.prompter.message_field_training_detail)
role = turn.get("role")
content = turn.get("content")
train_turn = turn.get("training")
train_detail = turn.get("training_detail")

LOG.debug(
f"Processing turn {index}: role={role}, content={content}, train_turn={train_turn}, train_detail={train_detail}"
)

should_train = (
train_turn
if train_turn is not None
else (
bool(train_detail is not None)
if train_detail is not None
else self.train_on_inputs or role in self.roles_to_train
)
)
should_train = None
if train_turn is not None:
should_train = train_turn
elif train_detail is not None:
should_train = bool(train_detail)
else:
should_train = self.train_on_inputs or role in self.roles_to_train

LOG.debug(f"Should train: {should_train}")

turn_start_idx, turn_end_idx = self.find_turn(
conversation_ids=input_ids, turn=index, turn_content=turn
)

if turn_start_idx == -1 or turn_end_idx == -1:
LOG.warning(f"Failed to find boundaries for turn {index}")

LOG.debug(f"Turn indices: start={turn_start_idx}, end={turn_end_idx}")

if should_train and turn_start_idx != -1 and turn_end_idx != -1:
Expand All @@ -313,7 +306,9 @@ def tokenize_prompt(self, prompt):
labels[turn_start_idx:turn_end_idx] = input_ids[
turn_start_idx:turn_end_idx
]
LOG.debug(f"Labels set for range {turn_start_idx}:{turn_end_idx}")
LOG.debug(
f"Set labels for training from {turn_start_idx} to {turn_end_idx}"
)

LOG.debug(f"Labels after processing turn {index}: {labels}")

Expand Down Expand Up @@ -351,52 +346,73 @@ def find_eos_token(self, input_ids, start_idx):
return i
return -1

def find_turn(self, conversation_ids, turn, turn_content):
def find_turn(self, conversation_ids: list[int], turn: int, turn_content: dict):
"""
Locate the starting and ending indices of the specified turn in a conversation.

Args:
conversation_ids (list[int]): Token IDs representing the conversation.
turn (int): The turn number to locate (based on EOS tokens).
turn_content (str): String containing the content of the turn.

Returns:
tuple: (start_idx, end_idx) indices of the start and end of the turn content.
Returns (-1, -1) if the turn content is not found.
"""
content = turn_content.get(self.prompter.message_field_content, "")
content = turn_content.get("content")
content_ids = self.tokenizer.encode(content, add_special_tokens=False)

eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0

# Locate the starting index after the specified number of EOS tokens
for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn:
start_search_idx = (
i + 1
) # Start searching after the specified turn's EOS token
break

# Find the start index of the content within the conversation
start_idx = -1
for i in range(start_search_idx, len(conversation_ids) - len(content_ids) + 1):
if conversation_ids[i : i + len(content_ids)] == content_ids:
start_idx = i
break

if start_idx != -1:
end_idx = start_idx + len(content_ids)
LOG.debug(f"content_ids (length {len(content_ids)}): {content_ids}")

if not content_ids:
LOG.warning(f"Empty content for turn {turn}")
return -1, -1

# For first turn, start from beginning
if turn == 0:
start_search_idx = 0
else:
end_idx = -1
# For subsequent turns, find the previous EOS token
eos_token_id = self.tokenizer.eos_token_id
eos_count = 0
start_search_idx = 0

for i, token_id in enumerate(conversation_ids):
if token_id == eos_token_id:
eos_count += 1
if eos_count == turn: # Find the nth EOS token where n = turn
start_search_idx = i + 1
break

# we can optimize this to only search for a few tokens from start_search_idx
# but it would risk missing the content if it's not found within the first few tokens or
# if start_search_idx cannot be found above.
last_index = len(conversation_ids) - len(content_ids) + 1

if last_index < start_search_idx:
LOG.warning(
f"last_index to search is less than start_search_idx for turn {turn}"
)
return -1, -1

# Search for content starting from start_search_idx
first_elem = content_ids[0]
for i in range(start_search_idx, last_index):
# Quick check of first element before doing full comparison
if conversation_ids[i] == first_elem:
# Check if the rest of the content matches
if conversation_ids[i : i + len(content_ids)] == content_ids:
LOG.debug(f"Found turn {turn} content at position {i}")
return i, i + len(content_ids)

return start_idx, end_idx
return -1, -1

def get_conversation_thread(self, prompt):
return prompt[self.messages]
turns = [
{
"role": self.prompter.roles[t[self.prompter.message_field_role]],
"content": t[self.prompter.message_field_content],
"training": t.get(self.prompter.message_field_training),
"training_detail": t.get(self.prompter.message_field_training_detail),
}
for t in prompt[self.messages]
]

if self.prompter.drop_system_message and turns[0]["role"] == "system":
turns = turns[1:]

return turns

def get_images(self, prompt):
return prompt.get(self.images, None)
Expand Down
Loading