Skip to content

Commit

Permalink
add ctc only model
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Jan 19, 2024
1 parent 1e11fe7 commit 227710e
Show file tree
Hide file tree
Showing 5 changed files with 118 additions and 82 deletions.
18 changes: 16 additions & 2 deletions scripts/asr_context_biasing/compute_key_words_fscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,28 @@ def main():
"--input_manifest", type=str, required=True, help="nemo manifest with recognition results in pred_text field",
)
parser.add_argument(
"--context_biasing_file", type=str, required=True, help="file of context biasing words/phrases with their spellings"
"--context_biasing_file", type=str, required=True,
help="""
file of context biasing words/phrases with their spellings \
(one word/phrase per line, spellings are separated from word/phrase by dash symbol):
WORD1_SPELLING1
WORD2_SPELLING1_SPELLING2
...
nvidia_nvidia
gpu_gpu_g p u
nvlink_nvlink_nv link
...
alternative spellings help to improve the recognition accuracy of abbriviations and complicated words,
which are aften recognized as separate words by ASR model (gpu -> g p u, tensorrt -> tensor rt, and so on).
"""
)

args = parser.parse_args()
# use list instead of dict to preserve key words order during printing word-level statistics
key_words_list = []
for line in open(args.context_biasing_file).readlines():
item = line.strip().split("-")[0].lower()
item = line.strip().split("_")[0].lower()
assert len(item) > 1, f"word/phrase {item} does not have any spelling"
if item not in key_words_list:
key_words_list.append(item)
compute_fscore(args.input_manifest, key_words_list)
Expand Down
3 changes: 2 additions & 1 deletion scripts/asr_context_biasing/context_biasing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def merge_alignment_with_ws_hyps(
cb_results: List[WSHyp],
decoder_type: str = "ctc",
intersection_threshold: float = 30.0,
blank_idx: int = 0,
) -> str:
"""
Merge context biasing predictions with ctc/rnnt word-level alignment.
Expand All @@ -45,7 +46,7 @@ def merge_alignment_with_ws_hyps(
alignment_tokens = []
prev_token = None
for idx, token in enumerate(candidate):
if token != asr_model.decoder.blank_idx:
if token != blank_idx:
if token == prev_token:
alignment_tokens[-1] = [idx, asr_model.tokenizer.ids_to_tokens([int(token)])[0]]
else:
Expand Down
3 changes: 3 additions & 0 deletions scripts/asr_context_biasing/context_graph_ctc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# The script is obtained and modified from Icefall repo:
# https://github.com/k2-fsa/icefall/blob/master/icefall/context_graph.py

from collections import deque
from typing import Dict, List, Optional, Union

Expand Down
10 changes: 5 additions & 5 deletions scripts/asr_context_biasing/ctc_based_word_spotter.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,7 @@ def find_best_hyps(spotted_words: List[WSHyp], intersection_threshold: int = 10)
return best_hyp_list


def get_ctc_word_alignment(logprob: np.ndarray, asr_model, token_weight: float = 1.0) -> List[tuple]:
def get_ctc_word_alignment(logprob: np.ndarray, asr_model, token_weight: float = 1.0, blank_idx: int = 0) -> List[tuple]:
"""
Get word level alignment (with start and end frames) based on argmax ctc predictions.
The word score is a sum of non-blank token logprobs with additional token_weight.
Expand All @@ -174,7 +174,7 @@ def get_ctc_word_alignment(logprob: np.ndarray, asr_model, token_weight: float =
prev_idx = None
for i, idx in enumerate(alignment_ctc):
token_logprob = 0
if idx != asr_model.decoder.blank_idx:
if idx != blank_idx:
token = asr_model.tokenizer.ids_to_tokens([int(idx)])[0]
if idx == prev_idx:
prev_repited_token = token_alignment.pop()
Expand Down Expand Up @@ -268,6 +268,7 @@ def run_word_spotter(
logprobs: np.ndarray,
context_graph: ContextGraphCTC,
asr_model,
blank_idx: int = 0,
beam_threshold: float = 5.0,
cb_weight: float = 3.0,
ctc_ali_token_weight: float = 0.5,
Expand All @@ -284,6 +285,7 @@ def run_word_spotter(
Args:
logprobs: CTC logprobs
context_graph: Context-Biasing graph
blank_idx: blank index in ASR model
asr_model: ASR model (ctc or hybrid-transducer-ctc)
beam_threshold: threshold for beam pruning
cb_weight: context biasing weight
Expand All @@ -306,8 +308,6 @@ def run_word_spotter(
blank_threshold = np.log(blank_threshold)
non_blank_threshold = np.log(non_blank_threshold)

blank_idx = asr_model.decoder.blank_idx

for frame in range(logprobs.shape[0]):
# add an empty token (located in the graph root) at each new frame to start new word spotting
active_tokens.append(Token(start_state, start_frame=frame))
Expand Down Expand Up @@ -357,7 +357,7 @@ def run_word_spotter(
best_hyp_list = find_best_hyps(spotted_words)

# filter hyps according to word-level ctc predictions to avoid a high false accept rate
ctc_word_alignment = get_ctc_word_alignment(logprobs, asr_model, token_weight=ctc_ali_token_weight)
ctc_word_alignment = get_ctc_word_alignment(logprobs, asr_model, token_weight=ctc_ali_token_weight, blank_idx=blank_idx)
best_hyp_list = filter_wb_hyps(best_hyp_list, ctc_word_alignment)

return best_hyp_list
Loading

0 comments on commit 227710e

Please sign in to comment.