Skip to content

Commit

Permalink
Lazy import tokenizers (NVIDIA#10213)
Browse files Browse the repository at this point in the history
* Move inflect to lazy import

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Use lazy imports for tokenizer libraries

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* sacremoses lazy import

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* fix cyclic import

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

* import fix

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

* move pangu

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>

* Apply isort and black reformatting

Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>

---------

Signed-off-by: Alexandros Koumparoulis <akoumparouli@nvidia.com>
Signed-off-by: akoumpa <akoumpa@users.noreply.github.com>
Co-authored-by: akoumpa <akoumpa@users.noreply.github.com>
  • Loading branch information
akoumpa and akoumpa authored Aug 28, 2024
1 parent f53600a commit e68f981
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 56 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,8 @@

import os
from argparse import ArgumentParser
from functools import cache

import inflect
import regex as re
from tqdm import tqdm

Expand All @@ -60,12 +60,21 @@
)
from nemo.utils import logging

engine = inflect.engine()

@cache
def inflect_engine():
import inflect

return inflect.engine()


# these are all words that can appear in a verbalized number, this list will be used later as a filter to detect numbers in verbalizations
number_verbalizations = list(range(0, 20)) + list(range(20, 100, 10))
number_verbalizations = (
[engine.number_to_words(x, zero="zero").replace("-", " ").replace(",", "") for x in number_verbalizations]
[
inflect_engine().number_to_words(x, zero="zero").replace("-", " ").replace(",", "")
for x in number_verbalizations
]
+ ["hundred", "thousand", "million", "billion", "trillion"]
+ ["point"]
)
Expand All @@ -85,7 +94,7 @@ def process_url(o):
"""

def flatten(l):
""" flatten a list of lists """
"""flatten a list of lists"""
return [item for sublist in l for item in sublist]

if o != '<self>' and '_letter' in o:
Expand Down Expand Up @@ -129,6 +138,7 @@ def convert2digits(digits: str):
Return:
res: number verbalization of the integer prefix of the input
"""
engine = inflect_engine()
res = []
for i, x in enumerate(digits):
if x in digit:
Expand All @@ -145,6 +155,7 @@ def convert2digits(digits: str):


def convert(example):
engine = inflect_engine()
cls, written, spoken = example

written = convert_fraction(written)
Expand Down Expand Up @@ -288,7 +299,7 @@ def convert(example):
def ignore(example):
"""
This function makes sure specific class types like 'PLAIN', 'ELECTRONIC' etc. are left unchanged.
Args:
example: data example
"""
Expand All @@ -300,7 +311,7 @@ def ignore(example):


def process_file(fp):
""" Reading the raw data from a file of NeMo format and preprocesses it. Write is out to the output directory.
"""Reading the raw data from a file of NeMo format and preprocesses it. Write is out to the output directory.
For more info about the data format, refer to the
`text_normalization doc <https://github.com/NVIDIA/NeMo/blob/main/docs/source/nlp/text_normalization.rst>`.
Expand Down
16 changes: 12 additions & 4 deletions nemo/collections/common/parts/preprocessing/cleaners.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@

import re

import inflect
from text_unidecode import unidecode

from nemo.utils import logging
Expand Down Expand Up @@ -139,7 +138,14 @@
]


inflect = inflect.engine()
from functools import cache


@cache
def inflect_engine():
import inflect

return inflect.engine()


def clean_text(string, table, punctuation_to_replace, abbreviation_version=None):
Expand Down Expand Up @@ -194,11 +200,12 @@ def reset(self):
self.currency = None

def format_final_number(self, whole_num, decimal):
inflect = inflect_engine()
if self.currency:
return_string = inflect.number_to_words(whole_num)
return_string += " dollar" if whole_num == 1 else " dollars"
if decimal:
return_string += " and " + inflect.number_to_words(decimal)
return_string += " and " + inflect_engine().number_to_words(decimal)
return_string += " cent" if whole_num == decimal else " cents"
self.reset()
return return_string
Expand All @@ -210,11 +217,12 @@ def format_final_number(self, whole_num, decimal):
else:
# Check if there are non-numbers
def convert_to_word(match):
return " " + inflect.number_to_words(match.group(0)) + " "
return " " + inflect_engine().number_to_words(match.group(0)) + " "

return re.sub(r'[0-9,]+', convert_to_word, whole_num)

def clean(self, match):
inflect = inflect_engine()
ws = match.group(2)
number = match.group(3)
_proceeding_symbol = match.group(7)
Expand Down
7 changes: 4 additions & 3 deletions nemo/collections/common/tokenizers/en_ja_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,6 @@
import re
from typing import List

from pangu import spacing
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer

try:
import ipadic
import MeCab
Expand All @@ -36,6 +33,8 @@ class EnJaProcessor:
"""

def __init__(self, lang_id: str):
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer

self.lang_id = lang_id
self.moses_tokenizer = MosesTokenizer(lang=lang_id)
self.moses_detokenizer = MosesDetokenizer(lang=lang_id)
Expand Down Expand Up @@ -81,6 +80,8 @@ def __init__(self):
self.mecab_tokenizer = MeCab.Tagger(ipadic.MECAB_ARGS + " -Owakati")

def detokenize(self, text: List[str]) -> str:
from pangu import spacing

RE_WS_IN_FW = re.compile(
r'([\u2018\u2019\u201c\u201d\u2e80-\u312f\u3200-\u32ff\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff\uff00-\uffef])\s+(?=[\u2018\u2019\u201c\u201d\u2e80-\u312f\u3200-\u32ff\u3400-\u4dbf\u4e00-\u9fff\uf900-\ufaff\uff00-\uffef])'
)
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/common/tokenizers/indic_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@

from typing import List

from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer


class IndicProcessor:
"""
Expand All @@ -26,6 +24,8 @@ class IndicProcessor:
def __init__(self, lang_id: str):
if lang_id != 'hi':
raise NotImplementedError
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer

self.moses_tokenizer = MosesTokenizer(lang=lang_id)
self.moses_detokenizer = MosesDetokenizer(lang=lang_id)
self.normalizer = MosesPunctNormalizer(lang=lang_id)
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/common/tokenizers/moses_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,15 @@

from typing import List

from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer


class MosesProcessor:
"""
Tokenizer, Detokenizer and Normalizer utilities in Moses
"""

def __init__(self, lang_id: str):
from sacremoses import MosesDetokenizer, MosesPunctNormalizer, MosesTokenizer

self.moses_tokenizer = MosesTokenizer(lang=lang_id)
self.moses_detokenizer = MosesDetokenizer(lang=lang_id)
self.normalizer = MosesPunctNormalizer(lang=lang_id)
Expand Down
73 changes: 40 additions & 33 deletions nemo/collections/nlp/modules/common/tokenizer_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,8 @@
from dataclasses import MISSING, dataclass
from typing import Dict, List, Optional

import nemo
from nemo.collections.common.tokenizers.bytelevel_tokenizers import ByteLevelTokenizer
from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer
from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer
from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer
from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer
from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer
from nemo.collections.nlp.modules.common.huggingface.huggingface_utils import get_huggingface_pretrained_lm_models_list
from nemo.collections.nlp.modules.common.lm_utils import get_pretrained_lm_models_list
from nemo.collections.nlp.parts.nlp_overrides import HAVE_MEGATRON_CORE
from nemo.utils import logging

try:
from nemo.collections.nlp.modules.common.megatron.megatron_utils import get_megatron_tokenizer

HAVE_MEGATRON_CORE = True

except (ImportError, ModuleNotFoundError):
HAVE_MEGATRON_CORE = False


__all__ = ['get_tokenizer', 'get_tokenizer_list']


