Skip to content

Commit

Permalink
some fixes
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Oct 27, 2023
1 parent 3c1ab2d commit 7b5f447
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 161 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,8 @@ def merge_alignment_with_wb_hyps(
alignment_tokens.append([idx, model.tokenizer.ids_to_tokens([token])[0]])

if not alignment_tokens:
for wb_hyp in wb_result:
print(f"wb_hyp: {wb_hyp.word}")
return " ".join([wb_hyp.word for wb_hyp in wb_result])


Expand Down Expand Up @@ -523,9 +525,10 @@ def default_autocast():
with open(cfg.probs_cache_file, 'wb') as f_dump:
pickle.dump(all_probs, f_dump)

################################
################################_WB_PART_#########################

wb_results = {}

if cfg.applay_context_biasing:
# load context graph:
context_transcripts = []
Expand All @@ -536,37 +539,27 @@ def default_autocast():
context_graph = ContextGraphCTC(blank_id=asr_model.decoder.blank_idx)
context_graph.build(context_transcripts)

# # get CTC logits:
# logging.warning("Getting CTC logprobs for CTC based word boosting...")
# with autocast():
# with torch.no_grad():
# if isinstance(asr_model, EncDecHybridRNNTCTCModel):
# asr_model.cur_decoder = 'ctc'
# ctc_logits = asr_model.transcribe(audio_file_paths, batch_size=cfg.acoustic_batch_size, logprobs=True)

# run WB search:
# run CTC based WB search:
for idx, logits in tqdm(enumerate(ctc_logprobs), desc=f"CTC based word boosting...", ncols=120, total=len(ctc_logprobs)):
# try:
wb_result = recognize_wb(
logits.numpy(),
context_graph,
asr_model,
beam_threshold=5, # 5
context_score=4, # 5 (4)
context_score=5, # 5 (4)
keyword_thr=-5, # -5
ctc_ali_token_weight=4.0 # 3.0 (4.0)
ctc_ali_token_weight=3 # 3.0 (4.0)
)
# except:
# logging.warning("-------------------------")
# logging.warning(f"audio file is: {audio_file_paths[idx]}")
wb_results[audio_file_paths[idx]] = wb_result
print(audio_file_paths[idx] + "\n")
# print(audio_file_paths[idx] + "\n")


# get RNNT results:


################################
################################_WB_PART_#########################

# sort all_probs according to length:
if cfg.sort_logits:
Expand Down Expand Up @@ -602,144 +595,6 @@ def default_autocast():
)
logging.info(f"Greedy batch WER/CER = {candidate_wer:.2%}/{candidate_cer:.2%}")

# asr_model = asr_model.to('cpu')

# # 'greedy_batch' decoding_strategy would skip the beam search decoding
# if cfg.decoding_strategy in ["beam", "tsd", "alsd", "maes"]:
# if cfg.beam_width is None or cfg.beam_alpha is None:
# raise ValueError("beam_width and beam_alpha are needed to perform beam search decoding.")
# params = {
# 'beam_width': cfg.beam_width,
# 'beam_alpha': cfg.beam_alpha,
# 'maes_prefix_alpha': cfg.maes_prefix_alpha,
# 'maes_expansion_gamma': cfg.maes_expansion_gamma,
# 'hat_ilm_weight': cfg.hat_ilm_weight,
# }
# hp_grid = ParameterGrid(params)
# hp_grid = list(hp_grid)

# best_wer_beam_size, best_cer_beam_size = None, None
# best_wer_alpha, best_cer_alpha = None, None
# best_wer, best_cer = 1e6, 1e6

# logging.info(
# f"==============================Starting the {cfg.decoding_strategy} decoding==============================="
# )
# logging.info(f"Grid search size: {len(hp_grid)}")
# logging.info(f"It may take some time...")
# logging.info(f"==============================================================================================")

# # context biasing:
# if cfg.context_file:
# context_transcripts = []
# for line in open(cfg.context_file).readlines():
# word = line.strip().lower()
# context_transcripts.append(asr_model.tokenizer.text_to_ids(word))
# # for word in cfg.context_str.split('_'):
# # context_transcripts.append(asr_model.tokenizer.text_to_ids(word))
# cfg.decoding.cb_score = cfg.context_score
# cfg.decoding.cb_words = context_transcripts

# cfg.decoding.softmax_temperature = cfg.softmax_temperature

# # logging.warning(f"{context_transcripts}")
# # raise Exception

# # # with bpe dropout:
# # kwl_set = set()
# # context_transcripts = []

# # sow_symbol = asr_model.tokenizer.tokens_to_ids(['▁'])[0]

# # for line in open(cfg.context_file).readlines():
# # word = line.strip().lower()
# # tokenization = asr_model.tokenizer.tokenizer.encode(word) # , out_type=str
# # kwl_set.add(str(tokenization))
# # context_transcripts.append(tokenization)

# # for _ in range(50):
# # tokenization = asr_model.tokenizer.tokenizer.encode(word, enable_sampling=True, alpha=0.1, nbest_size=-1)
# # if tokenization[0] != sow_symbol:
# # tokenization_str = str(tokenization)
# # if tokenization_str not in kwl_set:
# # kwl_set.add(tokenization_str)
# # context_transcripts.append(tokenization)

# # cfg.decoding.cb_score = cfg.context_score
# # cfg.decoding.cb_words = context_transcripts



