diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index 67559eccf6e27..70d63c0f8c6fd 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -213,20 +213,20 @@ def __init__(self, decoding_cfg, blank_id: int): self.batch_dim_index = self.cfg.get('batch_dim_index', 0) self.word_seperator = self.cfg.get('word_seperator', ' ') - possible_strategies = ['greedy', 'beam', 'pyctcdecode', 'flashlight'] + possible_strategies = ['greedy', 'greedy_batched', 'beam', 'pyctcdecode', 'flashlight'] if self.cfg.strategy not in possible_strategies: raise ValueError(f"Decoding strategy must be one of {possible_strategies}. Given {self.cfg.strategy}") # Update preserve alignments if self.preserve_alignments is None: - if self.cfg.strategy in ['greedy']: + if self.cfg.strategy in ['greedy', 'greedy_batched']: self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False) else: self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False) # Update compute timestamps if self.compute_timestamps is None: - if self.cfg.strategy in ['greedy']: + if self.cfg.strategy in ['greedy', 'greedy_batched']: self.compute_timestamps = self.cfg.greedy.get('compute_timestamps', False) elif self.cfg.strategy in ['beam']: self.compute_timestamps = self.cfg.beam.get('compute_timestamps', False) @@ -234,10 +234,10 @@ def __init__(self, decoding_cfg, blank_id: int): # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) - # Confidence estimation is not implemented for strategies other than `greedy` + # Confidence estimation is not implemented for strategies other than `greedy` and `greedy_batched` if ( not self.preserve_frame_confidence - and self.cfg.strategy != 'greedy' + and self.cfg.strategy not in ('greedy', 'greedy_batched') and self.cfg.beam.get('preserve_frame_confidence', False) ): raise NotImplementedError(f"Confidence calculation is not supported for strategy `{self.cfg.strategy}`") @@ -247,6 +247,10 @@ def __init__(self, decoding_cfg, blank_id: int): self.compute_timestamps |= self.preserve_frame_confidence if self.cfg.strategy == 'greedy': + logging.warning( + "CTC decoding strategy 'greedy' is slower than 'greedy_batched', which implements the same exact interface. Consider changing your strategy to 'greedy_batched' for a free performance improvement.", + mode=logging_mode.ONCE, + ) self.decoding = ctc_greedy_decoding.GreedyCTCInfer( blank_id=self.blank_id, @@ -256,6 +260,15 @@ def __init__(self, decoding_cfg, blank_id: int): confidence_method_cfg=self.confidence_method_cfg, ) + elif self.cfg.strategy == "greedy_batched": + self.decoding = ctc_greedy_decoding.GreedyBatchedCTCInfer( + blank_id=self.blank_id, + preserve_alignments=self.preserve_alignments, + compute_timestamps=self.compute_timestamps, + preserve_frame_confidence=self.preserve_frame_confidence, + confidence_method_cfg=self.confidence_method_cfg, + ) + elif self.cfg.strategy == 'beam': self.decoding = ctc_beam_decoding.BeamCTCInfer( @@ -1287,7 +1300,7 @@ def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: @dataclass class CTCDecodingConfig: - strategy: str = "greedy" + strategy: str = "greedy_batched" # preserve decoding alignments preserve_alignments: Optional[bool] = None diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index ab4b4c40e8602..1ef26cd7adf37 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -110,7 +110,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin): def input_types(self): """Returns definitions of module input ports. """ - # Input can be of dimention - + # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] return { @@ -266,6 +266,227 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) +class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin): + """A vectorized greedy CTC decoder. + + This is basically always faster than GreedyCTCInfer, and supports + the same interface. See issue #8891 on github for what is wrong + with GreedyCTCInfer. GreedyCTCInfer loops over each element in the + batch, running kernels at batch size one. CPU overheads end up + dominating. This implementation does appropriate masking to + appropriately do the same operation in a batched manner. + + Args: + blank_index: int index of the blank token. Can be 0 or len(vocabulary). + preserve_alignments: Bool flag which preserves the history of logprobs generated during + decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors. + compute_timestamps: A bool flag, which determines whether to compute the character/subword, or + word based timestamp mapping the output log-probabilities to discrite intervals of timestamps. + The timestamps will be available in the returned Hypothesis.timestep as a dictionary. + preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores + generated during decoding. When set to true, the Hypothesis will contain + the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats. + confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence scores. + + name: The method name (str). + Supported values: + - 'max_prob' for using the maximum token probability as a confidence. + - 'entropy' for using a normalized entropy of a log-likelihood vector. + + entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + Supported values: + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, + the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. + - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. + Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/Tsallis_entropy + - 'renyi' for the Rényi entropy. + Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy + + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', + and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + + entropy_norm: A mapping of the entropy value to the interval [0,1]. + Supported values: + - 'lin' for using the linear mapping. + - 'exp' for using exponential mapping with linear shift. + + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + # Input can be of dimension - + # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] + + return { + "decoder_output": NeuralType(None, LogprobsType()), + "decoder_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + blank_id: int, + preserve_alignments: bool = False, + compute_timestamps: bool = False, + preserve_frame_confidence: bool = False, + confidence_method_cfg: Optional[DictConfig] = None, + ): + super().__init__() + + self.blank_id = blank_id + self.preserve_alignments = preserve_alignments + # we need timestamps to extract non-blank per-frame confidence + self.compute_timestamps = compute_timestamps | preserve_frame_confidence + self.preserve_frame_confidence = preserve_frame_confidence + + # set confidence calculation method + self._init_confidence_method(confidence_method_cfg) + + @typecheck() + def forward( + self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + ): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label). + decoder_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + if decoder_output.ndim == 2: + hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths) + else: + hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths) + packed_result = pack_hypotheses(hypotheses, decoder_lengths) + return (packed_result,) + + @torch.no_grad() + def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor): + # x: [B, T, D] + # out_len: [B] + + batch_size = x.shape[0] + max_time = x.shape[1] + + predictions = x + # In CTC greedy decoding, each output maximum likelihood token + # is calculated independent of the other tokens. + predictions_logprobs, predictions_labels = predictions.max(dim=-1) + + # Since predictions_logprobs is a padded matrix in the time + # dimension, we consider invalid timesteps to be "blank". + time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time) + non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, time_steps < out_len.unsqueeze(1)) + # Sum the non-blank labels to compute the score of the + # transcription. This follows from Eq. (3) of "Connectionist + # Temporal Classification: Labelling Unsegmented Sequence Data + # with Recurrent Neural Networks". + scores = torch.where(non_blank_ids_mask, predictions_logprobs, 0.0).sum(axis=1) + + scores = scores.cpu() + predictions_labels = predictions_labels.cpu() + out_len = out_len.cpu() + + if self.preserve_alignments or self.preserve_frame_confidence: + predictions = predictions.cpu() + + hypotheses = [] + + # This mimics the for loop in GreedyCTCInfer::forward. + for i in range(batch_size): + hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None) + hypothesis.score = scores[i] + + prediction_labels_no_padding = predictions_labels[i, : out_len[i]].tolist() + + assert predictions_labels.dtype == torch.int64 + hypothesis.y_sequence = prediction_labels_no_padding + + if self.preserve_alignments: + hypothesis.alignments = ( + predictions[i, : out_len[i], :].clone(), + predictions_labels[i, : out_len[i]].clone(), + ) + if self.compute_timestamps: + # TOOD: Could do this in a vectorized manner... Would + # prefer to have nonzero_static, though, for sanity. + # Or do a prefix sum on out_len + hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist() + if self.preserve_frame_confidence: + hypothesis.frame_confidence = self._get_confidence(predictions[i, : out_len[i], :]) + + hypotheses.append(hypothesis) + + return hypotheses + + @torch.no_grad() + def _greedy_decode_labels_batched(self, x: torch.Tensor, out_len: torch.Tensor): + """ + This does greedy decoding in the case where you have already found the + most likely token at each timestep. + """ + # x: [B, T] + # out_len: [B] + + batch_size = x.shape[0] + max_time = x.shape[1] + + predictions_labels = x + time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time) + non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, time_steps < out_len.unsqueeze(1)) + predictions_labels = predictions_labels.cpu() + out_len = out_len.cpu() + + hypotheses = [] + + for i in range(batch_size): + hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None) + hypothesis.y_sequence = predictions_labels[i, : out_len[i]].tolist() + hypothesis.score = -1.0 + + if self.preserve_alignments: + raise ValueError( + "Requested for alignments, but predictions provided were labels, not log probabilities." + ) + if self.compute_timestamps: + # TOOD: Could do this in a vectorized manner... Would + # prefer to have nonzero_static, though, for sanity. + # Or do a prefix sum on out_len + hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist() + if self.preserve_frame_confidence: + raise ValueError( + "Requested for per-frame confidence, but predictions provided were labels, not log probabilities." + ) + + hypotheses.append(hypothesis) + + return hypotheses + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @dataclass class GreedyCTCInferConfig: preserve_alignments: bool = False diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index a3a5689062bff..8eceb822fd386 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -26,6 +26,7 @@ CTCDecoding, CTCDecodingConfig, ) +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis @@ -191,3 +192,82 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme # timestamps check if timestamps: check_subword_timestamps(hyp, decoding) + + @pytest.mark.unit + @pytest.mark.parametrize('alignments', [False, True]) + @pytest.mark.parametrize('timestamps', [False, True]) + @pytest.mark.parametrize('preserve_frame_confidence', [False, True]) + def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence): + cfg = CTCBPEDecodingConfig( + strategy='greedy', + preserve_alignments=alignments, + compute_timestamps=timestamps, + confidence_cfg=ConfidenceConfig(preserve_frame_confidence=preserve_frame_confidence), + ) + unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + + cfg.strategy = 'greedy_batched' + batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + + torch.manual_seed(1) + B, T = 4, 20 + V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1 + input_signal = torch.randn(size=(B, T, V)) + # Set the blank index to a very high probability to make sure + # that we always handle at least a few blanks. + input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 + input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 + length = torch.randint(low=1, high=T, size=[B]) + + with torch.inference_mode(): + hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor( + input_signal, length, fold_consecutive=True, return_hypotheses=True + ) + + batched_hyps, _ = batched_decoding.ctc_decoder_predictions_tensor( + input_signal, length, fold_consecutive=True, return_hypotheses=True + ) + + assert len(hyps) == len(batched_hyps) == B + for hyp, batched_hyp in zip(hyps, batched_hyps): + assert torch.abs(hyp.score - batched_hyp.score) <= 1e-5 + assert torch.all(hyp.y_sequence == batched_hyp.y_sequence) + if timestamps: + assert hyp.timestep == batched_hyp.timestep + if alignments: + assert torch.all(hyp.alignments[0] == batched_hyp.alignments[0]) + assert torch.all(hyp.alignments[1] == batched_hyp.alignments[1]) + + @pytest.mark.unit + @pytest.mark.parametrize('timestamps', [False, True]) + def test_batched_decoding_labels(self, tmp_tokenizer, timestamps): + cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps) + unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + cfg.strategy = 'greedy_batched' + batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + + torch.manual_seed(1) + B, T = 4, 20 + V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1 + input_labels = torch.randint(V, size=(B, T)) + # Set some indices to blank to make sure that we always handle + # at least a few blanks. + input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size + input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size + length = torch.randint(low=1, high=T, size=[B]) + + with torch.inference_mode(): + hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor( + input_labels, length, fold_consecutive=True, return_hypotheses=True + ) + + batched_hyps, _ = batched_decoding.ctc_decoder_predictions_tensor( + input_labels, length, fold_consecutive=True, return_hypotheses=True + ) + + assert len(hyps) == len(batched_hyps) == B + for hyp, batched_hyp in zip(hyps, batched_hyps): + assert abs(hyp.score - batched_hyp.score) <= 1e-5 + assert torch.all(hyp.y_sequence == batched_hyp.y_sequence) + if timestamps: + assert hyp.timestep == batched_hyp.timestep diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 744936263a032..2005c0e8d41c6 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -269,7 +269,7 @@ def test_vocab_change(self, test_data_dir, asr_model): def test_decoding_change(self, asr_model): assert asr_model.decoding is not None assert isinstance(asr_model.decoding, CTCBPEDecoding) - assert asr_model.decoding.cfg.strategy == "greedy" + assert asr_model.decoding.cfg.strategy == "greedy_batched" assert asr_model.decoding.preserve_alignments is False assert asr_model.decoding.compute_timestamps is False diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index 98d563a4a688b..d2587913b879b 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -150,7 +150,7 @@ def test_vocab_change(self, asr_model): def test_decoding_change(self, asr_model): assert asr_model.decoding is not None assert isinstance(asr_model.decoding, CTCDecoding) - assert asr_model.decoding.cfg.strategy == "greedy" + assert asr_model.decoding.cfg.strategy == "greedy_batched" assert asr_model.decoding.preserve_alignments is False assert asr_model.decoding.compute_timestamps is False diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py index 55e780c022d8b..994d832ec6e54 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -297,7 +297,7 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoding is not None assert isinstance(hybrid_asr_model.ctc_decoding, CTCBPEDecoding) - assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy" + assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batched" assert hybrid_asr_model.ctc_decoding.preserve_alignments is False assert hybrid_asr_model.ctc_decoding.compute_timestamps is False diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 018c9bcc4aa23..923263787def5 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -231,7 +231,7 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoding is not None assert isinstance(hybrid_asr_model.ctc_decoding, CTCDecoding) - assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy" + assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batched" assert hybrid_asr_model.ctc_decoding.preserve_alignments is False assert hybrid_asr_model.ctc_decoding.compute_timestamps is False