From b112ce9e4362725eaa25305e3d25fa84027d32ef Mon Sep 17 00:00:00 2001 From: Evelina <10428420+ekmb@users.noreply.github.com> Date: Thu, 2 Nov 2023 19:16:00 -0700 Subject: [PATCH] Fix tn duplex (#7808) * fix duplex tn infer Signed-off-by: Evelina * fix typo Signed-off-by: Evelina * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix TN docs Signed-off-by: Evelina --------- Signed-off-by: Evelina Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../wfst/wfst_text_normalization.rst | 2 +- .../duplex_text_normalization_infer.py | 17 ++++++++++------- .../nn_wfst/en/electronic/normalize.py | 2 ++ .../en/electronic/tokenize_and_classify.py | 2 +- .../nn_wfst/en/whitelist/normalize.py | 2 ++ .../duplex_text_normalization/duplex_tn.py | 4 +++- nemo/collections/nlp/models/nlp_model.py | 5 +++++ 7 files changed, 24 insertions(+), 10 deletions(-) diff --git a/docs/source/nlp/text_normalization/wfst/wfst_text_normalization.rst b/docs/source/nlp/text_normalization/wfst/wfst_text_normalization.rst index 874c90567823..b664eed23a1a 100644 --- a/docs/source/nlp/text_normalization/wfst/wfst_text_normalization.rst +++ b/docs/source/nlp/text_normalization/wfst/wfst_text_normalization.rst @@ -166,7 +166,7 @@ Language Support Matrix +------------------+----------+----------+----------+--------------------+----------------------+ | Arabic | ar | x | x | | | +------------------+----------+----------+----------+--------------------+----------------------+ -| Russian | ru | x | x | x | | +| Russian | ru | | x | x | | +------------------+----------+----------+----------+--------------------+----------------------+ | Swedish | sv | x | x | | | +------------------+----------+----------+----------+--------------------+----------------------+ diff --git a/examples/nlp/duplex_text_normalization/duplex_text_normalization_infer.py b/examples/nlp/duplex_text_normalization/duplex_text_normalization_infer.py index 6bcc69de7db9..3bb782ee7293 100644 --- a/examples/nlp/duplex_text_normalization/duplex_text_normalization_infer.py +++ b/examples/nlp/duplex_text_normalization/duplex_text_normalization_infer.py @@ -50,8 +50,6 @@ from typing import List from helpers import DECODER_MODEL, TAGGER_MODEL, instantiate_model_and_trainer -from nn_wfst.en.electronic.normalize import ElectronicNormalizer -from nn_wfst.en.whitelist.normalize import WhitelistNormalizer from omegaconf import DictConfig, OmegaConf from nemo.collections.nlp.data.text_normalization import constants @@ -61,11 +59,16 @@ try: from nemo_text_processing.text_normalization.data_loader_utils import post_process_punct + from nn_wfst.en.electronic.normalize import ElectronicNormalizer + from nn_wfst.en.whitelist.normalize import WhitelistNormalizer + + NEMO_TEXT_PROCESSING_AVAILABLE = True except (ImportError, ModuleNotFoundError): - raise ModuleNotFoundError( - "The package `nemo_text_processing` was not installed in this environment. Please refer to" + NEMO_TEXT_PROCESSING_AVAILABLE = False + logging.warning( + " `nemo_text_processing` is not installed in this environment. Please refer to" " https://github.com/NVIDIA/NeMo-text-processing and install this package before using " - "this script" + " this script: `pip install nemo_text_processing`" ) @@ -82,7 +85,7 @@ def main(cfg: DictConfig) -> None: tagger_model.max_sequence_len = 512 tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang) - if lang == constants.ENGLISH: + if lang == constants.ENGLISH and NEMO_TEXT_PROCESSING_AVAILABLE: normalizer_electronic = ElectronicNormalizer(input_case="cased", lang=lang, deterministic=True) normalizer_whitelist = WhitelistNormalizer(input_case="cased", lang=lang, deterministic=True) @@ -139,7 +142,7 @@ def _get_predictions(lines: List[str], mode: str, batch_size: int, text_file: st if test_input == "STOP": done = True if not done: - if lang == constants.ENGLISH: + if lang == constants.ENGLISH and NEMO_TEXT_PROCESSING_AVAILABLE: new_input = normalizer_electronic.normalize(test_input, verbose=False) test_input = post_process_punct(input=test_input, normalized_text=new_input) new_input = normalizer_whitelist.normalize(test_input, verbose=False) diff --git a/examples/nlp/duplex_text_normalization/nn_wfst/en/electronic/normalize.py b/examples/nlp/duplex_text_normalization/nn_wfst/en/electronic/normalize.py index a1f8caa7d959..94eee524a328 100644 --- a/examples/nlp/duplex_text_normalization/nn_wfst/en/electronic/normalize.py +++ b/examples/nlp/duplex_text_normalization/nn_wfst/en/electronic/normalize.py @@ -45,6 +45,7 @@ def __init__( deterministic: bool = True, cache_dir: str = None, overwrite_cache: bool = False, + max_number_of_permutations_per_split: int = 729, ): from nn_wfst.en.electronic.tokenize_and_classify import ClassifyFst @@ -58,3 +59,4 @@ def __init__( self.parser = TokenParser() self.lang = lang self.processor = MosesProcessor(lang_id=lang) + self.max_number_of_permutations_per_split = max_number_of_permutations_per_split diff --git a/examples/nlp/duplex_text_normalization/nn_wfst/en/electronic/tokenize_and_classify.py b/examples/nlp/duplex_text_normalization/nn_wfst/en/electronic/tokenize_and_classify.py index 9e0c284d84b0..33694d19c72c 100644 --- a/examples/nlp/duplex_text_normalization/nn_wfst/en/electronic/tokenize_and_classify.py +++ b/examples/nlp/duplex_text_normalization/nn_wfst/en/electronic/tokenize_and_classify.py @@ -70,7 +70,7 @@ def __init__( punctuation = PunctuationFst(deterministic=deterministic) punct_graph = punctuation.fst word_graph = WordFst(deterministic=deterministic, punctuation=punctuation).fst - electonic_graph = ElectronicFst(deterministic=deterministic).fst + electonic_graph = ElectronicFst(cardinal=None, deterministic=deterministic).fst classify = pynutil.add_weight(electonic_graph, 1.1) | pynutil.add_weight(word_graph, 100) diff --git a/examples/nlp/duplex_text_normalization/nn_wfst/en/whitelist/normalize.py b/examples/nlp/duplex_text_normalization/nn_wfst/en/whitelist/normalize.py index cfb4bef5d1c3..6b9c0adba69b 100644 --- a/examples/nlp/duplex_text_normalization/nn_wfst/en/whitelist/normalize.py +++ b/examples/nlp/duplex_text_normalization/nn_wfst/en/whitelist/normalize.py @@ -47,6 +47,7 @@ def __init__( cache_dir: str = None, overwrite_cache: bool = False, whitelist: str = None, + max_number_of_permutations_per_split: int = 729, ): from nn_wfst.en.whitelist.tokenize_and_classify import ClassifyFst @@ -64,3 +65,4 @@ def __init__( self.parser = TokenParser() self.lang = lang self.processor = MosesProcessor(lang_id=lang) + self.max_number_of_permutations_per_split = max_number_of_permutations_per_split diff --git a/nemo/collections/nlp/models/duplex_text_normalization/duplex_tn.py b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tn.py index b83ad8eca2e7..31e42bde9258 100644 --- a/nemo/collections/nlp/models/duplex_text_normalization/duplex_tn.py +++ b/nemo/collections/nlp/models/duplex_text_normalization/duplex_tn.py @@ -287,7 +287,9 @@ def _infer(self, sents: List[str], inst_directions: List[str], processed=False): cur_output_str = post_process_punct(input=original_sents[ix], normalized_text=cur_output_str) else: logging.warning( - "`pynini` not installed, please install via nemo_text_processing/pynini_install.sh" + " `nemo_text_processing` is not installed in this environment. Please refer to" + " https://github.com/NVIDIA/NeMo-text-processing and install this package before using " + " this script: `pip install nemo_text_processing`" ) final_outputs.append(cur_output_str) except IndexError: diff --git a/nemo/collections/nlp/models/nlp_model.py b/nemo/collections/nlp/models/nlp_model.py index ac3a8c998ba7..04bbb2ca17fe 100644 --- a/nemo/collections/nlp/models/nlp_model.py +++ b/nemo/collections/nlp/models/nlp_model.py @@ -439,6 +439,11 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): and "bert_model.embeddings.position_ids" in state_dict ): del state_dict["bert_model.embeddings.position_ids"] + else: + # fix for albert and other models + pos_id_keys = [x for x in state_dict.keys() if "position_ids" in x] + for key in pos_id_keys: + del state_dict[key] results = super(NLPModel, self).load_state_dict(state_dict, strict=strict) return results