Skip to content

Commit

Permalink
Apply isort and black reformatting
Browse files Browse the repository at this point in the history
Signed-off-by: galv <galv@users.noreply.github.com>
  • Loading branch information
galv committed May 17, 2024
1 parent 1958ddf commit 3caa051
Show file tree
Hide file tree
Showing 7 changed files with 55 additions and 37 deletions.
11 changes: 8 additions & 3 deletions examples/asr/transcribe_speech.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@
import glob
import json
import os
import time
from dataclasses import dataclass, is_dataclass
from tempfile import NamedTemporaryFile
import time
from typing import List, Optional, Union

import pytorch_lightning as pl
Expand Down Expand Up @@ -408,7 +408,9 @@ def autocast(dtype=None, enabled=True):
for line in fh:
item = json.loads(line)
if "duration" not in item:
raise ValueError(f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field.")
raise ValueError(
f"Requested calculate_rtfx=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field."
)
total_duration += item["duration"]

with autocast(dtype=amp_dtype, enabled=cfg.amp):
Expand Down Expand Up @@ -436,7 +438,10 @@ def autocast(dtype=None, enabled=True):

if cfg.calculate_rtfx:
start_time = time.time()
transcriptions = asr_model.transcribe(audio=filepaths, override_config=override_cfg,)
transcriptions = asr_model.transcribe(
audio=filepaths,
override_config=override_cfg,
)
if cfg.calculate_rtfx:
transcribe_time = time.time() - start_time

Check failure

Code scanning / CodeQL

Potentially uninitialized local variable Error

Local variable 'start_time' may be used before it is initialized.

Expand Down
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def _transcribe_output_processing(self, outputs, trcfg: TranscribeConfig) -> Gen
# cudaMallocHost()-allocated tensor to be floating
# around. Were that to be the case, then the pinned
# memory cache would always miss.
current_hypotheses[idx].y_sequence = logits_cpu[idx, :logits_len[idx]].clone()
current_hypotheses[idx].y_sequence = logits_cpu[idx, : logits_len[idx]].clone()
if current_hypotheses[idx].alignments is None:
current_hypotheses[idx].alignments = current_hypotheses[idx].y_sequence
del logits_cpu
Expand Down
7 changes: 3 additions & 4 deletions nemo/collections/asr/modules/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,17 +94,16 @@ def __init__(self, win_length, hop_length):
# output in appropriate precision. We have this empty tensor
# here just to detect which dtype tensor this module should
# output at the end of execution.
self.register_buffer("dtype_sentinel_tensor",
torch.tensor((), dtype=torch.float32),
persistent=False)
self.register_buffer("dtype_sentinel_tensor", torch.tensor((), dtype=torch.float32), persistent=False)

@typecheck()
@torch.no_grad()
def forward(self, input_signal, length):
if input_signal.dtype != torch.float32:
logging.warn(
f"AudioPreprocessor received an input signal of dtype {input_signal.dtype}, rather than torch.float32. In sweeps across multiple datasets, we have found that the preprocessor is not robust to low precision mathematics. As such, it runs in float32. Your input will be cast to float32, but this is not necessarily enough to recovery full accuracy. For example, simply casting input_signal from torch.float32 to torch.bfloat16, then back to torch.float32 before running AudioPreprocessor causes drops in absolute WER of up to 0.1%. torch.bfloat16 simply does not have enough mantissa bits to represent enough values in the range [-1.0,+1.0] correctly.",
mode=logging_mode.ONCE)
mode=logging_mode.ONCE,
)
processed_signal, processed_length = self.get_features(input_signal.to(torch.float32), length)
processed_signal = processed_signal.to(self.dtype_sentinel_tensor.dtype)
return processed_signal, processed_length
Expand Down
20 changes: 12 additions & 8 deletions nemo/collections/asr/modules/squeezeformer_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,8 +99,7 @@ def input_example(self, max_batch=1, max_dim=256):