Expand Down Expand Up @@ -96,46 +76,61 @@ def get_tokenizer(
model better learn word compositionality and become robust to segmentation errors.
It has emperically been shown to improve inference time BLEU scores.
"""

if special_tokens is None:
special_tokens_dict = {}
else:
special_tokens_dict = special_tokens

if 'megatron' in tokenizer_name:
if not HAVE_MEGATRON_CORE:
try:
from nemo.collections.nlp.modules.common.megatron.megatron_utils import (
get_megatron_merges_file,
get_megatron_tokenizer,
get_megatron_vocab_file,
)
except (ImportError, ModuleNotFoundError):
raise ImportError(
"Megatron-core was not found. Please see the NeMo README for installation instructions: https://github.com/NVIDIA/NeMo#megatron-gpt."
)
if vocab_file is None:
vocab_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_vocab_file(
tokenizer_name
)
merges_file = nemo.collections.nlp.modules.common.megatron.megatron_utils.get_megatron_merges_file(
tokenizer_name
)
vocab_file = get_megatron_vocab_file(tokenizer_name)
merges_file = get_megatron_merges_file(tokenizer_name)
tokenizer_name = get_megatron_tokenizer(tokenizer_name)

if tokenizer_name == 'sentencepiece':
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer

logging.info("tokenizer_model: " + str(tokenizer_model))
return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer(
return SentencePieceTokenizer(
model_path=tokenizer_model,
special_tokens=special_tokens,
legacy=True,
chat_template=chat_template,
)
elif tokenizer_name == 'tiktoken':
return nemo.collections.common.tokenizers.tiktoken_tokenizer.TiktokenTokenizer(vocab_file=vocab_file)
from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer

return TiktokenTokenizer(vocab_file=vocab_file)
elif tokenizer_name == 'word':
from nemo.collections.common.tokenizers.word_tokenizer import WordTokenizer

return WordTokenizer(vocab_file=vocab_file, **special_tokens_dict)
elif tokenizer_name == 'char':
from nemo.collections.common.tokenizers.char_tokenizer import CharTokenizer

return CharTokenizer(vocab_file=vocab_file, **special_tokens_dict)
elif tokenizer_name == 'regex':
from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer

return RegExTokenizer().load_tokenizer(regex_file=tokenizer_model, vocab_file=vocab_file)

logging.info(
f"Getting HuggingFace AutoTokenizer with pretrained_model_name: {tokenizer_name}, vocab_file: {vocab_file}, merges_files: {merges_file}, "
f"special_tokens_dict: {special_tokens_dict}, and use_fast: {use_fast}"
)
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

return AutoTokenizer(
pretrained_model_name=tokenizer_name,
vocab_file=vocab_file,
Expand Down Expand Up @@ -183,6 +178,8 @@ def get_nmt_tokenizer(
raise ValueError("No Tokenizer path provided or file does not exist!")

if library == 'huggingface':
from nemo.collections.common.tokenizers.huggingface.auto_tokenizer import AutoTokenizer

logging.info(f'Getting HuggingFace AutoTokenizer with pretrained_model_name: {model_name}')
return AutoTokenizer(
pretrained_model_name=model_name,
Expand All @@ -193,26 +190,32 @@ def get_nmt_tokenizer(
trust_remote_code=trust_remote_code,
)
elif library == 'sentencepiece':
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer

logging.info(f'Getting SentencePiece with model: {tokenizer_model}')
return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer(
return SentencePieceTokenizer(
model_path=tokenizer_model,
legacy=legacy,
chat_template=chat_template,
)
elif library == 'byte-level':
from nemo.collections.common.tokenizers.bytelevel_tokenizers import ByteLevelTokenizer

logging.info(f'Using byte-level tokenization')
return ByteLevelTokenizer(special_tokens_dict)
elif library == 'regex':
from nemo.collections.common.tokenizers.regex_tokenizer import RegExTokenizer

logging.info(f'Using regex tokenization')
return RegExTokenizer().load_tokenizer(regex_file=tokenizer_model, vocab_file=vocab_file)
elif library == 'megatron':

if model_name == 'GPTSentencePieceTokenizer':
from nemo.collections.common.tokenizers.sentencepiece_tokenizer import SentencePieceTokenizer

logging.info("tokenizer_model: ")
logging.info(tokenizer_model)
return nemo.collections.common.tokenizers.sentencepiece_tokenizer.SentencePieceTokenizer(
model_path=tokenizer_model, legacy=legacy
)
return SentencePieceTokenizer(model_path=tokenizer_model, legacy=legacy)

if model_name in megatron_tokenizer_model_map:
model_name = megatron_tokenizer_model_map[model_name]
Expand All @@ -223,8 +226,12 @@ def get_nmt_tokenizer(
tokenizer_name=model_name, vocab_file=vocab_file, merges_file=merges_file, chat_template=chat_template
)
elif library == 'tabular':
from nemo.collections.common.tokenizers.tabular_tokenizer import TabularTokenizer

return TabularTokenizer(vocab_file, delimiter=delimiter)
elif library == 'tiktoken':
from nemo.collections.common.tokenizers.tiktoken_tokenizer import TiktokenTokenizer

return TiktokenTokenizer(vocab_file=vocab_file)
else:
raise NotImplementedError(
Expand Down
13 changes: 9 additions & 4 deletions scripts/nlp_language_modeling/niv2/preprocess_niv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from argparse import ArgumentParser
from multiprocessing import Pool

from sacremoses import MosesDetokenizer

from nemo.collections.common.tokenizers import AutoTokenizer


Expand Down Expand Up @@ -99,6 +97,8 @@ def write_dataset_to_file(file_name, output_file_name, detokenizer, tokenizer, i


def process_folder(data_folder, output_folder, splits_file, remove_newline):
from sacremoses import MosesDetokenizer

detokenizer = MosesDetokenizer('en')
tokenizer = AutoTokenizer("gpt2")
assert os.path.isdir(data_folder)
Expand Down Expand Up @@ -162,10 +162,15 @@ def process_folder(data_folder, output_folder, splits_file, remove_newline):
help="Path to output folder where JSONL files will be written.",
)
parser.add_argument(
"--splits_file_path", type=str, default="default", help="Path to the file that contains splits. ex: ",
"--splits_file_path",
type=str,
default="default",
help="Path to the file that contains splits. ex: ",
)
parser.add_argument(
"--remove_newline", action="store_true", help="Whether to remove newlines from the input and output.",
"--remove_newline",
action="store_true",
help="Whether to remove newlines from the input and output.",
)
args = parser.parse_args()
process_folder(args.niv2_dataset_path, args.jsonl_output_path, args.splits_file_path, args.remove_newline)
Loading

0 comments on commit e68f981

Please sign in to comment.