Skip to content

Commit

Permalink
only prepend bos token for llama2.
Browse files Browse the repository at this point in the history
  • Loading branch information
guocuimi committed Nov 7, 2023
1 parent 684d410 commit 98afb3e
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
5 changes: 3 additions & 2 deletions src/model_loader/model_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ std::unique_ptr<Tokenizer> PTModelLoader::tokenizer() const {
GLOG(ERROR) << "Failed to find tokenizer file: " << tokenizer_path;
return nullptr;
}
return std::make_unique<SentencePieceTokenizer>(tokenizer_path);
return std::make_unique<SentencePieceTokenizer>(tokenizer_path, /*prepend_bos=*/true);
}

PTModelLoader::PTModelLoader(const std::string& model_weights_path)
Expand Down Expand Up @@ -150,12 +150,13 @@ std::unique_ptr<Tokenizer> HFModelLoader::tokenizer() const {
return HFTokenizer::from_file(tokenizer_path);
}

// fallback to tokenizer.model if tokenizer.json does not exist
const std::string vocab_path = model_weights_path_ + "/tokenizer.model";
if (std::filesystem::exists(vocab_path)) {
GLOG(WARNING) << "Failed to find tokenizer.json, use tokenizer.model "
"instead. Please consider to convert the tokenizer.model "
"to tokenizer.json for better performance.";
return std::make_unique<SentencePieceTokenizer>(vocab_path);
return std::make_unique<SentencePieceTokenizer>(vocab_path, /*prepend_bos=*/false);
}

GLOG(ERROR)
Expand Down
12 changes: 8 additions & 4 deletions src/tokenizer/sentencepiece_tokenizer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
namespace llm {

SentencePieceTokenizer::SentencePieceTokenizer(
const std::string& vocab_file_path)
: vocab_file_path_(vocab_file_path) {
const std::string& vocab_file_path,
bool prepend_bos)
: vocab_file_path_(vocab_file_path), prepend_bos_(prepend_bos) {
const auto status = sp_processor_.Load(vocab_file_path);
if (!status.ok()) {
GLOG(FATAL) << "Failed to load SentencePiece model from " << vocab_file_path
Expand All @@ -16,7 +17,8 @@ SentencePieceTokenizer::SentencePieceTokenizer(
}

std::unique_ptr<Tokenizer> SentencePieceTokenizer::clone() const {
return std::make_unique<SentencePieceTokenizer>(this->vocab_file_path_);
return std::make_unique<SentencePieceTokenizer>(this->vocab_file_path_,
this->prepend_bos_);
}

bool SentencePieceTokenizer::encode(const std::string_view& text,
Expand All @@ -28,7 +30,9 @@ bool SentencePieceTokenizer::encode(const std::string_view& text,
return false;
}
// prepend bos token
ids->insert(ids->begin(), sp_processor_.bos_id());
if (prepend_bos_) {
ids->insert(ids->begin(), sp_processor_.bos_id());
}
return true;
}

Expand Down
4 changes: 3 additions & 1 deletion src/tokenizer/sentencepiece_tokenizer.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace llm {
// a tokenizer that uses google/SentencePiece
class SentencePieceTokenizer : public Tokenizer {
public:
explicit SentencePieceTokenizer(const std::string& vocab_file_path);
explicit SentencePieceTokenizer(const std::string& vocab_file_path, bool prepend_bos);

bool encode(const std::string_view& text,
std::vector<int32_t>* ids) const override;
Expand All @@ -23,6 +23,8 @@ class SentencePieceTokenizer : public Tokenizer {
std::string vocab_file_path_;

sentencepiece::SentencePieceProcessor sp_processor_;

bool prepend_bos_ = false;
};

} // namespace llm
4 changes: 2 additions & 2 deletions src/tokenizer/sentencepiece_tokenizer_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace llm {

TEST(SentencePieceTokenizerTest, EncodeDecodeTest) {
SentencePieceTokenizer tokenizer("data/tokenizer.model");
SentencePieceTokenizer tokenizer("data/tokenizer.model", /*prepend_bos=*/true);
std::vector<int> ids;
ASSERT_TRUE(tokenizer.encode("Hello, world!", &ids));
const std::vector<int> desired_ids = {1, 15043, 29892, 3186, 29991};
Expand All @@ -16,7 +16,7 @@ TEST(SentencePieceTokenizerTest, EncodeDecodeTest) {
}

TEST(SentencePieceTokenizerTest, CJKTest) {
SentencePieceTokenizer tokenizer("data/tokenizer.model");
SentencePieceTokenizer tokenizer("data/tokenizer.model", /*prepend_bos=*/true);
const std::string test_text = "你好,世界!";
std::vector<int> ids;
ASSERT_TRUE(tokenizer.encode(test_text, &ids));
Expand Down

0 comments on commit 98afb3e

Please sign in to comment.