From d8840f1a50558ae113bb06c27243f4a33c41e25d Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Fri, 30 Jun 2023 06:23:13 -0700 Subject: [PATCH 1/3] rnnt_ngram_merge Signed-off-by: Nikolay Karpov --- .../ngram_lm/ngram_merge.py | 79 ++++++++++--------- 1 file changed, 40 insertions(+), 39 deletions(-) diff --git a/scripts/asr_language_modeling/ngram_lm/ngram_merge.py b/scripts/asr_language_modeling/ngram_lm/ngram_merge.py index abffc6372518..75e2221c5540 100644 --- a/scripts/asr_language_modeling/ngram_lm/ngram_merge.py +++ b/scripts/asr_language_modeling/ngram_lm/ngram_merge.py @@ -52,6 +52,7 @@ import nemo.collections.asr as nemo_asr from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET +from nemo.collections.asr.modules.rnnt import RNNTDecoder from nemo.utils import logging @@ -208,7 +209,7 @@ def make_arpa(self, ngram_mod: str, ngram_arpa: str, force: bool): return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,) def test_perplexity( - self, mod_c: str, symbols: str, test_txt: str, nemo_model_file: str, tmp_path: str, force: bool + self, mod_c: str, symbols: str, test_txt: str, nemo_model_file: str, tmp_path: str ) -> str: """ Tests the perplexity of a given ngram model on a test file. @@ -229,12 +230,12 @@ def test_perplexity( 'Perplexity: 123.45' """ - test_far = farcompile(symbols, test_txt, tmp_path, nemo_model_file, force) + test_far = farcompile(symbols, test_txt, tmp_path, nemo_model_file) res_p = self.perplexity(mod_c, test_far) return res_p -def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str, force: bool,) -> str: +def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str) -> str: """ Compiles a text file into a FAR file using the given symbol table or tokenizer. @@ -253,43 +254,39 @@ def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str """ test_far = os.path.join(tmp_path, os.path.split(text_file)[1] + ".far") - if os.path.isfile(test_far) and not force: - logging.info("File " + test_far + " exists. Skipping.") - return None - else: - sh_args = [ - "farcompilestrings", - "--generate_keys=10", - "--fst_type=compact", - "--symbols=" + symbols, - "--keep_symbols", - ">", - test_far, - ] + sh_args = [ + "farcompilestrings", + "--generate_keys=10", + "--fst_type=compact", + "--symbols=" + symbols, + "--keep_symbols", + ">", + test_far, + ] - tokenizer, encoding_level, is_aggregate_tokenizer = kenlm_utils.setup_tokenizer(nemo_model_file) + tokenizer, encoding_level, is_aggregate_tokenizer = kenlm_utils.setup_tokenizer(nemo_model_file) - ps = subprocess.Popen( - " ".join(sh_args), shell=True, stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr, - ) + ps = subprocess.Popen( + " ".join(sh_args), shell=True, stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr, + ) - kenlm_utils.iter_files( - source_path=[text_file], - dest_path=ps.stdin, - tokenizer=tokenizer, - encoding_level=encoding_level, - is_aggregate_tokenizer=is_aggregate_tokenizer, - verbose=1, - ) - stdout, stderr = ps.communicate() + kenlm_utils.iter_files( + source_path=[text_file], + dest_path=ps.stdin, + tokenizer=tokenizer, + encoding_level=encoding_level, + is_aggregate_tokenizer=is_aggregate_tokenizer, + verbose=1, + ) + stdout, stderr = ps.communicate() - exit_code = ps.returncode + exit_code = ps.returncode - command = " ".join(sh_args) - assert ( - exit_code == 0 - ), f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}" - return test_far + command = " ".join(sh_args) + assert ( + exit_code == 0 + ), f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}" + return test_far def make_kenlm(kenlm_bin_path: str, ngram_arpa: str, force: bool): @@ -310,7 +307,7 @@ def make_kenlm(kenlm_bin_path: str, ngram_arpa: str, force: bool): logging.info("File " + ngram_kenlm + " exists. Skipping.") return None else: - sh_args = [kenlm_bin_path, "trie", "-i", ngram_arpa, ngram_kenlm] + sh_args = [os.path.join(kenlm_bin_path,"build_binary"), "trie", "-i", ngram_arpa, ngram_kenlm] return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,) @@ -336,12 +333,15 @@ def make_symbol_list(nemo_model_file, symbols, force): else: if nemo_model_file.endswith('.nemo'): asr_model = nemo_asr.models.ASRModel.restore_from(nemo_model_file, map_location=torch.device('cpu')) - vocab_size = len(asr_model.decoder.vocabulary) else: logging.warning( "nemo_model_file does not end with .nemo, therefore trying to load a pretrained model with this name." ) asr_model = nemo_asr.models.ASRModel.from_pretrained(nemo_model_file, map_location=torch.device('cpu')) + + if isinstance(asr_model.decoder, RNNTDecoder): + vocab_size = asr_model.decoder.blank_idx + else: vocab_size = len(asr_model.decoder.vocabulary) vocab = [chr(idx + DEFAULT_TOKEN_OFFSET) for idx in range(vocab_size)] @@ -389,8 +389,9 @@ def main( if not symbols: symbols = os.path.join(out_path, os.path.split(nemo_model_file)[1] + ".syms") make_symbol_list(nemo_model_file, symbols, force) - test_p = nm.test_perplexity(mod_c, symbols, test_file, nemo_model_file, out_path, force) - logging.info("Perplexity summary:" + test_p) + for test_f in test_file.split(","): + test_p = nm.test_perplexity(mod_c, symbols, test_f, nemo_model_file, out_path) + logging.info("Perplexity summary " + test_f + " : " + test_p) logging.info("Making ARPA and Kenlm model " + arpa_c) out = nm.make_arpa(mod_c, arpa_c, force) From 9ff8d0bd7976fc08747f6528e70a02415d5bd1b6 Mon Sep 17 00:00:00 2001 From: Nikolay Karpov Date: Tue, 4 Jul 2023 02:55:23 -0700 Subject: [PATCH 2/3] char level bug Signed-off-by: Nikolay Karpov --- .../ngram_lm/kenlm_utils.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/scripts/asr_language_modeling/ngram_lm/kenlm_utils.py b/scripts/asr_language_modeling/ngram_lm/kenlm_utils.py index 9e255ddc50ca..d9b48afab292 100644 --- a/scripts/asr_language_modeling/ngram_lm/kenlm_utils.py +++ b/scripts/asr_language_modeling/ngram_lm/kenlm_utils.py @@ -79,11 +79,8 @@ def setup_tokenizer(nemo_model_file): ) model = nemo_asr.models.ASRModel.from_pretrained(nemo_model_file, map_location=torch.device('cpu')) - if type(model.tokenizer).__name__ == 'AggregateTokenizer': - is_aggregate_tokenizer = True - else: - is_aggregate_tokenizer = False - + is_aggregate_tokenizer = False + tokenizer_nemo = None encoding_level = SUPPORTED_MODELS.get(type(model).__name__, None) if not encoding_level: logging.warning( @@ -91,7 +88,12 @@ def setup_tokenizer(nemo_model_file): ) encoding_level = 'char' - tokenizer_nemo = model.tokenizer + if encoding_level == 'subword': + if type(model.tokenizer).__name__ == 'AggregateTokenizer': + is_aggregate_tokenizer = True + + tokenizer_nemo = model.tokenizer + del model return tokenizer_nemo, encoding_level, is_aggregate_tokenizer @@ -117,10 +119,10 @@ def iter_files(source_path, dest_path, tokenizer, encoding_level, is_aggregate_t if isinstance(dest_path, str): with open(dest_path, 'w', encoding='utf-8') as f: for line in dataset: - f.write(line + "\n") + f.write(line[0] + "\n") else: # write to stdin of KenLM for line in dataset: - dest_path.write((line + '\n').encode()) + dest_path.write((line[0] + '\n').encode()) def read_train_file( From ed59f550bfdea100a6196b52f97b4c93724807fa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 4 Jul 2023 11:10:12 +0000 Subject: [PATCH 3/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../ngram_lm/ngram_merge.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/scripts/asr_language_modeling/ngram_lm/ngram_merge.py b/scripts/asr_language_modeling/ngram_lm/ngram_merge.py index 75e2221c5540..b6606286ae5b 100644 --- a/scripts/asr_language_modeling/ngram_lm/ngram_merge.py +++ b/scripts/asr_language_modeling/ngram_lm/ngram_merge.py @@ -51,8 +51,8 @@ import torch import nemo.collections.asr as nemo_asr -from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET from nemo.collections.asr.modules.rnnt import RNNTDecoder +from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET from nemo.utils import logging @@ -208,9 +208,7 @@ def make_arpa(self, ngram_mod: str, ngram_arpa: str, force: bool): ] return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,) - def test_perplexity( - self, mod_c: str, symbols: str, test_txt: str, nemo_model_file: str, tmp_path: str - ) -> str: + def test_perplexity(self, mod_c: str, symbols: str, test_txt: str, nemo_model_file: str, tmp_path: str) -> str: """ Tests the perplexity of a given ngram model on a test file. @@ -266,9 +264,7 @@ def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str tokenizer, encoding_level, is_aggregate_tokenizer = kenlm_utils.setup_tokenizer(nemo_model_file) - ps = subprocess.Popen( - " ".join(sh_args), shell=True, stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr, - ) + ps = subprocess.Popen(" ".join(sh_args), shell=True, stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr,) kenlm_utils.iter_files( source_path=[text_file], @@ -283,9 +279,7 @@ def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str exit_code = ps.returncode command = " ".join(sh_args) - assert ( - exit_code == 0 - ), f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}" + assert exit_code == 0, f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}" return test_far @@ -307,7 +301,7 @@ def make_kenlm(kenlm_bin_path: str, ngram_arpa: str, force: bool): logging.info("File " + ngram_kenlm + " exists. Skipping.") return None else: - sh_args = [os.path.join(kenlm_bin_path,"build_binary"), "trie", "-i", ngram_arpa, ngram_kenlm] + sh_args = [os.path.join(kenlm_bin_path, "build_binary"), "trie", "-i", ngram_arpa, ngram_kenlm] return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,) @@ -338,7 +332,7 @@ def make_symbol_list(nemo_model_file, symbols, force): "nemo_model_file does not end with .nemo, therefore trying to load a pretrained model with this name." ) asr_model = nemo_asr.models.ASRModel.from_pretrained(nemo_model_file, map_location=torch.device('cpu')) - + if isinstance(asr_model.decoder, RNNTDecoder): vocab_size = asr_model.decoder.blank_idx else: