From 98afb3e804210fdf8268e46256c623af34d2bf87 Mon Sep 17 00:00:00 2001 From: Michael Mi Date: Tue, 7 Nov 2023 00:55:44 -0500 Subject: [PATCH] only prepend bos token for llama2. --- src/model_loader/model_loader.cpp | 5 +++-- src/tokenizer/sentencepiece_tokenizer.cpp | 12 ++++++++---- src/tokenizer/sentencepiece_tokenizer.h | 4 +++- src/tokenizer/sentencepiece_tokenizer_test.cpp | 4 ++-- 4 files changed, 16 insertions(+), 9 deletions(-) diff --git a/src/model_loader/model_loader.cpp b/src/model_loader/model_loader.cpp index 72fc2ce0..61cf93de 100644 --- a/src/model_loader/model_loader.cpp +++ b/src/model_loader/model_loader.cpp @@ -78,7 +78,7 @@ std::unique_ptr PTModelLoader::tokenizer() const { GLOG(ERROR) << "Failed to find tokenizer file: " << tokenizer_path; return nullptr; } - return std::make_unique(tokenizer_path); + return std::make_unique(tokenizer_path, /*prepend_bos=*/true); } PTModelLoader::PTModelLoader(const std::string& model_weights_path) @@ -150,12 +150,13 @@ std::unique_ptr HFModelLoader::tokenizer() const { return HFTokenizer::from_file(tokenizer_path); } + // fallback to tokenizer.model if tokenizer.json does not exist const std::string vocab_path = model_weights_path_ + "/tokenizer.model"; if (std::filesystem::exists(vocab_path)) { GLOG(WARNING) << "Failed to find tokenizer.json, use tokenizer.model " "instead. Please consider to convert the tokenizer.model " "to tokenizer.json for better performance."; - return std::make_unique(vocab_path); + return std::make_unique(vocab_path, /*prepend_bos=*/false); } GLOG(ERROR) diff --git a/src/tokenizer/sentencepiece_tokenizer.cpp b/src/tokenizer/sentencepiece_tokenizer.cpp index b29e969a..b41c3c3e 100644 --- a/src/tokenizer/sentencepiece_tokenizer.cpp +++ b/src/tokenizer/sentencepiece_tokenizer.cpp @@ -6,8 +6,9 @@ namespace llm { SentencePieceTokenizer::SentencePieceTokenizer( - const std::string& vocab_file_path) - : vocab_file_path_(vocab_file_path) { + const std::string& vocab_file_path, + bool prepend_bos) + : vocab_file_path_(vocab_file_path), prepend_bos_(prepend_bos) { const auto status = sp_processor_.Load(vocab_file_path); if (!status.ok()) { GLOG(FATAL) << "Failed to load SentencePiece model from " << vocab_file_path @@ -16,7 +17,8 @@ SentencePieceTokenizer::SentencePieceTokenizer( } std::unique_ptr SentencePieceTokenizer::clone() const { - return std::make_unique(this->vocab_file_path_); + return std::make_unique(this->vocab_file_path_, + this->prepend_bos_); } bool SentencePieceTokenizer::encode(const std::string_view& text, @@ -28,7 +30,9 @@ bool SentencePieceTokenizer::encode(const std::string_view& text, return false; } // prepend bos token - ids->insert(ids->begin(), sp_processor_.bos_id()); + if (prepend_bos_) { + ids->insert(ids->begin(), sp_processor_.bos_id()); + } return true; } diff --git a/src/tokenizer/sentencepiece_tokenizer.h b/src/tokenizer/sentencepiece_tokenizer.h index 3f349b3c..2ff366f8 100644 --- a/src/tokenizer/sentencepiece_tokenizer.h +++ b/src/tokenizer/sentencepiece_tokenizer.h @@ -8,7 +8,7 @@ namespace llm { // a tokenizer that uses google/SentencePiece class SentencePieceTokenizer : public Tokenizer { public: - explicit SentencePieceTokenizer(const std::string& vocab_file_path); + explicit SentencePieceTokenizer(const std::string& vocab_file_path, bool prepend_bos); bool encode(const std::string_view& text, std::vector* ids) const override; @@ -23,6 +23,8 @@ class SentencePieceTokenizer : public Tokenizer { std::string vocab_file_path_; sentencepiece::SentencePieceProcessor sp_processor_; + + bool prepend_bos_ = false; }; } // namespace llm diff --git a/src/tokenizer/sentencepiece_tokenizer_test.cpp b/src/tokenizer/sentencepiece_tokenizer_test.cpp index e88d5c70..1d19019e 100644 --- a/src/tokenizer/sentencepiece_tokenizer_test.cpp +++ b/src/tokenizer/sentencepiece_tokenizer_test.cpp @@ -5,7 +5,7 @@ namespace llm { TEST(SentencePieceTokenizerTest, EncodeDecodeTest) { - SentencePieceTokenizer tokenizer("data/tokenizer.model"); + SentencePieceTokenizer tokenizer("data/tokenizer.model", /*prepend_bos=*/true); std::vector ids; ASSERT_TRUE(tokenizer.encode("Hello, world!", &ids)); const std::vector desired_ids = {1, 15043, 29892, 3186, 29991}; @@ -16,7 +16,7 @@ TEST(SentencePieceTokenizerTest, EncodeDecodeTest) { } TEST(SentencePieceTokenizerTest, CJKTest) { - SentencePieceTokenizer tokenizer("data/tokenizer.model"); + SentencePieceTokenizer tokenizer("data/tokenizer.model", /*prepend_bos=*/true); const std::string test_text = "你好,世界!"; std::vector ids; ASSERT_TRUE(tokenizer.encode(test_text, &ids));