Skip to content

Commit

Permalink
minor python style changes
Browse files Browse the repository at this point in the history
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
Hainan Xu committed Nov 29, 2022
1 parent 84a9d61 commit 0725e13
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 29 deletions.
4 changes: 0 additions & 4 deletions nemo/collections/asr/metrics/rnnt_wer_bpe.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,10 +196,6 @@ class RNNTBPEDecoding(AbstractRNNTDecoding):

def __init__(self, decoding_cfg, decoder, joint, tokenizer: TokenizerSpec):
blank_id = tokenizer.tokenizer.vocab_size

big_blank_duration_list = decoding_cfg.big_blank_duration_list

big_blank_id_list = list(range(tokenizer.tokenizer.vocab_size + 1, tokenizer.tokenizer.vocab_size + len(big_blank_duration_list) + 1))
self.tokenizer = tokenizer

super(RNNTBPEDecoding, self).__init__(
Expand Down
5 changes: 2 additions & 3 deletions nemo/collections/asr/modules/rnnt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1187,7 +1187,8 @@ def __init__(
self.vocabulary = vocabulary

self._vocab_size = num_classes
self._num_classes = num_classes + 1 + num_big_blanks # add 2 for two blank symbols
self.num_extra_outputs = num_extra_outputs
self._num_classes = num_classes + 1 + num_extra_outputs

if experimental_fuse_loss_wer is not None:
# Override fuse_loss_wer from deprecated argument
Expand Down Expand Up @@ -1218,8 +1219,6 @@ def __init__(
self.pred_hidden = jointnet['pred_hidden']
self.joint_hidden = jointnet['joint_hidden']
self.activation = jointnet['activation']
self.num_extra_outputs = num_extra_outputs
self._num_classes = num_classes + 1 + num_extra_outputs

# Optional arguments
dropout = jointnet.get('dropout', 0.0)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@
# limitations under the License.

import math
from typing import List

import torch
from numba import cuda
Expand Down
64 changes: 43 additions & 21 deletions nemo/collections/asr/parts/submodules/rnnt_greedy_decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,10 +464,10 @@ def _greedy_decode(
return hypothesis


class GreedyMultiblankRNNTInfer(_GreedyRNNTInfer):
"""A greedy transducer decoder.
class GreedyBatchedRNNTInfer(_GreedyRNNTInfer):
"""A batch level greedy transducer decoder.
Sequence level greedy decoding, performed auto-repressively.
Batch level greedy decoding, performed auto-repressively.
Args:
decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
Expand All @@ -479,11 +479,49 @@ class GreedyMultiblankRNNTInfer(_GreedyRNNTInfer):
preserve_alignments: Bool flag which preserves the history of alignments generated during
greedy decoding (sample / batched). When set to true, the Hypothesis will contain
the non-null value for `alignments` in it. Here, `alignments` is a List of List of
Tuple(Tensor (of length V + 1 + num-big-blanks), Tensor(scalar, label after argmax)).
Tuple(Tensor (of length V + 1), Tensor(scalar, label after argmax)).
The length of the list corresponds to the Acoustic Length (T).
Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more targets from a vocabulary.
U is the number of target tokens for the current timestep Ti.
preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores generated
during greedy decoding (sample / batched). When set to true, the Hypothesis will contain
the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of List of floats.
The length of the list corresponds to the Acoustic Length (T).
Each value in the list (Ti) is a torch.Tensor (U), representing 1 or more confidence scores.
U is the number of target tokens for the current timestep Ti.
confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame
confidence scores.
name: The method name (str).
Supported values:
- 'max_prob' for using the maximum token probability as a confidence.
- 'entropy' for using normalized entropy of a log-likelihood vector.
entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`.
Supported values:
- 'gibbs' for the (standard) Gibbs entropy. If the temperature α is provided,
the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)).
Note that for this entropy, the temperature should comply the following inequality:
1/log(V) <= α <= -1/log(1-1/V) where V is the model vocabulary size.
- 'tsallis' for the Tsallis entropy with the Boltzmann constant one.
Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)),
where α is a parameter. When α == 1, it works like the Gibbs entropy.
More: https://en.wikipedia.org/wiki/Tsallis_entropy
- 'renui' for the Rényi entropy.
Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)),
where α is a parameter. When α == 1, it works like the Gibbs entropy.
More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy
temperature: Temperature scale for logsoftmax (α for entropies). Here we restrict it to be > 0.
When the temperature equals one, scaling is not applied to 'max_prob',
and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i))
entropy_norm: A mapping of the entropy value to the interval [0,1].
Supported values:
- 'lin' for using the linear mapping.
- 'exp' for using exponential mapping with linear shift.
"""

def __init__(
Expand Down Expand Up @@ -1670,22 +1708,6 @@ def _get_initial_states(self, batchsize):
return input_states


@dataclass
class GreedyRNNTInferConfig:
max_symbols_per_step: Optional[int] = 10
preserve_alignments: bool = False
preserve_frame_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = None


@dataclass
class GreedyBatchedRNNTInferConfig:
max_symbols_per_step: Optional[int] = 10
preserve_alignments: bool = False
preserve_frame_confidence: bool = False
confidence_method_cfg: Optional[ConfidenceMethodConfig] = None


class GreedyMultiblankRNNTInfer(_GreedyRNNTInfer):
"""A greedy transducer decoder.
Expand All @@ -1695,7 +1717,7 @@ class GreedyMultiblankRNNTInfer(_GreedyRNNTInfer):
decoder_model: rnnt_utils.AbstractRNNTDecoder implementation.
joint_model: rnnt_utils.AbstractRNNTJoint implementation.
blank_index: int index of the blank token. Can be 0 or len(vocabulary).
big_blank_durations: a list containing durations for big blank the model supports.
big_blank_durations: a list containing durations for big blank the model supports.
max_symbols_per_step: Optional int. The maximum number of symbols that can be added
to a sequence in a single time step; if set to None then there is
no limit.
Expand Down

0 comments on commit 0725e13

Please sign in to comment.