Skip to content

Commit

Permalink
fix: Adding tokens in special_tokens_dict
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishek <maurya.abhishek@ibm.com>
  • Loading branch information
Abhishek-TAMU committed Sep 20, 2024
1 parent 146e9f1 commit 0022da3
Showing 1 changed file with 6 additions and 14 deletions.
20 changes: 6 additions & 14 deletions tuning/sft_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,23 +218,16 @@ def train(
)

# Add special tokens only when a custom tokenizer is not passed
special_tokens_dict = {}
if not model_args.tokenizer_name_or_path:
# TODO: understand if we need to hardcode these here or just use defaults in model
if isinstance(tokenizer, (LlamaTokenizer, LlamaTokenizerFast)):
tokenizer.add_special_tokens(
{
"bos_token": "<s>",
"eos_token": "</s>",
"unk_token": "<unk>",
"pad_token": "<pad>",
}
)
special_tokens_dict["bos_token"] = "<s>"
special_tokens_dict["eos_token"] = "</s>"
special_tokens_dict["unk_token"] = "<unk>"
special_tokens_dict["pad_token"] = "<pad>"
elif isinstance(tokenizer, (GPT2Tokenizer, GPTNeoXTokenizerFast)):
tokenizer.add_special_tokens(
{
"pad_token": "<pad>",
}
)
special_tokens_dict["pad_token"] = "<pad>"

max_seq_length = min(train_args.max_seq_length, tokenizer.model_max_length)
logger.info("Max sequence length is %s", max_seq_length)
Expand All @@ -248,7 +241,6 @@ def train(
)

# add special tokens only when a custom tokenizer is not passed
special_tokens_dict = {}
if not model_args.tokenizer_name_or_path:
# TODO: we need to change this, perhaps follow what open instruct does?
if tokenizer.pad_token is None:
Expand Down

0 comments on commit 0022da3

Please sign in to comment.