Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci

Signed-off-by: Elena Rastorgueva <erastorgueva@nvidia.com>
  • Loading branch information
pre-commit-ci[bot] authored and erastorgueva-nv committed Dec 9, 2022
1 parent e74fcbf commit 6d3ff2f
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 11 deletions.
13 changes: 4 additions & 9 deletions tools/nemo_forced_aligner/align.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,9 @@
from pathlib import Path

import torch
from nemo.collections.asr.models import ASRModel
from utils import get_log_probs_y_T_U, get_manifest_lines, make_basetoken_ctm, make_word_ctm

from utils import (
get_manifest_lines,
get_log_probs_y_T_U,
make_basetoken_ctm,
make_word_ctm,
)
from nemo.collections.asr.models import ASRModel

V_NEG_NUM = -1e30

Expand All @@ -48,7 +43,7 @@ def viterbi_decoding(log_probs, y, T, U):
padding_for_log_probs = V_NEG_NUM * torch.ones((B, T_max, 1))
log_probs_padded = torch.cat((log_probs, padding_for_log_probs), dim=2)
log_probs_reordered = torch.gather(input=log_probs_padded, dim=2, index=y.unsqueeze(1).repeat(1, T_max, 1))
log_probs_reordered = log_probs_reordered.cpu() # TODO: do alignment on GPU if available
log_probs_reordered = log_probs_reordered.cpu() # TODO: do alignment on GPU if available

v_matrix = V_NEG_NUM * torch.ones_like(log_probs_reordered)
backpointers = -999 * torch.ones_like(v_matrix)
Expand Down Expand Up @@ -111,7 +106,7 @@ def align(
model_downsample_factor,
output_ctm_folder,
grouping_for_ctm,
utt_id_extractor_func=lambda fp : Path(fp).resolve().stem,
utt_id_extractor_func=lambda fp: Path(fp).resolve().stem,
audio_sr=16000, # TODO: get audio SR automatically
device="cuda:0",
batch_size=1,
Expand Down
6 changes: 4 additions & 2 deletions tools/nemo_forced_aligner/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import json
import os

import torch

V_NEG_NUM = -1e30
Expand Down Expand Up @@ -109,6 +110,7 @@ def get_log_probs_y_T_U(data, model):

return log_probs, y, T, U_dash


def make_basetoken_ctm(
data, alignments, model, model_downsample_factor, output_ctm_folder, utt_id_extractor_func, audio_sr,
):
Expand Down Expand Up @@ -223,7 +225,7 @@ def make_word_ctm(

for word_i, word in enumerate(manifest_line["text"].split(" ")):
word_info = {
"word": word,
"word": word,
"u_start": u_counter,
"u_end": None,
"t_start": None,
Expand All @@ -239,7 +241,7 @@ def make_word_ctm(
if word_i < len(manifest_line["text"].split(" ")) - 1:
# add the space after every word except the final word
word_info = {
"word": "<space>",
"word": "<space>",
"u_start": u_counter,
"u_end": None,
"t_start": None,
Expand Down

0 comments on commit 6d3ff2f

Please sign in to comment.