@property
def input_types(self):
"""Returns definitions of module input ports.
"""
"""Returns definitions of module input ports."""
return OrderedDict(
{
"audio_signal": NeuralType(('B', 'D', 'T'), SpectrogramType()),
Expand All @@ -110,8 +109,7 @@ def input_types(self):

@property
def output_types(self):
"""Returns definitions of module output ports.
"""
"""Returns definitions of module output ports."""
return OrderedDict(
{
"outputs": NeuralType(('B', 'D', 'T'), AcousticEncodedRepresentation()),
Expand Down Expand Up @@ -253,7 +251,11 @@ def __init__(
# Chose same type of positional encoding as the originally determined above
if self_attention_model == "rel_pos":
self.time_reduce_pos_enc = RelPositionalEncoding(
d_model=d_model, dropout_rate=0.0, max_len=pos_emb_max_len, xscale=None, dropout_rate_emb=0.0,
d_model=d_model,
dropout_rate=0.0,
max_len=pos_emb_max_len,
xscale=None,
dropout_rate_emb=0.0,
)
else:
self.time_reduce_pos_enc = PositionalEncoding(
Expand All @@ -275,8 +277,8 @@ def __init__(
self.interctc_capture_at_layers = None

def set_max_audio_length(self, max_audio_length):
""" Sets maximum input length.
Pre-calculates internal seq_range mask.
"""Sets maximum input length.
Pre-calculates internal seq_range mask.
"""
self.max_audio_length = max_audio_length
device = next(self.parameters()).device
Expand Down Expand Up @@ -435,7 +437,9 @@ def _update_adapter_cfg_input_dim(self, cfg: DictConfig):
cfg = adapter_utils.update_adapter_cfg_input_dim(self, cfg, module_dim=self.d_model)
return cfg

def get_accepted_adapter_types(self,) -> Set[type]:
def get_accepted_adapter_types(
self,
) -> Set[type]:
types = super().get_accepted_adapter_types()

if len(types) == 0:
Expand Down
8 changes: 4 additions & 4 deletions nemo/collections/asr/parts/submodules/conformer_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ def forward(self, x, pad_mask=None, cache=None):
return x, cache

def reset_parameters_conv(self):
pw1_max = pw2_max = self.d_model ** -0.5
dw_max = self.kernel_size ** -0.5
pw1_max = pw2_max = self.d_model**-0.5
dw_max = self.kernel_size**-0.5

with torch.no_grad():
nn.init.uniform_(self.pointwise_conv1.weight, -pw1_max, pw1_max)
Expand Down Expand Up @@ -404,8 +404,8 @@ def forward(self, x):
return x

def reset_parameters_ff(self):
ffn1_max = self.d_model ** -0.5
ffn2_max = self.d_ff ** -0.5
ffn1_max = self.d_model**-0.5
ffn2_max = self.d_ff**-0.5
with torch.no_grad():
nn.init.uniform_(self.linear1.weight, -ffn1_max, ffn1_max)
nn.init.uniform_(self.linear1.bias, -ffn1_max, ffn1_max)
Expand Down
9 changes: 6 additions & 3 deletions nemo/collections/asr/parts/submodules/multi_head_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ def forward(self, query, key, value, mask, pos_emb, cache=None):
class RelPositionMultiHeadAttentionLongformer(RelPositionMultiHeadAttention):
"""Multi-Head Attention layer of Transformer-XL with sliding window local+global attention from Longformer.
Partially adapted from allenai (https://github.com/allenai/longformer/blob/master/longformer/sliding_chunks.py)
and huggingface (https://github.com/huggingface/transformers/blob/main/src/transformers/models/longformer/modeling_longformer.py)
and huggingface (https://github.com/huggingface/transformers/blob/main/src/transformers/models/longformer/modeling_longformer.py)
Paper: https://arxiv.org/abs/1901.02860 (Transformer-XL),
https://arxiv.org/abs/2004.05150 (Longformer)
Args:
Expand Down Expand Up @@ -650,7 +650,8 @@ def _compute_out_global_to_all(
global_attn_scores = global_attn_scores.transpose(1, 2)

global_attn_scores = global_attn_scores.masked_fill(
is_index_masked.transpose(2, 3), torch.finfo(global_attn_scores.dtype).min,
is_index_masked.transpose(2, 3),
torch.finfo(global_attn_scores.dtype).min,
)

global_attn_scores = global_attn_scores.view(batch_size * self.h, max_num_global_attn_indices, seq_len)
Expand Down Expand Up @@ -747,7 +748,9 @@ def _get_invalid_locations_mask(self, w: int, device: str):
return mask.bool().to(device), ending_mask

def mask_invalid_locations(
self, input_tensor: torch.Tensor, w: int,
self,
input_tensor: torch.Tensor,
w: int,
):
"""
Mask locations invalid for the sliding window attention
Expand Down
35 changes: 21 additions & 14 deletions nemo/collections/asr/parts/utils/transcribe_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@
import json
import os
import re
from tempfile import NamedTemporaryFile
from dataclasses import dataclass
from pathlib import Path
from tempfile import NamedTemporaryFile

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'NamedTemporaryFile' is not used.
from typing import List, Optional, Tuple, Union

import soundfile as sf

Check notice

Code scanning / CodeQL

Unused import Note

Import of 'sf' is not used.
import torch
from omegaconf import DictConfig
import soundfile as sf
from tqdm.auto import tqdm

import nemo.collections.asr as nemo_asr
Expand Down Expand Up @@ -234,7 +234,7 @@ def get_buffered_pred_feat_multitaskAED(


def wrap_transcription(hyps: List[str]) -> List[rnnt_utils.Hypothesis]:
""" Wrap transcription to the expected format in func write_transcription """
"""Wrap transcription to the expected format in func write_transcription"""
wrapped_hyps = []
for hyp in hyps:
hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], text=hyp)
Expand All @@ -243,21 +243,23 @@ def wrap_transcription(hyps: List[str]) -> List[rnnt_utils.Hypothesis]:


def setup_model(cfg: DictConfig, map_location: torch.device) -> Tuple[ASRModel, str]:
""" Setup model from cfg and return model and model name for next step """
"""Setup model from cfg and return model and model name for next step"""
if cfg.model_path is not None and cfg.model_path != "None":
# restore model from .nemo file path
model_cfg = ASRModel.restore_from(restore_path=cfg.model_path, return_config=True)
classpath = model_cfg.target # original class path
imported_class = model_utils.import_class_by_path(classpath) # type: ASRModel
logging.info(f"Restoring model : {imported_class.__name__}")
asr_model = imported_class.restore_from(
restore_path=cfg.model_path, map_location=map_location,
restore_path=cfg.model_path,
map_location=map_location,
) # type: ASRModel
model_name = os.path.splitext(os.path.basename(cfg.model_path))[0]
else:
# restore model by name
asr_model = ASRModel.from_pretrained(
model_name=cfg.pretrained_name, map_location=map_location,
model_name=cfg.pretrained_name,
map_location=map_location,
) # type: ASRModel
model_name = cfg.pretrained_name

Expand All @@ -271,7 +273,7 @@ def setup_model(cfg: DictConfig, map_location: torch.device) -> Tuple[ASRModel,


def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]:
""" Prepare audio data and decide whether it's partial_audio condition. """
"""Prepare audio data and decide whether it's partial_audio condition."""
# this part may need refactor alongsides with refactor of transcribe
partial_audio = False

Expand All @@ -291,7 +293,9 @@ def prepare_audio_data(cfg: DictConfig) -> Tuple[List[str], bool]:
item = json.loads(line)
item["audio_filepath"] = get_full_path(item["audio_filepath"], cfg.dataset_manifest)
if item.get("duration") is None and cfg.presort_manifest:
raise ValueError(f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field.")
raise ValueError(
f"Requested presort_manifest=True, but line {line} in manifest {cfg.dataset_manifest} lacks a 'duration' field."
)
all_entries_have_offset_and_duration = True
for item in read_and_maybe_sort_manifest(cfg.dataset_manifest, try_sort=cfg.presort_manifest):
if not ("offset" in item and "duration" in item):
Expand Down Expand Up @@ -331,7 +335,7 @@ def restore_transcription_order(manifest_path: str, transcriptions: list) -> lis


def compute_output_filename(cfg: DictConfig, model_name: str) -> DictConfig:
""" Compute filename of output manifest and update cfg"""
"""Compute filename of output manifest and update cfg"""
if cfg.output_filename is None:
# create default output filename
if cfg.audio_dir is not None:
Expand Down Expand Up @@ -372,7 +376,7 @@ def write_transcription(
compute_langs: bool = False,
compute_timestamps: bool = False,
) -> Tuple[str, str]:
""" Write generated transcription to output file. """
"""Write generated transcription to output file."""
if cfg.append_pred:
logging.info(f'Transcripts will be written in "{cfg.output_filename}" file')
if cfg.pred_name_postfix is not None:
Expand Down Expand Up @@ -542,7 +546,11 @@ def transcribe_partial_audio(
lg = logits[idx][: logits_len[idx]]
hypotheses.append(lg)
else:
current_hypotheses, _ = decode_function(logits, logits_len, return_hypotheses=return_hypotheses,)
current_hypotheses, _ = decode_function(
logits,
logits_len,
return_hypotheses=return_hypotheses,
)

if return_hypotheses:
# dump log probs per file
Expand Down Expand Up @@ -576,18 +584,17 @@ def compute_metrics_per_sample(
punctuation_marks: List[str] = [".", ",", "?"],
output_manifest_path: str = None,
) -> dict:

'''
Computes metrics per sample for given manifest
Args:
manifest_path: str, Required - path to dataset JSON manifest file (in NeMo format)
reference_field: str, Optional - name of field in .json manifest with the reference text ("text" by default).
hypothesis_field: str, Optional - name of field in .json manifest with the hypothesis text ("pred_text" by default).
metrics: list[str], Optional - list of metrics to be computed (currently supported "wer", "cer", "punct_er")
punctuation_marks: list[str], Optional - list of punctuation marks for computing punctuation error rate ([".", ",", "?"] by default).
output_manifest_path: str, Optional - path where .json manifest with calculated metrics will be saved.
Returns:
samples: dict - Dict of samples with calculated metrics
'''
Expand Down

0 comments on commit 3caa051

Please sign in to comment.