Skip to content

Commit

Permalink
add tokel level lm boosting
Browse files Browse the repository at this point in the history
  • Loading branch information
andrusenkoau committed Jul 20, 2023
1 parent 68ae3fc commit 63a93be
Show file tree
Hide file tree
Showing 7 changed files with 147 additions and 12 deletions.
1 change: 1 addition & 0 deletions examples/asr/speech_to_text_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
class EvaluationConfig(transcribe_speech.TranscriptionConfig):
dataset_manifest: str = MISSING
output_filename: Optional[str] = "evaluation_transcripts.json"
decoder_type: Optional[str] = None

use_cer: bool = False
tolerance: Optional[float] = None
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/metrics/rnnt_wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ class AbstractRNNTDecoding(ConfidenceMixin):
blank_id: The id of the RNNT blank token.
"""

def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
def __init__(self, decoding_cfg, decoder, joint, blank_id: int, tokenizer=None):
super(AbstractRNNTDecoding, self).__init__()

# Convert dataclass to config object
Expand Down Expand Up @@ -366,6 +366,7 @@ def __init__(self, decoding_cfg, decoder, joint, blank_id: int):
ngram_lm_alpha=self.cfg.beam.get('ngram_lm_alpha', 0.0),
hat_subtract_ilm=self.cfg.beam.get('hat_subtract_ilm', False),
hat_ilm_weight=self.cfg.beam.get('hat_ilm_weight', 0.0),
tokenizer=tokenizer,
)

else:
Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/metrics/rnnt_wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec):
self.tokenizer = tokenizer

super(RNNTBPEDecoding, self).__init__(
decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id + joint.num_extra_outputs
decoding_cfg=decoding_cfg, decoder=decoder, joint=joint, blank_id=blank_id + joint.num_extra_outputs, tokenizer=tokenizer
)

if isinstance(self.decoding, rnnt_beam_decoding.BeamRNNTInfer):
Expand Down
20 changes: 15 additions & 5 deletions nemo/collections/asr/parts/submodules/rnnt_beam_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@
from nemo.core.neural_types import AcousticEncodedRepresentation, HypothesisType, LengthsType, NeuralType
from nemo.utils import logging

from nemo.collections.common.tokenizers.tokenizer_spec import TokenizerSpec

try:
import kenlm

Expand Down Expand Up @@ -237,9 +239,15 @@ def __init__(
ngram_lm_alpha: float = 0.0,
hat_subtract_ilm: bool = False,
hat_ilm_weight: float = 0.0,
tokenizer: TokenizerSpec = None,
):
self.decoder = decoder_model
self.joint = joint_model
self.tokenizer = tokenizer
# logging.warning(f"**************************************")
# logging.warning(f"***[DEBUG]: labels_map is {tokenizer}")
# logging.warning(f"**************************************")
# raise KeyError

self.blank = decoder_model.blank_idx
self.vocab_size = decoder_model.vocab_size
Expand Down Expand Up @@ -1461,12 +1469,14 @@ def compute_ngram_score(self, current_lm_state: "kenlm.State", label: int) -> Tu
Score computation for kenlm ngram language model.
"""

if self.token_offset:
label = chr(label + self.token_offset)
else:
label = str(label)
label = self.tokenizer.ids_to_tokens([label])

# if self.token_offset:
# label = chr(label + self.token_offset)
# else:
# label = str(label)
next_state = kenlm.State()
lm_score = self.ngram_lm.BaseScore(current_lm_state, label, next_state)
lm_score = self.ngram_lm.BaseScore(current_lm_state, label[0], next_state)
lm_score *= 1.0 / np.log10(np.e)

return lm_score, next_state
Expand Down
99 changes: 99 additions & 0 deletions scripts/asr_language_modeling/ngram_lm/compute_key_words_fscore.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
#!/usr/bin/env python

import argparse
import json
import os
from kaldialign import align


def load_data(manifest):
data = []
with open(manifest, 'r') as f:
for line in f:
item = json.loads(line)
data.append(item)
return data


def print_alignment(audio_filepath, ali, key_words):
ref, hyp = [], []
for pair in ali:
if pair[0] in key_words:
ref.append(pair[0].upper())
hyp.append(pair[1].upper())
else:
ref.append(pair[0])
hyp.append(pair[1])
print(" ")
print(f"ID: {os.path.basename(audio_filepath)}")
print(f"REF: {' '.join(ref)}")
print(f"HYP: {' '.join(hyp)}")


def compute_fscore(recognition_results_manifest, key_words_list):

data = load_data(recognition_results_manifest)
key_words_set = set(key_words_list)
key_words_stat = {}
for word in key_words_set:
key_words_stat[word] = [0, 0]

gt, fn, fp, tn, tp = 0, 0, 0, 0, 0
eps = '***'

for item in data:
audio_filepath = item['audio_filepath']
ref = item['text'].split()
hyp = item['pred_text'].split()
ali = align(ref, hyp, eps)
recognized_words = []
for pair in ali:
if pair[0] in key_words_set:
gt += 1
key_words_stat[pair[0]][-1] += 1
if pair[0] == pair[1]:
tp += 1
recognized_words.append(pair[0])
key_words_stat[pair[0]][0] += 1
if pair[1] in key_words_set:
if pair[0] != pair[1]:
fp += 1
if recognized_words:
print_alignment(audio_filepath, ali, recognized_words)

precision = tp / (tp + fp + 1e-8)
recall = tp / (gt + 1e-8)
fscore = 2*(precision*recall)/(precision+recall + 1e-8)

print("\n"+"***"*15)
print("Per words statistic (word: correct/totall):\n")
max_len = max([len(x) for x in key_words_stat])
for word in key_words_stat:
print(f"{word:{max_len}}: {key_words_stat[word][0]}/{key_words_stat[word][-1]}")
print("***"*15)

print(" ")
print("***"*10)
print(f"Precision: {precision:.4f} ({tp}/{tp + fp})")
print(f"Recall: {recall:.4f} ({tp}/{gt})")
print(f"Fscore: {fscore:.4f}")
print("***"*10)



def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--input_manifest", type=str, required=True, help="manifest with recognition results",
)
parser.add_argument(
"--key_words_list", type=str, required=True, help="list of key words for fscore calculation"
)

args = parser.parse_args()
key_words_list = [x for x in args.key_words_list.split(' ')]
compute_fscore(args.input_manifest, key_words_list)


if __name__ == '__main__':
main()
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ class EvalBeamSearchNGramConfig:
device: str = "cuda" # The device to load the model onto to calculate log probabilities
use_amp: bool = False # Whether to use AMP if available to calculate log probabilities
num_workers: int = 1 # Number of workers for DataLoader

# for hybrid model
decoder_type: Optional[str] = None # [ctc, rnnt] Decoder type for hybrid ctc-rnnt model

# The decoding scheme to be used for evaluation
decoding_strategy: str = "greedy_batch" # ["greedy_batch", "beam", "tsd", "alsd", "maes"]
Expand All @@ -126,6 +129,8 @@ def decoding_step(
cfg: EvalBeamSearchNGramConfig,
all_probs: List[torch.Tensor],
target_transcripts: List[str],
audio_file_paths: List[str],
durations: List[str],
preds_output_file: str = None,
beam_batch_size: int = 128,
progress_bar: bool = True,
Expand Down Expand Up @@ -205,10 +210,20 @@ def decoding_step(
cer_dist_first += cer_dist

score = candidate.score
if preds_output_file:
out_file.write('{}\t{}\n'.format(pred_text, score))

#out_file.write('{}\t{}\n'.format(pred_text, score))
wer_dist_best += wer_dist_min
cer_dist_best += cer_dist_min

# write manifest with prediction results
if preds_output_file:
item = {'audio_filepath': audio_file_paths[sample_idx + beams_idx],
'duration': durations[sample_idx + beams_idx],
'text': target_transcripts[sample_idx + beams_idx],
'pred_text': pred_text,
'wer': wer_dist_first}
out_file.write(json.dumps(item) + "\n")

sample_idx += len(probs_batch)

if cfg.decoding_strategy == "greedy_batch":
Expand Down Expand Up @@ -272,6 +287,7 @@ def main(cfg: EvalBeamSearchNGramConfig):
cfg.maes_prefix_alpha, cfg.maes_expansion_gamma, cfg.hat_ilm_weight = [0], [0], [0]

target_transcripts = []
durations = []
manifest_dir = Path(cfg.input_manifest).parent
with open(cfg.input_manifest, 'r', encoding='utf_8') as manifest_file:
audio_file_paths = []
Expand All @@ -281,6 +297,7 @@ def main(cfg: EvalBeamSearchNGramConfig):
if not audio_file.is_file() and not audio_file.is_absolute():
audio_file = manifest_dir / audio_file
target_transcripts.append(data['text'])
durations.append(data['duration'])
audio_file_paths.append(str(audio_file.absolute()))

if cfg.probs_cache_file and os.path.exists(cfg.probs_cache_file):
Expand Down Expand Up @@ -348,13 +365,17 @@ def default_autocast():

if cfg.decoding_strategy == "greedy_batch":
asr_model = asr_model.to('cpu')
preds_output_file = os.path.join(cfg.preds_output_folder, f"recognition_results.json")
candidate_wer, candidate_cer = decoding_step(
asr_model,
cfg,
all_probs=all_probs,
target_transcripts=target_transcripts,
audio_file_paths=audio_file_paths,
durations=durations,
beam_batch_size=cfg.beam_batch_size,
progress_bar=True,
preds_output_file=preds_output_file,
)
logging.info(f"Greedy batch WER/CER = {candidate_wer:.2%}/{candidate_cer:.2%}")

Expand Down Expand Up @@ -396,7 +417,7 @@ def default_autocast():
results_file = f"{results_file}_ba{hp['beam_alpha']}"
if cfg.hat_subtract_ilm:
results_file = f"{results_file}_hat_ilmw{hp['hat_ilm_weight']}"
preds_output_file = os.path.join(cfg.preds_output_folder, f"{results_file}.tsv")
preds_output_file = os.path.join(cfg.preds_output_folder, f"recognition_results.json")
else:
preds_output_file = None

Expand All @@ -411,6 +432,8 @@ def default_autocast():
cfg,
all_probs=all_probs,
target_transcripts=target_transcripts,
audio_file_paths=audio_file_paths,
durations=durations,
preds_output_file=preds_output_file,
beam_batch_size=cfg.beam_batch_size,
progress_bar=True,
Expand Down
5 changes: 3 additions & 2 deletions scripts/asr_language_modeling/ngram_lm/kenlm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,8 +174,9 @@ def read_train_file(
def tokenize_str(texts, tokenizer):
tokenized_text = []
for text in texts:
tok_text = tokenizer.text_to_ids(*text)
tok_text = [chr(token + DEFAULT_TOKEN_OFFSET) for token in tok_text]
# tok_text = tokenizer.text_to_ids(*text)
# tok_text = [chr(token + DEFAULT_TOKEN_OFFSET) for token in tok_text]
tok_text = tokenizer.text_to_tokens(*text)
tokenized_text.append(tok_text)
return tokenized_text

Expand Down

0 comments on commit 63a93be

Please sign in to comment.