# if cfg.preds_output_folder and not os.path.exists(cfg.preds_output_folder):
# os.mkdir(cfg.preds_output_folder)
# for hp in hp_grid:
# if cfg.preds_output_folder:
# results_file = f"preds_out_{cfg.decoding_strategy}_bw{hp['beam_width']}"
# if cfg.decoding_strategy == "maes":
# results_file = f"{results_file}_ma{hp['maes_prefix_alpha']}_mg{hp['maes_expansion_gamma']}"
# if cfg.kenlm_model_file:
# results_file = f"{results_file}_ba{hp['beam_alpha']}"
# if cfg.hat_subtract_ilm:
# results_file = f"{results_file}_hat_ilmw{hp['hat_ilm_weight']}"
# preds_output_file = os.path.join(cfg.preds_output_folder, f"{results_file}.tsv")
# preds_output_manifest = os.path.join(cfg.preds_output_folder, f"recognition_results.json")
# else:
# preds_output_file = None

# cfg.decoding.beam_size = hp["beam_width"]
# cfg.decoding.ngram_lm_alpha = hp["beam_alpha"]
# cfg.decoding.maes_prefix_alpha = hp["maes_prefix_alpha"]
# cfg.decoding.maes_expansion_gamma = hp["maes_expansion_gamma"]
# cfg.decoding.hat_ilm_weight = hp["hat_ilm_weight"]

# candidate_wer, candidate_cer = decoding_step(
# asr_model,
# cfg,
# all_probs=all_probs,
# target_transcripts=target_transcripts,
# audio_file_paths=audio_file_paths,
# durations=durations,
# preds_output_file=preds_output_file,
# preds_output_manifest=preds_output_manifest,
# beam_batch_size=cfg.beam_batch_size,
# progress_bar=True,
# )

# if candidate_cer < best_cer:
# best_cer_beam_size = hp["beam_width"]
# best_cer_alpha = hp["beam_alpha"]
# best_cer_ma = hp["maes_prefix_alpha"]
# best_cer_mg = hp["maes_expansion_gamma"]
# best_cer_hat_ilm_weight = hp["hat_ilm_weight"]
# best_cer = candidate_cer

# if candidate_wer < best_wer:
# best_wer_beam_size = hp["beam_width"]
# best_wer_alpha = hp["beam_alpha"]
# best_wer_ma = hp["maes_prefix_alpha"]
# best_wer_ga = hp["maes_expansion_gamma"]
# best_wer_hat_ilm_weight = hp["hat_ilm_weight"]
# best_wer = candidate_wer

# wer_hat_parameter = ""
# if cfg.hat_subtract_ilm:
# wer_hat_parameter = f"HAT ilm weight = {best_wer_hat_ilm_weight}, "
# logging.info(
# f'Best WER Candidate = {best_wer:.2%} :: Beam size = {best_wer_beam_size}, '
# f'Beam alpha = {best_wer_alpha}, {wer_hat_parameter}'
# f'maes_prefix_alpha = {best_wer_ma}, maes_expansion_gamma = {best_wer_ga} '
# )

# cer_hat_parameter = ""
# if cfg.hat_subtract_ilm:
# cer_hat_parameter = f"HAT ilm weight = {best_cer_hat_ilm_weight}"
# logging.info(
# f'Best CER Candidate = {best_cer:.2%} :: Beam size = {best_cer_beam_size}, '
# f'Beam alpha = {best_cer_alpha}, {cer_hat_parameter} '
# f'maes_prefix_alpha = {best_cer_ma}, maes_expansion_gamma = {best_cer_mg}'
# )
# logging.info(f"=================================================================================")


if __name__ == '__main__':
main()
27 changes: 20 additions & 7 deletions scripts/asr_language_modeling/ngram_lm/word_boosting_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,7 @@ def filter_wb_hyps(best_hyp_list, word_alignment):
li, ri = item[1], item[2]
if li <= lh <= ri or li <= rh <= ri or lh <= li <= rh or lh <= ri <= rh:
if hyp.score >= item[3]:
# if hyp.score >= item[3] and not item[0].startswith(hyp.word):
best_hyp_list_new.append(hyp)
current_frame = i
break
Expand Down Expand Up @@ -170,7 +171,17 @@ def filter_wb_hyps(best_hyp_list, word_alignment):
# return best_hyp_list_new


def recognize_wb(logprobs, context_graph, asr_model, beam_threshold=None, context_score=0.0, keyword_thr=-3, ctc_ali_token_weight=2.0):
def recognize_wb(
logprobs,
context_graph,
asr_model,
beam_threshold=None,
context_score=0.0,
keyword_thr=-3,
ctc_ali_token_weight=2.0,
print_results=False
):

start_state = context_graph.root
active_tokens = []
next_tokens = []
Expand Down Expand Up @@ -231,16 +242,18 @@ def recognize_wb(logprobs, context_graph, asr_model, beam_threshold=None, contex

# find best hyp for spotted keywords:
best_hyp_list = find_best_hyp(spotted_words)
print(f"---spotted words:")
for hyp in best_hyp_list:
print(f"{hyp.word}: [{hyp.start_frame};{hyp.end_frame}], score:{hyp.score:-.2f}")
if print_results:
print(f"---spotted words:")
for hyp in best_hyp_list:
print(f"{hyp.word}: [{hyp.start_frame};{hyp.end_frame}], score:{hyp.score:-.2f}")

# filter wb hyps according to greedy ctc predictions
ctc_word_alignment = get_ctc_word_alignment(logprobs, asr_model, token_weight=ctc_ali_token_weight)
best_hyp_list_new = filter_wb_hyps(best_hyp_list, ctc_word_alignment)
print("---final result is:")
for hyp in best_hyp_list_new:
print(f"{hyp.word}: [{hyp.start_frame};{hyp.end_frame}], score:{hyp.score:-.2f}")
if print_results:
print("---final result is:")
for hyp in best_hyp_list_new:
print(f"{hyp.word}: [{hyp.start_frame};{hyp.end_frame}], score:{hyp.score:-.2f}")


return best_hyp_list_new

0 comments on commit 7b5f447

Please sign in to comment.