From 0725e136905d10111c52d96af30d6c3f5147b15c Mon Sep 17 00:00:00 2001 From: Hainan Xu Date: Mon, 28 Nov 2022 17:51:47 -0500 Subject: [PATCH] minor python style changes Signed-off-by: Hainan Xu --- nemo/collections/asr/metrics/rnnt_wer_bpe.py | 4 -- nemo/collections/asr/modules/rnnt.py | 5 +- .../utils/cuda_utils/gpu_rnnt_kernel.py | 1 - .../parts/submodules/rnnt_greedy_decoding.py | 64 +++++++++++++------ 4 files changed, 45 insertions(+), 29 deletions(-) diff --git a/nemo/collections/asr/metrics/rnnt_wer_bpe.py b/nemo/collections/asr/metrics/rnnt_wer_bpe.py index 48998adb67eb..70bde94ade61 100644 --- a/nemo/collections/asr/metrics/rnnt_wer_bpe.py +++ b/nemo/collections/asr/metrics/rnnt_wer_bpe.py @@ -196,10 +196,6 @@ class RNNTBPEDecoding(AbstractRNNTDecoding): def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec): blank_id = tokenizer.tokenizer.vocab_size - - big_blank_duration_list = decoding_cfg.big_blank_duration_list - - big_blank_id_list = list(range(tokenizer.tokenizer.vocab_size + 1, tokenizer.tokenizer.vocab_size + len(big_blank_duration_list) + 1)) self.tokenizer = tokenizer super(RNNTBPEDecoding, self).__init__( diff --git a/nemo/collections/asr/modules/rnnt.py b/nemo/collections/asr/modules/rnnt.py index 3a363cdf333c..b6c694f4e97e 100644 --- a/nemo/collections/asr/modules/rnnt.py +++ b/nemo/collections/asr/modules/rnnt.py @@ -1187,7 +1187,8 @@ def __init__( self.vocabulary = vocabulary self._vocab_size = num_classes - self._num_classes = num_classes + 1 + num_big_blanks # add 2 for two blank symbols + self.num_extra_outputs = num_extra_outputs + self._num_classes = num_classes + 1 + num_extra_outputs if experimental_fuse_loss_wer is not None: # Override fuse_loss_wer from deprecated argument @@ -1218,8 +1219,6 @@ def __init__( self.pred_hidden = jointnet['pred_hidden'] self.joint_hidden = jointnet['joint_hidden'] self.activation = jointnet['activation'] - self.num_extra_outputs = num_extra_outputs - self._num_classes = num_classes + 1 + num_extra_outputs # Optional arguments dropout = jointnet.get('dropout', 0.0) diff --git a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py index a2a42cd0336f..42a7a9a022a4 100644 --- a/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py +++ b/nemo/collections/asr/parts/numba/rnnt_loss/utils/cuda_utils/gpu_rnnt_kernel.py @@ -27,7 +27,6 @@ # limitations under the License. import math -from typing import List import torch from numba import cuda diff --git a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py index 3860920c803f..f1fda33e1287 100644 --- a/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py @@ -464,10 +464,10 @@ def _greedy_decode( return hypothesis -class GreedyMultiblankRNNTInfer(_GreedyRNNTInfer): - """A greedy transducer decoder. +class GreedyBatchedRNNTInfer(_GreedyRNNTInfer): + """A batch level greedy transducer decoder. - Sequence level greedy decoding, performed auto-repressively. + Batch level greedy decoding, performed auto-repressively. Args: decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. @@ -479,11 +479,49 @@ class GreedyMultiblankRNNTInfer(_GreedyRNNTInfer): preserve_alignments: Bool flag which preserves the history of alignments generated during greedy decoding (sample / batched). When set to true, the Hypothesis will contain the non-null value for `alignments` in it. Here, `alignments` is a List of List of - Tuple(Tensor (of length V + 1 + num-big-blanks), Tensor(scalar, label after argmax)). + Tuple(Tensor (of length V + 1), Tensor(scalar, label after argmax)). The length of the list corresponds to the Acoustic Length (T). Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary. U is the number of target tokens for the current timestep Ti. + preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated + during greedy decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of List of floats. + + The length of the list corresponds to the Acoustic Length (T). + Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores. + U is the number of target tokens for the current timestep Ti. + confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence scores. + + name: The method name (str). + Supported values: + - 'max_prob' for using the maximum token probability as a confidence. + - 'entropy' for using normalized entropy of a log-likelihood vector. + + entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + Supported values: + - 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided, + the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). + Note that for this entropy, the temperature should comply the following inequality: + 1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size. + - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. + Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/Tsallis_entropy + - 'renui' for the Rényi entropy. + Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy + + temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the temperature equals one, scaling is not applied to 'max_prob', + and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + + entropy_norm: A mapping of the entropy value to the interval [0,1]. + Supported values: + - 'lin' for using the linear mapping. + - 'exp' for using exponential mapping with linear shift. """ def __init__( @@ -1670,22 +1708,6 @@ def _get_initial_states(self, batchsize): return input_states -@dataclass -class GreedyRNNTInferConfig: - max_symbols_per_step: Optional[int] = 10 - preserve_alignments: bool = False - preserve_frame_confidence: bool = False - confidence_method_cfg: Optional[ConfidenceMethodConfig] = None - - -@dataclass -class GreedyBatchedRNNTInferConfig: - max_symbols_per_step: Optional[int] = 10 - preserve_alignments: bool = False - preserve_frame_confidence: bool = False - confidence_method_cfg: Optional[ConfidenceMethodConfig] = None - - class GreedyMultiblankRNNTInfer(_GreedyRNNTInfer): """A greedy transducer decoder. @@ -1695,7 +1717,7 @@ class GreedyMultiblankRNNTInfer(_GreedyRNNTInfer): decoder_model: rnnt_utils.AbstractRNNTDecoder implementation. joint_model: rnnt_utils.AbstractRNNTJoint implementation. blank_index: int index of the blank token. Can be 0 or len(vocabulary). - big_blank_durations: a list containing durations for big blank the model supports. + big_blank_durations: a list containing durations for big blank the model supports. max_symbols_per_step: Optional int. The maximum number of symbols that can be added to a sequence in a single time step; if set to None then there is no limit.