diff --git a/nemo/collections/common/tokenizers/__init__.py b/nemo/collections/common/tokenizers/__init__.py index 6a71920bf6d4..4ba946cf9f76 100644 --- a/nemo/collections/common/tokenizers/__init__.py +++ b/nemo/collections/common/tokenizers/__init__.py @@ -19,6 +19,7 @@ from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer +from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer diff --git a/nemo/collections/common/tokenizers/tiktoken_tokenizer.py b/nemo/collections/common/tokenizers/tiktoken_tokenizer.py new file mode 100644 index 000000000000..4b1847051cdc --- /dev/null +++ b/nemo/collections/common/tokenizers/tiktoken_tokenizer.py @@ -0,0 +1,200 @@ +# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import json +import os +from pathlib import Path +from typing import Dict, List, Optional + +try: + import tiktoken +except ImportError: + pass + +from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec + +__all__ = ['TiktokenTokenizer'] + + +def reload_mergeable_ranks( + path: str, + max_vocab: Optional[int] = None, +) -> Dict[bytes, int]: + """ + Reload the tokenizer JSON file and convert it to Tiktoken format. + """ + assert path.endswith(".json") + + # reload vocab + with open(path, "r") as f: + vocab = json.load(f) + assert isinstance(vocab, list) + print(f"Vocab size: {len(vocab)}") + if max_vocab is not None: + vocab = vocab[:max_vocab] + print(f"Cutting vocab to first {len(vocab)} tokens.") + + # build ranks + ranks: Dict[bytes, int] = {} + for i, x in enumerate(vocab): + assert x.keys() == {"rank", "token_bytes", "token_str"} + assert x["rank"] == i + merge = base64.b64decode(x["token_bytes"]) + assert i >= 256 or merge == bytes([i]) + ranks[merge] = x["rank"] + + # sanity check + assert len(ranks) == len(vocab) + assert set(ranks.values()) == set(range(len(ranks))) + + return ranks + + +PATTERN_TIKTOKEN = "[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]*[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]+|[^\\r\\n\\p{L}\\p{N}]?[\\p{Lu}\\p{Lt}\\p{Lm}\\p{Lo}\\p{M}]+[\\p{Ll}\\p{Lm}\\p{Lo}\\p{M}]*|\\p{N}| ?[^\\s\\p{L}\\p{N}]+[\\r\\n/]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+" +DEFAULT_TIKTOKEN_MAX_VOCAB = 2**17 # 131072 +SPECIAL_TOKENS = ["", "", ""] +SPECIAL_TOKEN_TEMPLATE = "" + + +class TiktokenTokenizer(TokenizerSpec): + """ + TiktokenTokenizer https://github.com/openai/tiktoken. + + Args: + model_path: path to tokenizer vocabulary + num_special_tokens: number of special tokens to generate + special_tokens: template for user-defined special tokens + pattern: Regex pattern to split the text + """ + + def __init__( + self, + vocab_file: str, + pattern: str = PATTERN_TIKTOKEN, + vocab_size: int = DEFAULT_TIKTOKEN_MAX_VOCAB, # 131072 + num_special_tokens: int = 1000, + special_tokens: Optional[List[str]] = None, + ): + if not vocab_file or not os.path.exists(vocab_file): + raise ValueError(f"vocab_file: {vocab_file} is invalid") + + if special_tokens is None: + special_tokens = SPECIAL_TOKENS.copy() + + assert len(special_tokens) == len(set(special_tokens)), f"Special tokens should be unique: {special_tokens}" + assert len(special_tokens) <= num_special_tokens < vocab_size + assert set(SPECIAL_TOKENS) <= set(special_tokens), f"Custom special tokens should include {SPECIAL_TOKENS}" + + self._unk_id = special_tokens.index("") + self._bos_id = special_tokens.index("") + self._eos_id = special_tokens.index("") + + self._vocab_size = vocab_size + print(f'{self._vocab_size = }') + self.num_special_tokens = num_special_tokens + special_filler = [SPECIAL_TOKEN_TEMPLATE.format(id=i) for i in range(len(special_tokens), num_special_tokens)] + if special_filler: + print(f"Adding special tokens {special_filler[0]}, ..., {special_filler[-1]}") + self.special_tokens = special_tokens + special_filler + assert len(set(self.special_tokens)) == len(self.special_tokens) == num_special_tokens, self.special_tokens + self.inner_vocab_size = vocab_size - num_special_tokens + + # reload vocab + self.token2id = reload_mergeable_ranks(vocab_file, max_vocab=self.inner_vocab_size) + self.id2token = {v: k for k, v in self.token2id.items()} + assert set(range(self.inner_vocab_size)) == set(self.id2token.keys()) + + self.shifted_id2token = {i: tok for i, tok in enumerate(self.special_tokens)} + for key, value in self.id2token.items(): + self.shifted_id2token[key + self.num_special_tokens] = value + + self.tokenizer = tiktoken.Encoding( + name=Path(vocab_file).parent.name, + pat_str=pattern, + mergeable_ranks=self.token2id, + special_tokens={}, # special tokens are handled manually + ) + + def text_to_tokens(self, text: str): + token_ids = self.tokenizer.encode(text) + return [self.tokenizer.decode_single_token_bytes(token) for token in token_ids] + + def tokens_to_text(self, tokens: List[int]): + token_ids = [self.tokenizer.encode_single_token(tokens) for tokens in tokens] + return self.tokenizer.decode(token_ids) + + def token_to_id(self, token): + return self.tokenizer.encode_single_token(token) + + def tokens_to_ids(self, tokens): + return [self.tokenizer.encode_single_token(token) for token in tokens] + + def ids_to_tokens(self, token_ids): + tokens = [] + for token_id in token_ids: + if token_id < self.num_special_tokens: + tokens.append(self.special_tokens[token_id]) + else: + token_id -= self.num_special_tokens + token_bytes = self.tokenizer.decode_single_token_bytes(token_id) + tokens.append(token_bytes.decode('utf-8', errors='replace')) + return tokens + + def text_to_ids(self, text: str): + tokens = self.tokenizer.encode(text) + tokens = [t + self.num_special_tokens for t in tokens] + return tokens + + def ids_to_text(self, tokens: List[int]): + # Filter out special tokens and adjust the remaining tokens + adjusted_tokens = [ + t - self.num_special_tokens + for t in tokens + if t not in {self.bos, self.eos} and t >= self.num_special_tokens + ] + + # Decode only if there are tokens left after filtering + if adjusted_tokens: + return self.tokenizer.decode(adjusted_tokens) + else: + return "" # Return an empty string if all tokens were filtered out + + @property + def bos_id(self): + return self._bos_id + + @property + def eos_id(self): + return self._eos_id + + @property + def unk_id(self): + return self._unk_id + + @property + def vocab(self): + return self.token2id + + @property + def decoder(self): + return self.shifted_id2token + + @property + def encoder(self): + return self.vocab + + @property + def vocab_size(self) -> int: + return self._vocab_size diff --git a/nemo/collections/nlp/modules/common/tokenizer_utils.py b/nemo/collections/nlp/modules/common/tokenizer_utils.py index d3ee69f75b25..4cbadd87fe52 100644 --- a/nemo/collections/nlp/modules/common/tokenizer_utils.py +++ b/nemo/collections/nlp/modules/common/tokenizer_utils.py @@ -22,6 +22,7 @@ from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer +from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import get_huggingface_pretrained_lm_models_list from nemo.collections.nlp.modules.common.lm_utils import get_pretrained_lm_models_list @@ -122,6 +123,8 @@ def get_tokenizer( legacy=True, chat_template=chat_template, ) + elif tokenizer_name == 'tiktoken': + return nemo.collections.common.tokenizers.tiktoken_tokenizer.TiktokenTokenizer(vocab_file=vocab_file) elif tokenizer_name == 'word': return WordTokenizer(vocab_file=vocab_file, **special_tokens_dict) elif tokenizer_name == 'char': @@ -221,6 +224,8 @@ def get_nmt_tokenizer( ) elif library == 'tabular': return TabularTokenizer(vocab_file, delimiter=delimiter) + elif library == 'tiktoken': + return TiktokenTokenizer(vocab_file=vocab_file) else: raise NotImplementedError( 'Currently we only support "huggingface", "sentencepiece", "megatron", and "byte-level" tokenizer' diff --git a/nemo/export/multimodal/run.py b/nemo/export/multimodal/run.py index 1809a6fc8ce7..86bcc716af79 100644 --- a/nemo/export/multimodal/run.py +++ b/nemo/export/multimodal/run.py @@ -80,7 +80,6 @@ def init_tokenizer(self, llm_engine_dir): self.tokenizer = AutoTokenizer.from_pretrained(os.path.join(llm_engine_dir, 'huggingface_tokenizer')) self.tokenizer.pad_token = self.tokenizer.eos_token - if self.model_type == 'vita': self.tokenizer.im_start_id = self.tokenizer.convert_tokens_to_ids("") self.tokenizer.im_end_id = self.tokenizer.convert_tokens_to_ids("") diff --git a/requirements/requirements_nlp.txt b/requirements/requirements_nlp.txt index a1dad5b64a8a..f98f7c318c56 100644 --- a/requirements/requirements_nlp.txt +++ b/requirements/requirements_nlp.txt @@ -20,4 +20,5 @@ rouge_score sacrebleu # manually install sacrebleu[ja] for Japanese support; MeCab is unsupported in Python 3.11+ sentence_transformers tensorstore<0.1.46 +tiktoken==0.7.0 zarr