Skip to content

Commit

Permalink
Add lhotse fixes for rnnt model training and WER hanging issue with f… (
Browse files Browse the repository at this point in the history
#10821)

* Add lhotse fixes for rnnt model training and WER hanging issue with f… (#10787)

* Add lhotse fixes for rnnt model training and WER hanging issue with fuse batching

Signed-off-by: Nithin Rao Koluguri <nithinraok>

* Apply isort and black reformatting

Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>

---------

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>
Co-authored-by: Nithin Rao Koluguri <nithinraok>
Co-authored-by: nithinraok <nithinraok@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>

* Apply isort and black reformatting

Signed-off-by: artbataev <artbataev@users.noreply.github.com>

---------

Signed-off-by: Nithin Rao Koluguri <nithinraok>
Signed-off-by: nithinraok <nithinraok@users.noreply.github.com>
Signed-off-by: artbataev <artbataev@users.noreply.github.com>
Co-authored-by: nithinraok <nithinraok@users.noreply.github.com>
Co-authored-by: artbataev <artbataev@users.noreply.github.com>
  • Loading branch information
3 people authored and Yashaswi Karnati committed Oct 20, 2024
1 parent 153e067 commit ade44dc
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 11 deletions.
15 changes: 6 additions & 9 deletions nemo/collections/asr/data/audio_to_text_lhotse.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,15 +51,12 @@ def __init__(self, tokenizer):
def __getitem__(self, cuts) -> Tuple[torch.Tensor, ...]:
audio, audio_lens, cuts = self.load_audio(cuts)
tokens = [
torch.as_tensor(
sum(
(
# Supervisions may come pre-tokenized from the dataloader.
s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language)
for s in c.supervisions
),
start=[],
)
torch.cat(
[
torch.as_tensor(s.tokens if hasattr(s, "tokens") else self.tokenizer(s.text, s.language))
for s in c.supervisions
],
dim=0,
)
for c in cuts
]
Expand Down
3 changes: 2 additions & 1 deletion nemo/collections/asr/metrics/wer.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,8 +254,9 @@ def __init__(
fold_consecutive=True,
batch_dim_index=0,
dist_sync_on_step=False,
sync_on_compute=True,
):
super().__init__(dist_sync_on_step=dist_sync_on_step)
super().__init__(dist_sync_on_step=dist_sync_on_step, sync_on_compute=sync_on_compute)

self.decoding = decoding
self.use_cer = use_cer
Expand Down
7 changes: 7 additions & 0 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1457,6 +1457,10 @@ def forward(
sub_transcripts = sub_transcripts.detach()

# Update WER on each process without syncing
if self.training:
original_sync = self.wer._to_sync
self.wer._to_sync = False

self.wer.update(
predictions=sub_enc,
predictions_lengths=sub_enc_lens,
Expand All @@ -1467,6 +1471,9 @@ def forward(
wer, wer_num, wer_denom = self.wer.compute()
self.wer.reset()

if self.training:
self.wer._to_sync = original_sync

wers.append(wer)
wer_nums.append(wer_num)
wer_denoms.append(wer_denom)
Expand Down
1 change: 0 additions & 1 deletion nemo/collections/common/data/lhotse/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,6 @@ def get_lhotse_dataloader_from_config(
duration_bins=determine_bucket_duration_bins(config),
num_cuts_for_bins_estimate=config.num_cuts_for_bins_estimate,
buffer_size=config.bucket_buffer_size,
concurrent=config.concurrent_bucketing,
rank=0 if is_tarred else global_rank,
world_size=1 if is_tarred else world_size,
)
Expand Down

0 comments on commit ade44dc

Please sign in to comment.