Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rnnt and char utils #6971

Merged
merged 5 commits into from
Jul 14, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions scripts/asr_language_modeling/ngram_lm/kenlm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,19 +79,21 @@ def setup_tokenizer(nemo_model_file):
)
model = nemo_asr.models.ASRModel.from_pretrained(nemo_model_file, map_location=torch.device('cpu'))

if type(model.tokenizer).__name__ == 'AggregateTokenizer':
is_aggregate_tokenizer = True
else:
is_aggregate_tokenizer = False

is_aggregate_tokenizer = False
tokenizer_nemo = None
encoding_level = SUPPORTED_MODELS.get(type(model).__name__, None)
if not encoding_level:
logging.warning(
f"Model type '{type(model).__name__}' may not be supported. Would try to train a char-level LM."
)
encoding_level = 'char'

tokenizer_nemo = model.tokenizer
if encoding_level == 'subword':
if type(model.tokenizer).__name__ == 'AggregateTokenizer':
is_aggregate_tokenizer = True

tokenizer_nemo = model.tokenizer

del model

return tokenizer_nemo, encoding_level, is_aggregate_tokenizer
Expand All @@ -117,10 +119,10 @@ def iter_files(source_path, dest_path, tokenizer, encoding_level, is_aggregate_t
if isinstance(dest_path, str):
with open(dest_path, 'w', encoding='utf-8') as f:
for line in dataset:
f.write(line + "\n")
f.write(line[0] + "\n")
else: # write to stdin of KenLM
for line in dataset:
dest_path.write((line + '\n').encode())
dest_path.write((line[0] + '\n').encode())


def read_train_file(
Expand Down
83 changes: 39 additions & 44 deletions scripts/asr_language_modeling/ngram_lm/ngram_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import torch

import nemo.collections.asr as nemo_asr
from nemo.collections.asr.modules.rnnt import RNNTDecoder
from nemo.collections.asr.parts.submodules.ctc_beam_decoding import DEFAULT_TOKEN_OFFSET
from nemo.utils import logging

Expand Down Expand Up @@ -207,9 +208,7 @@ def make_arpa(self, ngram_mod: str, ngram_arpa: str, force: bool):
]
return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,)

def test_perplexity(
self, mod_c: str, symbols: str, test_txt: str, nemo_model_file: str, tmp_path: str, force: bool
) -> str:
def test_perplexity(self, mod_c: str, symbols: str, test_txt: str, nemo_model_file: str, tmp_path: str) -> str:
"""
Tests the perplexity of a given ngram model on a test file.

Expand All @@ -229,12 +228,12 @@ def test_perplexity(
'Perplexity: 123.45'
"""

test_far = farcompile(symbols, test_txt, tmp_path, nemo_model_file, force)
test_far = farcompile(symbols, test_txt, tmp_path, nemo_model_file)
res_p = self.perplexity(mod_c, test_far)
return res_p


def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str, force: bool,) -> str:
def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str) -> str:
"""
Compiles a text file into a FAR file using the given symbol table or tokenizer.

Expand All @@ -253,43 +252,35 @@ def farcompile(symbols: str, text_file: str, tmp_path: str, nemo_model_file: str
"""
test_far = os.path.join(tmp_path, os.path.split(text_file)[1] + ".far")

if os.path.isfile(test_far) and not force:
logging.info("File " + test_far + " exists. Skipping.")
return None
else:
sh_args = [
"farcompilestrings",
"--generate_keys=10",
"--fst_type=compact",
"--symbols=" + symbols,
"--keep_symbols",
">",
test_far,
]

tokenizer, encoding_level, is_aggregate_tokenizer = kenlm_utils.setup_tokenizer(nemo_model_file)

ps = subprocess.Popen(
" ".join(sh_args), shell=True, stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr,
)

kenlm_utils.iter_files(
source_path=[text_file],
dest_path=ps.stdin,
tokenizer=tokenizer,
encoding_level=encoding_level,
is_aggregate_tokenizer=is_aggregate_tokenizer,
verbose=1,
)
stdout, stderr = ps.communicate()
sh_args = [
"farcompilestrings",
"--generate_keys=10",
"--fst_type=compact",
"--symbols=" + symbols,
"--keep_symbols",
">",
test_far,
]

tokenizer, encoding_level, is_aggregate_tokenizer = kenlm_utils.setup_tokenizer(nemo_model_file)

ps = subprocess.Popen(" ".join(sh_args), shell=True, stdin=subprocess.PIPE, stdout=sys.stdout, stderr=sys.stderr,)

kenlm_utils.iter_files(
source_path=[text_file],
dest_path=ps.stdin,
tokenizer=tokenizer,
encoding_level=encoding_level,
is_aggregate_tokenizer=is_aggregate_tokenizer,
verbose=1,
)
stdout, stderr = ps.communicate()

exit_code = ps.returncode
exit_code = ps.returncode

command = " ".join(sh_args)
assert (
exit_code == 0
), f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}"
return test_far
command = " ".join(sh_args)
assert exit_code == 0, f"Exit_code must be 0.\n bash command: {command} \n stdout: {stdout} \n stderr: {stderr}"
return test_far


def make_kenlm(kenlm_bin_path: str, ngram_arpa: str, force: bool):
Expand All @@ -310,7 +301,7 @@ def make_kenlm(kenlm_bin_path: str, ngram_arpa: str, force: bool):
logging.info("File " + ngram_kenlm + " exists. Skipping.")
return None
else:
sh_args = [kenlm_bin_path, "trie", "-i", ngram_arpa, ngram_kenlm]
sh_args = [os.path.join(kenlm_bin_path, "build_binary"), "trie", "-i", ngram_arpa, ngram_kenlm]
return subprocess.run(sh_args, capture_output=False, text=True, stdout=sys.stdout, stderr=sys.stderr,)


Expand All @@ -336,12 +327,15 @@ def make_symbol_list(nemo_model_file, symbols, force):
else:
if nemo_model_file.endswith('.nemo'):
asr_model = nemo_asr.models.ASRModel.restore_from(nemo_model_file, map_location=torch.device('cpu'))
vocab_size = len(asr_model.decoder.vocabulary)
else:
logging.warning(
"nemo_model_file does not end with .nemo, therefore trying to load a pretrained model with this name."
)
asr_model = nemo_asr.models.ASRModel.from_pretrained(nemo_model_file, map_location=torch.device('cpu'))

if isinstance(asr_model.decoder, RNNTDecoder):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@titu1994 is it the best way to have the vocab size?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As long as its not a TDT model or a Multiblank RNNT model I think. Otherwise you gotta add another value that counts the number of blank/duration ids. But this code is fine for RNNT alone until we support Beam search with TDT

vocab_size = asr_model.decoder.blank_idx
else:
vocab_size = len(asr_model.decoder.vocabulary)

vocab = [chr(idx + DEFAULT_TOKEN_OFFSET) for idx in range(vocab_size)]
Expand Down Expand Up @@ -389,8 +383,9 @@ def main(
if not symbols:
symbols = os.path.join(out_path, os.path.split(nemo_model_file)[1] + ".syms")
make_symbol_list(nemo_model_file, symbols, force)
test_p = nm.test_perplexity(mod_c, symbols, test_file, nemo_model_file, out_path, force)
logging.info("Perplexity summary:" + test_p)
for test_f in test_file.split(","):
test_p = nm.test_perplexity(mod_c, symbols, test_f, nemo_model_file, out_path)
logging.info("Perplexity summary " + test_f + " : " + test_p)

logging.info("Making ARPA and Kenlm model " + arpa_c)
out = nm.make_arpa(mod_c, arpa_c, force)
Expand Down