diff --git a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer_wb-ctc.py b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer_wb-ctc.py index 59cb0cd09a58..9cd23627a746 100644 --- a/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer_wb-ctc.py +++ b/scripts/asr_language_modeling/ngram_lm/eval_beamsearch_ngram_transducer_wb-ctc.py @@ -174,6 +174,8 @@ def merge_alignment_with_wb_hyps( alignment_tokens.append([idx, model.tokenizer.ids_to_tokens([token])[0]]) if not alignment_tokens: + for wb_hyp in wb_result: + print(f"wb_hyp: {wb_hyp.word}") return " ".join([wb_hyp.word for wb_hyp in wb_result]) @@ -523,9 +525,10 @@ def default_autocast(): with open(cfg.probs_cache_file, 'wb') as f_dump: pickle.dump(all_probs, f_dump) -################################ +################################_WB_PART_######################### wb_results = {} + if cfg.applay_context_biasing: # load context graph: context_transcripts = [] @@ -536,15 +539,7 @@ def default_autocast(): context_graph = ContextGraphCTC(blank_id=asr_model.decoder.blank_idx) context_graph.build(context_transcripts) - # # get CTC logits: - # logging.warning("Getting CTC logprobs for CTC based word boosting...") - # with autocast(): - # with torch.no_grad(): - # if isinstance(asr_model, EncDecHybridRNNTCTCModel): - # asr_model.cur_decoder = 'ctc' - # ctc_logits = asr_model.transcribe(audio_file_paths, batch_size=cfg.acoustic_batch_size, logprobs=True) - - # run WB search: + # run CTC based WB search: for idx, logits in tqdm(enumerate(ctc_logprobs), desc=f"CTC based word boosting...", ncols=120, total=len(ctc_logprobs)): # try: wb_result = recognize_wb( @@ -552,21 +547,19 @@ def default_autocast(): context_graph, asr_model, beam_threshold=5, # 5 - context_score=4, # 5 (4) + context_score=5, # 5 (4) keyword_thr=-5, # -5 - ctc_ali_token_weight=4.0 # 3.0 (4.0) + ctc_ali_token_weight=3 # 3.0 (4.0) ) # except: # logging.warning("-------------------------") # logging.warning(f"audio file is: {audio_file_paths[idx]}") wb_results[audio_file_paths[idx]] = wb_result - print(audio_file_paths[idx] + "\n") + # print(audio_file_paths[idx] + "\n") - # get RNNT results: - -################################ +################################_WB_PART_######################### # sort all_probs according to length: if cfg.sort_logits: @@ -602,144 +595,6 @@ def default_autocast(): ) logging.info(f"Greedy batch WER/CER = {candidate_wer:.2%}/{candidate_cer:.2%}") - # asr_model = asr_model.to('cpu') - - # # 'greedy_batch' decoding_strategy would skip the beam search decoding - # if cfg.decoding_strategy in ["beam", "tsd", "alsd", "maes"]: - # if cfg.beam_width is None or cfg.beam_alpha is None: - # raise ValueError("beam_width and beam_alpha are needed to perform beam search decoding.") - # params = { - # 'beam_width': cfg.beam_width, - # 'beam_alpha': cfg.beam_alpha, - # 'maes_prefix_alpha': cfg.maes_prefix_alpha, - # 'maes_expansion_gamma': cfg.maes_expansion_gamma, - # 'hat_ilm_weight': cfg.hat_ilm_weight, - # } - # hp_grid = ParameterGrid(params) - # hp_grid = list(hp_grid) - - # best_wer_beam_size, best_cer_beam_size = None, None - # best_wer_alpha, best_cer_alpha = None, None - # best_wer, best_cer = 1e6, 1e6 - - # logging.info( - # f"==============================Starting the {cfg.decoding_strategy} decoding===============================" - # ) - # logging.info(f"Grid search size: {len(hp_grid)}") - # logging.info(f"It may take some time...") - # logging.info(f"==============================================================================================") - - # # context biasing: - # if cfg.context_file: - # context_transcripts = [] - # for line in open(cfg.context_file).readlines(): - # word = line.strip().lower() - # context_transcripts.append(asr_model.tokenizer.text_to_ids(word)) - # # for word in cfg.context_str.split('_'): - # # context_transcripts.append(asr_model.tokenizer.text_to_ids(word)) - # cfg.decoding.cb_score = cfg.context_score - # cfg.decoding.cb_words = context_transcripts - - # cfg.decoding.softmax_temperature = cfg.softmax_temperature - - # # logging.warning(f"{context_transcripts}") - # # raise Exception - - # # # with bpe dropout: - # # kwl_set = set() - # # context_transcripts = [] - - # # sow_symbol = asr_model.tokenizer.tokens_to_ids(['▁'])[0] - - # # for line in open(cfg.context_file).readlines(): - # # word = line.strip().lower() - # # tokenization = asr_model.tokenizer.tokenizer.encode(word) # , out_type=str - # # kwl_set.add(str(tokenization)) - # # context_transcripts.append(tokenization) - - # # for _ in range(50): - # # tokenization = asr_model.tokenizer.tokenizer.encode(word, enable_sampling=True, alpha=0.1, nbest_size=-1) - # # if tokenization[0] != sow_symbol: - # # tokenization_str = str(tokenization) - # # if tokenization_str not in kwl_set: - # # kwl_set.add(tokenization_str) - # # context_transcripts.append(tokenization) - - # # cfg.decoding.cb_score = cfg.context_score - # # cfg.decoding.cb_words = context_transcripts - - - - # if cfg.preds_output_folder and not os.path.exists(cfg.preds_output_folder): - # os.mkdir(cfg.preds_output_folder) - # for hp in hp_grid: - # if cfg.preds_output_folder: - # results_file = f"preds_out_{cfg.decoding_strategy}_bw{hp['beam_width']}" - # if cfg.decoding_strategy == "maes": - # results_file = f"{results_file}_ma{hp['maes_prefix_alpha']}_mg{hp['maes_expansion_gamma']}" - # if cfg.kenlm_model_file: - # results_file = f"{results_file}_ba{hp['beam_alpha']}" - # if cfg.hat_subtract_ilm: - # results_file = f"{results_file}_hat_ilmw{hp['hat_ilm_weight']}" - # preds_output_file = os.path.join(cfg.preds_output_folder, f"{results_file}.tsv") - # preds_output_manifest = os.path.join(cfg.preds_output_folder, f"recognition_results.json") - # else: - # preds_output_file = None - - # cfg.decoding.beam_size = hp["beam_width"] - # cfg.decoding.ngram_lm_alpha = hp["beam_alpha"] - # cfg.decoding.maes_prefix_alpha = hp["maes_prefix_alpha"] - # cfg.decoding.maes_expansion_gamma = hp["maes_expansion_gamma"] - # cfg.decoding.hat_ilm_weight = hp["hat_ilm_weight"] - - # candidate_wer, candidate_cer = decoding_step( - # asr_model, - # cfg, - # all_probs=all_probs, - # target_transcripts=target_transcripts, - # audio_file_paths=audio_file_paths, - # durations=durations, - # preds_output_file=preds_output_file, - # preds_output_manifest=preds_output_manifest, - # beam_batch_size=cfg.beam_batch_size, - # progress_bar=True, - # ) - - # if candidate_cer < best_cer: - # best_cer_beam_size = hp["beam_width"] - # best_cer_alpha = hp["beam_alpha"] - # best_cer_ma = hp["maes_prefix_alpha"] - # best_cer_mg = hp["maes_expansion_gamma"] - # best_cer_hat_ilm_weight = hp["hat_ilm_weight"] - # best_cer = candidate_cer - - # if candidate_wer < best_wer: - # best_wer_beam_size = hp["beam_width"] - # best_wer_alpha = hp["beam_alpha"] - # best_wer_ma = hp["maes_prefix_alpha"] - # best_wer_ga = hp["maes_expansion_gamma"] - # best_wer_hat_ilm_weight = hp["hat_ilm_weight"] - # best_wer = candidate_wer - - # wer_hat_parameter = "" - # if cfg.hat_subtract_ilm: - # wer_hat_parameter = f"HAT ilm weight = {best_wer_hat_ilm_weight}, " - # logging.info( - # f'Best WER Candidate = {best_wer:.2%} :: Beam size = {best_wer_beam_size}, ' - # f'Beam alpha = {best_wer_alpha}, {wer_hat_parameter}' - # f'maes_prefix_alpha = {best_wer_ma}, maes_expansion_gamma = {best_wer_ga} ' - # ) - - # cer_hat_parameter = "" - # if cfg.hat_subtract_ilm: - # cer_hat_parameter = f"HAT ilm weight = {best_cer_hat_ilm_weight}" - # logging.info( - # f'Best CER Candidate = {best_cer:.2%} :: Beam size = {best_cer_beam_size}, ' - # f'Beam alpha = {best_cer_alpha}, {cer_hat_parameter} ' - # f'maes_prefix_alpha = {best_cer_ma}, maes_expansion_gamma = {best_cer_mg}' - # ) - # logging.info(f"=================================================================================") - if __name__ == '__main__': main() diff --git a/scripts/asr_language_modeling/ngram_lm/word_boosting_search.py b/scripts/asr_language_modeling/ngram_lm/word_boosting_search.py index f2a56403d9e2..7d8c3326b808 100644 --- a/scripts/asr_language_modeling/ngram_lm/word_boosting_search.py +++ b/scripts/asr_language_modeling/ngram_lm/word_boosting_search.py @@ -141,6 +141,7 @@ def filter_wb_hyps(best_hyp_list, word_alignment): li, ri = item[1], item[2] if li <= lh <= ri or li <= rh <= ri or lh <= li <= rh or lh <= ri <= rh: if hyp.score >= item[3]: + # if hyp.score >= item[3] and not item[0].startswith(hyp.word): best_hyp_list_new.append(hyp) current_frame = i break @@ -170,7 +171,17 @@ def filter_wb_hyps(best_hyp_list, word_alignment): # return best_hyp_list_new -def recognize_wb(logprobs, context_graph, asr_model, beam_threshold=None, context_score=0.0, keyword_thr=-3, ctc_ali_token_weight=2.0): +def recognize_wb( + logprobs, + context_graph, + asr_model, + beam_threshold=None, + context_score=0.0, + keyword_thr=-3, + ctc_ali_token_weight=2.0, + print_results=False + ): + start_state = context_graph.root active_tokens = [] next_tokens = [] @@ -231,16 +242,18 @@ def recognize_wb(logprobs, context_graph, asr_model, beam_threshold=None, contex # find best hyp for spotted keywords: best_hyp_list = find_best_hyp(spotted_words) - print(f"---spotted words:") - for hyp in best_hyp_list: - print(f"{hyp.word}: [{hyp.start_frame};{hyp.end_frame}], score:{hyp.score:-.2f}") + if print_results: + print(f"---spotted words:") + for hyp in best_hyp_list: + print(f"{hyp.word}: [{hyp.start_frame};{hyp.end_frame}], score:{hyp.score:-.2f}") # filter wb hyps according to greedy ctc predictions ctc_word_alignment = get_ctc_word_alignment(logprobs, asr_model, token_weight=ctc_ali_token_weight) best_hyp_list_new = filter_wb_hyps(best_hyp_list, ctc_word_alignment) - print("---final result is:") - for hyp in best_hyp_list_new: - print(f"{hyp.word}: [{hyp.start_frame};{hyp.end_frame}], score:{hyp.score:-.2f}") + if print_results: + print("---final result is:") + for hyp in best_hyp_list_new: + print(f"{hyp.word}: [{hyp.start_frame};{hyp.end_frame}], score:{hyp.score:-.2f}") return best_hyp_list_new \ No newline at end of file