diff --git a/CMakeLists.txt b/CMakeLists.txt index 02da93b..db21541 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -150,6 +150,7 @@ set(mlxdata-src ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/ThreadController.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/ThreadPool.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/Tokenizer.cpp + ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/BPETokenizer.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/Levenshtein.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/Utils.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/core/audio/Audio.cpp @@ -204,7 +205,8 @@ set(mlxdata-src ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/Squeeze.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/Tokenize.cpp ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/ImageTransform.cpp - ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/RemoveValue.cpp) + ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/RemoveValue.cpp + ${CMAKE_CURRENT_LIST_DIR}/mlx/data/op/Replace.cpp) if(AWSSDK_FOUND) list(APPEND mlxdata-src diff --git a/mlx/data/Dataset.cpp b/mlx/data/Dataset.cpp index acd1bfb..ff3c7df 100644 --- a/mlx/data/Dataset.cpp +++ b/mlx/data/Dataset.cpp @@ -21,6 +21,7 @@ #include "mlx/data/op/ReadFromTAR.h" #include "mlx/data/op/RemoveValue.h" #include "mlx/data/op/RenameKey.h" +#include "mlx/data/op/Replace.h" #include "mlx/data/op/SampleTransform.h" #include "mlx/data/op/SaveImage.h" #include "mlx/data/op/Shape.h" @@ -633,6 +634,31 @@ T Dataset::remove_value_if( } } +template +T Dataset::replace( + const std::string& key, + const std::string& old, + const std::string& replacement, + int count) { + return transform_( + std::make_shared(key, old, replacement, count)); +} + +template +T Dataset::replace_if( + bool cond, + const std::string& key, + const std::string& old, + const std::string& replacement, + int count) { + if (cond) { + return transform_( + std::make_shared(key, old, replacement, count)); + } else { + return T(self_); + } +} + template T Dataset::rename_key(const std::string& ikey, const std::string& okey) const { @@ -824,6 +850,31 @@ T Dataset::tokenize_if( } } +template +T Dataset::tokenize_bpe( + const std::string& ikey, + std::shared_ptr> symbols, + std::shared_ptr merges, + const std::string& okey) const { + return transform_( + std::make_shared(ikey, symbols, merges, okey)); +} + +template +T Dataset::tokenize_bpe_if( + bool cond, + const std::string& ikey, + std::shared_ptr> symbols, + std::shared_ptr merges, + const std::string& okey) const { + if (cond) { + return transform_( + std::make_shared(ikey, symbols, merges, okey)); + } else { + return T(self_); + } +} + // Implement Stream template <> Stream Dataset::transform_( diff --git a/mlx/data/Dataset.h b/mlx/data/Dataset.h index ce387ae..6f7ce1b 100644 --- a/mlx/data/Dataset.h +++ b/mlx/data/Dataset.h @@ -314,6 +314,18 @@ class Dataset { double value, double pad) const; + T replace( + const std::string& key, + const std::string& old, + const std::string& replacement, + int count = -1); + T replace_if( + bool cond, + const std::string& key, + const std::string& old, + const std::string& replacement, + int count = -1); + T rename_key(const std::string& ikey, const std::string& okey) const; T rename_key_if(bool cond, const std::string& ikey, const std::string& okey) const; @@ -384,6 +396,17 @@ class Dataset { bool ignore_unk = false, const std::vector& trie_key_scores = {}, const std::string& okey = "") const; + T tokenize_bpe( + const std::string& ikey, + std::shared_ptr> symbols, + std::shared_ptr merges, + const std::string& okey = "") const; + T tokenize_bpe_if( + bool cond, + const std::string& ikey, + std::shared_ptr> symbols, + std::shared_ptr merges, + const std::string& okey = "") const; protected: std::shared_ptr self_; diff --git a/mlx/data/core/BPETokenizer.cpp b/mlx/data/core/BPETokenizer.cpp new file mode 100644 index 0000000..21a0089 --- /dev/null +++ b/mlx/data/core/BPETokenizer.cpp @@ -0,0 +1,178 @@ +// Copyright © 2024 Apple Inc. + +#include +#include + +#include "mlx/data/core/BPETokenizer.h" +#include "mlx/data/core/Trie.h" + +namespace mlx { +namespace data { +namespace core { + +void BPEMerges::add( + const std::string& left, + const std::string& right, + int64_t token) { + auto [left_s, left_inserted] = strings_.insert(left); + auto [right_s, right_inserted] = strings_.insert(right); + + std::string_view left_v(*left_s); + std::string_view right_v(*right_s); + + auto left_it = merges_.find(left_v); + if (left_it == merges_.end()) { + merges_[left_v][right_v] = token; + } else { + auto right_it = left_it->second.find(right_v); + if (right_it == left_it->second.end()) { + left_it->second[right_v] = token; + } else { + right_it->second = std::min(token, right_it->second); + } + } +} + +std::pair BPEMerges::can_merge( + std::string_view left, + std::string_view right) const { + auto left_it = merges_.find(left); + if (left_it == merges_.end()) { + return {false, 0}; + } + auto right_it = left_it->second.find(right); + if (right_it == left_it->second.end()) { + return {false, 0}; + } + return {true, right_it->second}; +} + +BPETokenizer::BPETokenizer( + std::shared_ptr> symbols, + std::shared_ptr merges) + : symbols_(symbols), merges_(merges) {} + +std::vector BPETokenizer::tokenize(std::string_view input) const { + struct Symbol { + std::string_view value; + int left; + int right; + int64_t token; + }; + + struct Pair { + std::vector::iterator left; + std::vector::iterator right; + int64_t token; + std::string_view value; + + Pair( + std::vector::iterator left, + std::vector::iterator right, + int64_t token) + : left(left), + right(right), + token(token), + value(left->value.data(), left->value.size() + right->value.size()) {} + + bool operator<(const Pair& right) const { + return token >= right.token; + }; + }; + + // Transform the input to a sequence of basic symbols that will subsequently + // be merged. + std::vector symbols; + symbols.reserve(input.size()); + for (auto it = input.begin(); it != input.end(); it++) { + auto [node, length] = symbols_->search_longest_prefix(it, input.end()); + if (length == 0) { + std::ostringstream msg; + msg << "BPETokenizer: Unknown symbol '" << *it << "'"; + throw std::runtime_error(msg.str()); + } + symbols.push_back(Symbol{ + std::string_view(&*it, length), + static_cast(symbols.size() - 1), + static_cast(symbols.size() + 1), + node->id}); + it += length - 1; + } + + std::priority_queue merge_queue; + + // Initialize the merge queue + auto left = symbols.begin(); + auto right = std::next(left); + while (right != symbols.end()) { + auto [can_merge, token] = merges_->can_merge(left->value, right->value); + if (can_merge) { + merge_queue.emplace(left, right, token); + } + left++; + right++; + } + + while (!merge_queue.empty()) { + Pair top = std::move(merge_queue.top()); + merge_queue.pop(); + + // Skip invalidated pairs + if (top.left->token < 0 || top.right->token < 0) { + continue; + } + if (top.value.size() != top.left->value.size() + top.right->value.size()) { + continue; + } + if (top.value.data() != top.left->value.data()) { + continue; + } + + // Yay! Valid pair, let's merge into the left one. + top.left->token = top.token; + top.left->value = top.value; + + // Invalidate our neighbor which we just merged into ourselves. + top.right->token = -1; + + // Adjust the pointers to neighboring symbols + top.left->right = top.right->right; + if (top.right->right < symbols.size()) { + symbols[top.right->right].left = top.right->left; + } + + // Check for a possible merge to the left. + if (top.left != symbols.begin()) { + auto neighbor = symbols.begin() + top.left->left; + auto [can_merge, token] = + merges_->can_merge(neighbor->value, top.left->value); + if (can_merge) { + merge_queue.emplace(neighbor, top.left, token); + } + } + + // Do the same to our right. + if (top.left->right < symbols.size()) { + auto neighbor = symbols.begin() + top.left->right; + auto [can_merge, token] = + merges_->can_merge(top.left->value, neighbor->value); + if (can_merge) { + merge_queue.emplace(top.left, neighbor, token); + } + } + } + + // Gather the final result in a vector + std::vector tokens; + for (auto& symbol : symbols) { + if (symbol.token >= 0) { + tokens.push_back(symbol.token); + } + } + + return tokens; +} + +} // namespace core +} // namespace data +} // namespace mlx diff --git a/mlx/data/core/BPETokenizer.h b/mlx/data/core/BPETokenizer.h new file mode 100644 index 0000000..343a3fc --- /dev/null +++ b/mlx/data/core/BPETokenizer.h @@ -0,0 +1,53 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include +#include + +#include "mlx/data/core/Trie.h" + +namespace mlx { +namespace data { +namespace core { + +class BPEMerges { + public: + void add(const std::string& left, const std::string& right, int64_t token); + std::pair can_merge( + std::string_view left, + std::string_view right) const; + + template + std::pair + can_merge(iterator_type left, iterator_type middle, iterator_type end) const { + // switch to std::string_view(left, middle) when in C++20 + return can_merge( + std::string_view(&(*left), std::distance(left, middle)), + std::string_view(&(*middle), std::distance(middle, end))); + } + + private: + std::unordered_set strings_; + std::unordered_map< + std::string_view, + std::unordered_map> + merges_; +}; + +class BPETokenizer { + public: + BPETokenizer( + std::shared_ptr> symbols, + std::shared_ptr merges); + + std::vector tokenize(std::string_view input) const; + + private: + std::shared_ptr> symbols_; + std::shared_ptr merges_; +}; + +} // namespace core +} // namespace data +} // namespace mlx diff --git a/mlx/data/core/Trie.h b/mlx/data/core/Trie.h index 21e2f0f..5924631 100644 --- a/mlx/data/core/Trie.h +++ b/mlx/data/core/Trie.h @@ -4,6 +4,7 @@ #include #include +#include #include #include #include @@ -32,55 +33,96 @@ class Trie { nodes_.resize(1); nodes_.back().id = -1; // uid is 0 }; - const TrieNode* insert(const std::vector& key) { - TrieNode* node; - int64_t i; - std::tie(node, i) = partial_search_(key); - for (; i < key.size(); i++) { + + template + std::tuple*, int64_t> search_longest_prefix( + iterator_type it, + iterator_type end) const { + auto node = root_(); + int64_t i = 0; + auto valid_node = node; + int64_t valid_i = i; + while (it != end) { + auto kv = node->children.find(*it); + if (kv == node->children.end()) { + break; + } else { + node = kv->second; + i++; + it++; + if (node->accepts()) { + valid_node = node; + valid_i = i; + } + } + } + return std::make_tuple(valid_node, valid_i); + } + + template + const TrieNode* + insert(iterator_type begin, iterator_type end, int64_t id = -1) { + id = (id < 0) ? keys_.size() : id; + auto it = begin; + auto [node, i] = partial_search(it, end); + std::advance(it, i); // it += i but also supports sequential iterators + while (it != end) { nodes_.resize(nodes_.size() + 1); TrieNode* new_node = &nodes_.back(); new_node->uid = nodes_.size() - 1; new_node->id = -1; - node->children[key[i]] = new_node; + node->children[*it] = new_node; node = new_node; + it++; } if (!node->accepts()) { - node->id = keys_.size(); - keys_.push_back(key); + node->id = id; + keys_.emplace(id, std::vector(begin, end)); } return node; - }; - const TrieNode* search(const std::vector& key) { - auto res = partial_search_(key); - if (std::get<1>(res) != key.size()) { + } + + template + const TrieNode* search(iterator_type it, iterator_type end) { + auto [node, i] = partial_search(it, end); + if (i != std::distance(it, end) || !node->accepts()) { return nullptr; - } else { - auto node = std::get<0>(res); - return (node->accepts() ? node : nullptr); } - }; + return node; + } + + const TrieNode* insert(const std::vector& key, int64_t id = -1) { + return insert(key.begin(), key.end(), id); + } + + const TrieNode* search(const std::vector& key) { + return search(key.begin(), key.end()); + } + const TrieNode* root() const { return &nodes_.front(); } + int64_t num_keys() const { return keys_.size(); } + const std::vector& key(int64_t id) const { return keys_.at(id); } - // helper for strings + // helpers for strings template < typename U = T, std::enable_if_t::value, char> = false> - const TrieNode* insert(const std::string& key) { - return insert(std::vector(key.begin(), key.end())); + const TrieNode* insert(const std::string& key, int64_t id = -1) { + return insert(key.begin(), key.end(), id); }; template < typename U = T, std::enable_if_t::value, char> = false> const TrieNode* search(const std::string& key) { - return search(std::vector(key.begin(), key.end())); + return search(key.begin(), key.end()); }; template < typename U = T, @@ -94,21 +136,32 @@ class Trie { TrieNode* root_() { return &nodes_.front(); } - std::tuple*, int64_t> partial_search_(const std::vector& key) { + + const TrieNode* root_() const { + return &nodes_.front(); + } + + template + std::tuple*, int64_t> partial_search( + iterator_type it, + iterator_type end) { auto node = root_(); int64_t i = 0; - for (; i < key.size(); i++) { - auto kv = node->children.find(key[i]); + while (it != end) { + auto kv = node->children.find(*it); if (kv == node->children.end()) { break; } else { node = kv->second; + i++; + it++; } } return std::make_tuple(node, i); } + std::deque> nodes_; - std::vector> keys_; + std::unordered_map> keys_; }; } // namespace core diff --git a/mlx/data/core/Utils.cpp b/mlx/data/core/Utils.cpp index daadb26..30f7b8b 100644 --- a/mlx/data/core/Utils.cpp +++ b/mlx/data/core/Utils.cpp @@ -1,7 +1,6 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include "mlx/data/core/Utils.h" -#include namespace { @@ -57,6 +56,7 @@ void uniq_t( } } } + template void remove_t( std::shared_ptr dst, @@ -102,6 +102,65 @@ void remove_t( } } } + +template +void replace_t( + std::shared_ptr& result, + const std::shared_ptr src, + const std::shared_ptr old, + const std::shared_ptr replacement, + int count) { + int64_t src_size = src->size(); + int64_t old_size = old->size(); + int64_t replacement_size = replacement->size(); + + T* src_buffer = src->data(); + T* old_buffer = old->data(); + T* replacement_buffer = replacement->data(); + + // Calculate the result size. If this ends up being slow we can try + // a single pass algorithm that grows the buffer using realloc. We can also + // try a better search algorithm because this has a worst case complexity + // O(src_size old_size). + int64_t result_size = src_size; + int matches = 0; + if (old_size != replacement_size) { + for (int64_t i = 0; i < src_size; i++) { + if (std::equal(old_buffer, old_buffer + old_size, src_buffer + i)) { + i += old_size - 1; + result_size += replacement_size - old_size; + matches++; + } + if (matches == count) { + break; + } + } + } + + result = std::make_shared(src->type(), result_size); + T* result_buffer = result->data(); + + matches = 0; + for (int64_t i = 0, j = 0; i < src_size; i++, j++) { + if (std::equal(old_buffer, old_buffer + old_size, src_buffer + i)) { + std::copy( + replacement_buffer, + replacement_buffer + replacement_size, + result_buffer + j); + i += old_size - 1; + j += replacement_size - 1; + matches++; + } else { + result_buffer[j] = src_buffer[i]; + } + if (matches == count) { + std::copy( + src_buffer + i + 1, src_buffer + src_size, result_buffer + j + 1); + break; + } + } +} + } // namespace namespace mlx { namespace data { @@ -192,6 +251,16 @@ Sample merge_batch( return sample_batch; } +std::shared_ptr replace( + const std::shared_ptr src, + const std::shared_ptr old, + const std::shared_ptr replacement, + int count) { + std::shared_ptr result; + ARRAY_DISPATCH(src, replace_t, result, src, old, replacement, count); + return result; +} + } // namespace core } // namespace data } // namespace mlx diff --git a/mlx/data/core/Utils.h b/mlx/data/core/Utils.h index 8631bad..e3d3dca 100644 --- a/mlx/data/core/Utils.h +++ b/mlx/data/core/Utils.h @@ -1,4 +1,4 @@ -// Copyright © 2023 Apple Inc. +// Copyright © 2023-2024 Apple Inc. #include "mlx/data/Array.h" #include "mlx/data/Sample.h" @@ -20,6 +20,12 @@ std::pair, std::shared_ptr> remove( double value, double pad); +std::shared_ptr replace( + const std::shared_ptr src, + const std::shared_ptr old, + const std::shared_ptr replacement, + int count); + Sample merge_batch( const std::vector& samples, const std::unordered_map& pad_values = {}, diff --git a/mlx/data/op/Replace.cpp b/mlx/data/op/Replace.cpp new file mode 100644 index 0000000..683f113 --- /dev/null +++ b/mlx/data/op/Replace.cpp @@ -0,0 +1,30 @@ +// Copyright © 2024 Apple Inc. + +#include "mlx/data/op/Replace.h" +#include "mlx/data/core/Utils.h" + +namespace mlx { +namespace data { +namespace op { + +Replace::Replace( + const std::string& key, + const std::string& old, + const std::string& replacement, + int count) + : key_(key), + old_(std::make_shared(old)), + replacement_(std::make_shared(replacement)), + count_(count) {} + +Sample Replace::apply(const Sample& sample) const { + auto value = sample::check_key(sample, key_, old_->type()); + value = core::replace(value, old_, replacement_, count_); + auto new_sample = sample; + new_sample[key_] = value; + return new_sample; +} + +} // namespace op +} // namespace data +} // namespace mlx diff --git a/mlx/data/op/Replace.h b/mlx/data/op/Replace.h new file mode 100644 index 0000000..9d181d1 --- /dev/null +++ b/mlx/data/op/Replace.h @@ -0,0 +1,30 @@ +// Copyright © 2024 Apple Inc. + +#pragma once + +#include "mlx/data/op/Op.h" + +namespace mlx { +namespace data { +namespace op { + +class Replace : public Op { + public: + Replace( + const std::string& key, + const std::string& old, + const std::string& replacement, + int count); + + virtual Sample apply(const Sample& sample) const override; + + private: + std::string key_; + std::shared_ptr old_; + std::shared_ptr replacement_; + int count_; +}; + +} // namespace op +} // namespace data +} // namespace mlx diff --git a/mlx/data/op/Tokenize.cpp b/mlx/data/op/Tokenize.cpp index 5255aae..26d9ccc 100644 --- a/mlx/data/op/Tokenize.cpp +++ b/mlx/data/op/Tokenize.cpp @@ -5,6 +5,7 @@ namespace mlx { namespace data { namespace op { + Tokenize::Tokenize( const std::string& ikey, std::shared_ptr> trie, @@ -15,6 +16,7 @@ Tokenize::Tokenize( : KeyTransformOp(ikey, okey), tokenizer_(trie, ignore_unk, trie_key_scores), mode_(mode) {} + std::shared_ptr Tokenize::apply_key( const std::shared_ptr& src) const { std::string str( @@ -34,6 +36,21 @@ std::shared_ptr Tokenize::apply_key( return std::make_shared(tokens); } + +BPETokenize::BPETokenize( + const std::string& ikey, + std::shared_ptr> symbols, + std::shared_ptr merges, + const std::string& okey) + : KeyTransformOp(ikey, okey), tokenizer_(symbols, merges) {} + +std::shared_ptr BPETokenize::apply_key( + const std::shared_ptr& src) const { + auto tokens = tokenizer_.tokenize(std::string_view( + reinterpret_cast(src->data()), src->size() * src->itemsize())); + return std::make_shared(tokens); +} + } // namespace op } // namespace data } // namespace mlx diff --git a/mlx/data/op/Tokenize.h b/mlx/data/op/Tokenize.h index 8b14a4c..5c5346b 100644 --- a/mlx/data/op/Tokenize.h +++ b/mlx/data/op/Tokenize.h @@ -2,6 +2,7 @@ #pragma once +#include "mlx/data/core/BPETokenizer.h" #include "mlx/data/core/Tokenizer.h" #include "mlx/data/core/Trie.h" #include "mlx/data/op/KeyTransform.h" @@ -30,6 +31,21 @@ class Tokenize : public KeyTransformOp { TokenizeMode mode_; }; +class BPETokenize : public KeyTransformOp { + public: + BPETokenize( + const std::string& ikey, + std::shared_ptr> symbols, + std::shared_ptr merges, + const std::string& okey = ""); + + virtual std::shared_ptr apply_key( + const std::shared_ptr& src) const override; + + private: + core::BPETokenizer tokenizer_; +}; + } // namespace op } // namespace data } // namespace mlx diff --git a/python/mlx/data/tokenizer_helpers.py b/python/mlx/data/tokenizer_helpers.py index 626ebc8..13ef89b 100644 --- a/python/mlx/data/tokenizer_helpers.py +++ b/python/mlx/data/tokenizer_helpers.py @@ -1,5 +1,6 @@ -# Copyright © 2023 Apple Inc. +# Copyright © 2024 Apple Inc. +import math import re from pathlib import Path @@ -8,7 +9,31 @@ except ImportError: SentencePieceProcessor = None -from .core import CharTrie +from .core import BPEMerges, CharTrie + + +def _iterate_spm_tokens(spm_file): + if spm_file.endswith(".model"): + if SentencePieceProcessor is None: + raise RuntimeError( + "sentencepiece must be installed to read directly from a binary model" + ) + + spm_tok = SentencePieceProcessor(spm_file) + for i in range(spm_tok.vocab_size()): + yield spm_tok.id_to_piece(i).encode("utf-8"), spm_tok.get_score(i) + + elif spm_file.endswith(".vocab"): + f = open(spm_file, "rb") + for line in f: + line = line.rstrip() + token, score = line.split(b"\t") + yield token, float(score) + + else: + raise ValueError( + f"Sentencepiece file extenstion must be in [.vocab, .model] but it was {spm_file}" + ) def read_trie_from_spm(spm_file): @@ -18,6 +43,17 @@ def read_trie_from_spm(spm_file): however if the vocabulary and the scores are exported the file can be read without installing sentencepiece. + .. note:: + + Sentencepiece models are almost always BPE models with scores being the + associated log likelihood of from a unigram language model. Using the + :class:`mlx.data.core.CharTrie` and the loaded scores will provide the + shortest possible tokenization with the highest possible log likelihood + but it can be slightly different than the BPE one. + + Use :func:`read_bpe_from_spm` to load the model to be used with a + :class:`mlx.data.core.BPETokenizer`. + Args: spm_file (str): Either a sentencepiece model file or a vocab file extracted from a sentencepiece model. @@ -27,90 +63,56 @@ def read_trie_from_spm(spm_file): corresponding weights from the SPM mdoel. """ - def iterate_tokens(spm_file): - if spm_file.endswith(".model"): - if SentencePieceProcessor is None: - raise RuntimeError( - "sentencepiece must be installed to read directly from a binary model" - ) - - spm_tok = SentencePieceProcessor(spm_file) - for i in range(spm_tok.vocab_size()): - yield spm_tok.id_to_piece(i).encode("utf-8"), spm_tok.get_score(i) - - elif spm_file.endswith(".vocab"): - f = open(spm_file, "rb") - for line in f: - line = line.rstrip() - token, score = line.split(b"\t") - yield token, float(score) - - else: - raise ValueError( - f"Sentencepiece file extenstion must be in [.vocab, .model] but it was {spm_file}" - ) - def to_special_token(token): return b"<0x" + token.hex().encode() + b">" - sep = "\u2581".encode("utf-8") - # We parse the model in two passes. First we save the tokens in tmp_tokens - # and tmp_scores and go back and replace special tokens that already exist - # to a special token representation. This happens so we can keep the same - # ids as the original sentencepiece model. + # and go back and replace special tokens that already exist or tokens that + # have a better score to a special token representation. This happens so we + # can keep the same ids as the original sentencepiece model. tokenmap = {} tmp_tokens = [] - tmp_scores = [] - max_scores = set() - for token, score in iterate_tokens(spm_file): - score = -score - + trie_key_scores = [] + for token, score in _iterate_spm_tokens(spm_file): if re.match(b"^<.*>$", token): - # Make sure to set the max score for all special tokens - max_scores.add(len(tmp_scores)) - hex_byte = re.match(b"^<0x(..)>$", token) if hex_byte: (token,) = hex_byte.groups() token = bytes.fromhex(token.decode()) - token = token.replace(sep, b" ") - # Token already exists so we should choose either the previous one or # this one. if token in tokenmap: existing_token_id = tokenmap[token] - existing_token_score = tmp_scores[existing_token_id] + existing_token_score = trie_key_scores[existing_token_id] # We should replace that token with our token if score < existing_token_score: tmp_tokens[existing_token_id] = to_special_token(token) - max_scores.add(existing_token_id) tmp_tokens.append(token) - tmp_scores.append(score) + trie_key_scores.append(score) tokenmap[token] = len(tmp_tokens) - 1 # We should ignore this token else: tmp_tokens.append(to_special_token(token)) - tmp_scores.append(score) - max_scores.add(len(tmp_tokens) - 1) + trie_key_scores.append(score) # Token doesn't exist so add it else: tmp_tokens.append(token) - tmp_scores.append(score) + trie_key_scores.append(score) tokenmap[token] = len(tmp_tokens) - 1 - # Set the max score to duplicates - max_score = max(tmp_scores) + 1 - for token_id in max_scores: - tmp_scores[token_id] = max_score + # SPM is a BPE tokenizer so it doesn't exactly work like the MLX tokenizer. + # Favoring the shortest sequence and taking into account the scores at the + # same time yields the closest tokenization. + min_score = min(trie_key_scores) + for i in range(len(trie_key_scores)): + trie_key_scores[i] = -min_score - trie_key_scores[i] - # Build the trie and the scores + # Build the trie trie = CharTrie() - trie_key_scores = tmp_scores for token in tmp_tokens: if trie.search(token): raise RuntimeError(f"Token {token} found twice") @@ -119,6 +121,76 @@ def to_special_token(token): return trie, trie_key_scores +def read_bpe_from_spm(spm_file): + """Read a sentencepiece file and decompose it to a symbol trie and BPE + merges for use with :class:`mlx.data.core.BPETokenizer`. + + Because it isn't straightforward to extract the merges from the SPM file, + we create a trie of basic symbols by considering all single unicode + character tokens as basic symbols as well as any special tokens provided. + + To extract the merges we run the BPE algorithm on the tokens in order of + probability as suggested in https://github.com/openai/tiktoken/issues/60 + for exporting an SPM model to huggingface tokenizers. + + Args: + spm_file (str): Either a sentencepiece model file or a vocab file + extracted from a sentencepiece model. + + Returns: + tuple[:class:`mlx.data.core.CharTrie`, :class:`mlx.data.core.BPEMerges`]: The + trie and the corresponding BPE merges from the SPM mdoel. + """ + symbols = [] + merged = [] + tokenmap = {} + for token_id, (token, score) in enumerate(_iterate_spm_tokens(spm_file)): + if re.match(b"^<.*>$", token): + hex_byte = re.match(b"^<0x(..)>$", token) + if hex_byte: + (token,) = hex_byte.groups() + token = bytes.fromhex(token.decode()) + + if len(token) == 1 or score == 0 or len(token.decode(errors="ignore")) == 1: + symbols.append(token) + else: + merged.append(token) + + tokenmap[token] = token_id + + trie = CharTrie() + for s in symbols: + trie.insert(s, tokenmap[s]) + + merges = BPEMerges() + + def bpe(tokenmap, token, max_rank): + parts = list(token) + while True: + min_idx = None + min_rank = None + for i, pair in enumerate(zip(parts[:-1], parts[1:])): + rank = tokenmap.get((pair[0] + pair[1]).encode()) + if rank is not None and (min_rank is None or rank < min_rank): + min_idx = i + min_rank = rank + if min_rank is None or (max_rank is not None and min_rank >= max_rank): + break + assert min_idx is not None + parts = ( + parts[:min_idx] + + [parts[min_idx] + parts[min_idx + 1]] + + parts[min_idx + 2 :] + ) + return parts + + for t in merged: + left, right = bpe(tokenmap, t.decode(), tokenmap[t]) + merges.add(left.encode(), right.encode(), tokenmap[t]) + + return trie, merges + + def read_trie_from_vocab(vocab_file): """Read an :class:`mlx.data.core.CharTrie` from a file with one token per line. diff --git a/python/src/wrap_core.cpp b/python/src/wrap_core.cpp index 5207f36..3ca5574 100644 --- a/python/src/wrap_core.cpp +++ b/python/src/wrap_core.cpp @@ -8,6 +8,7 @@ #include "mlx/data/core/AWSFileFetcher.h" #endif +#include "mlx/data/core/BPETokenizer.h" #include "mlx/data/core/FileFetcher.h" #include "mlx/data/core/Graph.h" #include "mlx/data/core/Levenshtein.h" @@ -133,21 +134,26 @@ void init_mlx_data_core(py::module& m) { .def( "insert", [](std::shared_ptr> trie, - std::variant> token) { + std::variant> token, + int64_t id) { if (std::holds_alternative(token)) { - return trie->insert(std::get(token)); + return trie->insert(std::get(token), id); } else { - return trie->insert(std::get>(token)); + return trie->insert(std::get>(token), id); } }, py::return_value_policy::reference_internal, py::arg("token"), + py::arg("id") = -1, R"pbcopy( Insert a token in the trie making a new token if it doesn't already exist. Args: token (str or list[char]): The new token to be inserted given either as a string or a list of characters. + id (int, optional): The id to assign to the new token to be + inserted. If negative then use ``num_keys()`` as default. + Default: ``-1``. )pbcopy") .def( "search", @@ -275,6 +281,86 @@ void init_mlx_data_core(py::module& m) { input (str): The input string to be tokenized. )pbcopy"); + py::class_>( + m, + "BPEMerges", + R"pbcopy( + A datastructure that holds all possible merges and allows querying + whether two strings can be merged in O(1) time. + )pbcopy") + .def(py::init<>()) + .def( + "add", + &BPEMerges::add, + py::arg("left"), + py::arg("right"), + py::arg("token"), + R"pbcopy( + Add two strings as a possible merge that results in ``token``. + + Args: + left (str): The left side to be merged. + right (str): The right side to be merged. + token (int): The resulting token. + )pbcopy") + .def( + "can_merge", + [](std::shared_ptr& merges, + const std::string& left, + const std::string& right) -> std::optional { + auto [can_merge, token] = merges->can_merge( + std::string_view(left.data(), left.size()), + std::string_view(right.data(), right.size())); + + if (!can_merge) { + return {}; + } + + return token; + }, + py::arg("left"), + py::arg("right"), + R"pbcopy( + Check if ``left`` and ``right`` can be merged to one token. + + Args: + left (str): The left side of the possible token. + right (str): The right side of the possible token. + + Returns: + The token id is returned or None if ``left`` and ``right`` + couldn't be merged. + )pbcopy"); + + py::class_>( + m, + "BPETokenizer", + R"pbcopy( + A tokenizer that uses the BPE algorithm to tokenize strings. + + Args: + symbol_trie (mlx.data.core.CharTrie): The trie containing the basic + symbols that all merges start from. + merges (mlx.data.core.BPEMerges): The datastructure holding the bpe + merges. + )pbcopy") + .def( + py::init< + std::shared_ptr>, + std::shared_ptr>(), + py::arg("symbols"), + py::arg("merges")) + .def( + "tokenize", + &BPETokenizer::tokenize, + py::arg("input"), + R"pbcopy( + Tokenize the input according to the symbols and merges. + + Args: + input (str): The input string to be tokenized. + )pbcopy"); + py::class_>( m, "FileFetcherHandle"); diff --git a/python/src/wrap_dataset.h b/python/src/wrap_dataset.h index 6b09bb8..7d38e74 100644 --- a/python/src/wrap_dataset.h +++ b/python/src/wrap_dataset.h @@ -945,6 +945,42 @@ void mlx_data_export_dataset(py::class_& base) { py::arg("pad") = 0, "Conditional :meth:`Buffer.remove_value`."); + base.def( + "replace", + &T::replace, + py::call_guard(), + py::arg("key"), + py::arg("old"), + py::arg("replacement"), + py::arg("count") = -1, + R"pbdoc( + Replace ``old`` with ``replacement`` in the array at ``key``. + + Example: + + .. code-block:: python + + # Replace ' ' with '▁' to prepare for SPM tokenization. + dset = dset.replace("text", " ", "\u2581") + + Args: + key (str): The sample key that contains the array we are operating on. + old (str): The character sequence that we are replacing. + replacement (str): The character sequence that we are replacing with. + count (int): Perform at most ``count`` replacements. Ignore if negative. + Default: ``-1``. + )pbdoc"); + base.def( + "replace_if", + &T::replace_if, + py::call_guard(), + py::arg("cond"), + py::arg("key"), + py::arg("old"), + py::arg("replacement"), + py::arg("count") = -1, + "Conditional :meth:`Buffer.replace`."); + base.def( "rename_key", &T::rename_key, @@ -1195,5 +1231,40 @@ void mlx_data_export_dataset(py::class_& base) { py::arg("trie_key_scores") = std::vector({}), py::arg("output_key") = "", "Conditional :meth:`Buffer.tokenize`."); + + base.def( + "tokenize_bpe", + &T::tokenize_bpe, + py::call_guard(), + py::arg("key"), + py::arg("symbols"), + py::arg("merges"), + py::arg("output_key") = "", + R"pbcopy( + Tokenize the the contents of the array at ``key`` using the BPE merging + algorithm. + + For instance this can be used to match the tokenization of the + Sentencepiece tokenizers. + + Args: + key (str): The sample key that contains the array we are operating on. + symbols (mlx.data.core.CharTrie): A trie containing the basic symbols + to use for the tokenization. + merges (mlx.data.core.BPEMerges): A datastructure containing the + merges of the basic symbols in order of priority. + output_key (str): If it is not empty then write the result to this + key instead of overwriting ``key``. (default: '') + )pbcopy"); + base.def( + "tokenize_bpe_if", + &T::tokenize_bpe_if, + py::call_guard(), + py::arg("cond"), + py::arg("key"), + py::arg("symbols"), + py::arg("merges"), + py::arg("output_key") = "", + "Conditional :meth:`Buffer.tokenize_bpe`."); } } // namespace diff --git a/python/tests/test_bpe.py b/python/tests/test_bpe.py new file mode 100644 index 0000000..f76a222 --- /dev/null +++ b/python/tests/test_bpe.py @@ -0,0 +1,31 @@ +# Copyright © 2024 Apple Inc. + +import string +import unittest + +from mlx.data.core import BPEMerges, BPETokenizer, CharTrie + + +class TestBpe(unittest.TestCase): + def test_bpe(self): + symbols = CharTrie() + symbols.insert(" ") + for s in string.ascii_letters: + symbols.insert(s) + n = symbols.num_keys() + merges = BPEMerges() + + tokenizer = BPETokenizer(symbols, merges) + + self.assertEqual(tokenizer.tokenize("abcd"), [1, 2, 3, 4]) + + merges.add("a", "b", n + 1) + self.assertEqual(tokenizer.tokenize("abcd"), [n + 1, 3, 4]) + + merges.add("c", "d", n + 2) + merges.add("b", "cd", n + 3) + self.assertEqual(tokenizer.tokenize("abcd"), [n + 1, n + 2]) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_buffer.py b/python/tests/test_buffer.py index df8fcda..1f5da17 100644 --- a/python/tests/test_buffer.py +++ b/python/tests/test_buffer.py @@ -1,11 +1,11 @@ # Copyright © 2024 Apple Inc. -from unittest import TestCase +import unittest import mlx.data as dx -class TestBuffer(TestCase): +class TestBuffer(unittest.TestCase): def test__getitem__(self): n = 5 b = dx.buffer_from_vector(list(dict(i=i) for i in range(n))) @@ -18,3 +18,7 @@ def test__getitem__(self): _ = b[n] with self.assertRaises(IndexError): _ = b[-(n + 1)] + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/test_replace.py b/python/tests/test_replace.py new file mode 100644 index 0000000..eb49f3b --- /dev/null +++ b/python/tests/test_replace.py @@ -0,0 +1,24 @@ +# Copyright © 2024 Apple Inc. + +import unittest + +import mlx.data as dx + + +class TestReplace(unittest.TestCase): + def test_replace(self): + s = "Hello world".encode() + dset = dx.buffer_from_vector([dict(text=s)]) + + ds = dset.replace("text", "world", "everybody!") + self.assertEqual(bytes(ds[0]["text"]), b"Hello everybody!") + + ds = dset.replace("text", "l", "b") + self.assertEqual(bytes(ds[0]["text"]), b"Hebbo worbd") + + ds = dset.replace("text", "l", "b", 2) + self.assertEqual(bytes(ds[0]["text"]), b"Hebbo world") + + +if __name__ == "__main__": + unittest.main()