Skip to content

Commit

Permalink
Fix tn duplex (#7808)
Browse files Browse the repository at this point in the history
* fix duplex tn infer

Signed-off-by: Evelina <ebakhturina@nvidia.com>

* fix typo

Signed-off-by: Evelina <ebakhturina@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix TN docs

Signed-off-by: Evelina <ebakhturina@nvidia.com>

---------

Signed-off-by: Evelina <ebakhturina@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ekmb and pre-commit-ci[bot] authored Nov 3, 2023
1 parent 05ecfe4 commit b112ce9
Show file tree
Hide file tree
Showing 7 changed files with 24 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ Language Support Matrix
+------------------+----------+----------+----------+--------------------+----------------------+
| Arabic | ar | x | x | | |
+------------------+----------+----------+----------+--------------------+----------------------+
| Russian | ru | x | x | x | |
| Russian | ru | | x | x | |
+------------------+----------+----------+----------+--------------------+----------------------+
| Swedish | sv | x | x | | |
+------------------+----------+----------+----------+--------------------+----------------------+
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,6 @@
from typing import List

from helpers import DECODER_MODEL, TAGGER_MODEL, instantiate_model_and_trainer
from nn_wfst.en.electronic.normalize import ElectronicNormalizer
from nn_wfst.en.whitelist.normalize import WhitelistNormalizer
from omegaconf import DictConfig, OmegaConf

from nemo.collections.nlp.data.text_normalization import constants
Expand All @@ -61,11 +59,16 @@

try:
from nemo_text_processing.text_normalization.data_loader_utils import post_process_punct
from nn_wfst.en.electronic.normalize import ElectronicNormalizer
from nn_wfst.en.whitelist.normalize import WhitelistNormalizer

NEMO_TEXT_PROCESSING_AVAILABLE = True
except (ImportError, ModuleNotFoundError):
raise ModuleNotFoundError(
"The package `nemo_text_processing` was not installed in this environment. Please refer to"
NEMO_TEXT_PROCESSING_AVAILABLE = False
logging.warning(
" `nemo_text_processing` is not installed in this environment. Please refer to"
" https://github.com/NVIDIA/NeMo-text-processing and install this package before using "
"this script"
" this script: `pip install nemo_text_processing`"
)


Expand All @@ -82,7 +85,7 @@ def main(cfg: DictConfig) -> None:
tagger_model.max_sequence_len = 512
tn_model = DuplexTextNormalizationModel(tagger_model, decoder_model, lang)

if lang == constants.ENGLISH:
if lang == constants.ENGLISH and NEMO_TEXT_PROCESSING_AVAILABLE:
normalizer_electronic = ElectronicNormalizer(input_case="cased", lang=lang, deterministic=True)
normalizer_whitelist = WhitelistNormalizer(input_case="cased", lang=lang, deterministic=True)

Expand Down Expand Up @@ -139,7 +142,7 @@ def _get_predictions(lines: List[str], mode: str, batch_size: int, text_file: st
if test_input == "STOP":
done = True
if not done:
if lang == constants.ENGLISH:
if lang == constants.ENGLISH and NEMO_TEXT_PROCESSING_AVAILABLE:
new_input = normalizer_electronic.normalize(test_input, verbose=False)
test_input = post_process_punct(input=test_input, normalized_text=new_input)
new_input = normalizer_whitelist.normalize(test_input, verbose=False)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
deterministic: bool = True,
cache_dir: str = None,
overwrite_cache: bool = False,
max_number_of_permutations_per_split: int = 729,
):

from nn_wfst.en.electronic.tokenize_and_classify import ClassifyFst
Expand All @@ -58,3 +59,4 @@ def __init__(
self.parser = TokenParser()
self.lang = lang
self.processor = MosesProcessor(lang_id=lang)
self.max_number_of_permutations_per_split = max_number_of_permutations_per_split
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def __init__(
punctuation = PunctuationFst(deterministic=deterministic)
punct_graph = punctuation.fst
word_graph = WordFst(deterministic=deterministic, punctuation=punctuation).fst
electonic_graph = ElectronicFst(deterministic=deterministic).fst
electonic_graph = ElectronicFst(cardinal=None, deterministic=deterministic).fst

classify = pynutil.add_weight(electonic_graph, 1.1) | pynutil.add_weight(word_graph, 100)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def __init__(
cache_dir: str = None,
overwrite_cache: bool = False,
whitelist: str = None,
max_number_of_permutations_per_split: int = 729,
):

from nn_wfst.en.whitelist.tokenize_and_classify import ClassifyFst
Expand All @@ -64,3 +65,4 @@ def __init__(
self.parser = TokenParser()
self.lang = lang
self.processor = MosesProcessor(lang_id=lang)
self.max_number_of_permutations_per_split = max_number_of_permutations_per_split
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,9 @@ def _infer(self, sents: List[str], inst_directions: List[str], processed=False):
cur_output_str = post_process_punct(input=original_sents[ix], normalized_text=cur_output_str)
else:
logging.warning(
"`pynini` not installed, please install via nemo_text_processing/pynini_install.sh"
" `nemo_text_processing` is not installed in this environment. Please refer to"
" https://github.com/NVIDIA/NeMo-text-processing and install this package before using "
" this script: `pip install nemo_text_processing`"
)
final_outputs.append(cur_output_str)
except IndexError:
Expand Down
5 changes: 5 additions & 0 deletions nemo/collections/nlp/models/nlp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -439,6 +439,11 @@ def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True):
and "bert_model.embeddings.position_ids" in state_dict
):
del state_dict["bert_model.embeddings.position_ids"]
else:
# fix for albert and other models
pos_id_keys = [x for x in state_dict.keys() if "position_ids" in x]
for key in pos_id_keys:
del state_dict[key]
results = super(NLPModel, self).load_state_dict(state_dict, strict=strict)
return results

Expand Down

0 comments on commit b112ce9

Please sign in to comment.