From 2d2219cc91e4f2abbe2f567b4babb6b83a9f06d4 Mon Sep 17 00:00:00 2001 From: Daniel Galvez Date: Thu, 9 May 2024 11:00:08 -0700 Subject: [PATCH] Fix #8891 by supported GPU-side batched CTC Greedy Decoding (#9100) * Support batched inference of greedy CTC decoding. Fixes #8891 Basically, doing max() on CPU one at a time is very very slow. It is better to do that all on the GPU before we do the copy over to CPU. This new algorithm has the same interface as the old one and can be accessed by setting strategy to "greedy_batched" rather than "greedy". Warn when using greedy rather than greedy_batched strategy. Signed-off-by: Daniel Galvez --- .../asr/parts/submodules/ctc_decoding.py | 25 +- .../parts/submodules/ctc_greedy_decoding.py | 223 +++++++++++++++++- .../asr/decoding/test_ctc_decoding.py | 80 +++++++ .../asr/test_asr_ctc_encoder_model_bpe.py | 2 +- .../asr/test_asr_ctcencdec_model.py | 2 +- .../asr/test_asr_hybrid_rnnt_ctc_model_bpe.py | 2 +- .../test_asr_hybrid_rnnt_ctc_model_char.py | 2 +- 7 files changed, 325 insertions(+), 11 deletions(-) diff --git a/nemo/collections/asr/parts/submodules/ctc_decoding.py b/nemo/collections/asr/parts/submodules/ctc_decoding.py index 67559eccf6e27..70d63c0f8c6fd 100644 --- a/nemo/collections/asr/parts/submodules/ctc_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_decoding.py @@ -213,20 +213,20 @@ def __init__(self, decoding_cfg, blank_id: int): self.batch_dim_index = self.cfg.get('batch_dim_index', 0) self.word_seperator = self.cfg.get('word_seperator', ' ') - possible_strategies = ['greedy', 'beam', 'pyctcdecode', 'flashlight'] + possible_strategies = ['greedy', 'greedy_batched', 'beam', 'pyctcdecode', 'flashlight'] if self.cfg.strategy not in possible_strategies: raise ValueError(f"Decoding strategy must be one of {possible_strategies}. Given {self.cfg.strategy}") # Update preserve alignments if self.preserve_alignments is None: - if self.cfg.strategy in ['greedy']: + if self.cfg.strategy in ['greedy', 'greedy_batched']: self.preserve_alignments = self.cfg.greedy.get('preserve_alignments', False) else: self.preserve_alignments = self.cfg.beam.get('preserve_alignments', False) # Update compute timestamps if self.compute_timestamps is None: - if self.cfg.strategy in ['greedy']: + if self.cfg.strategy in ['greedy', 'greedy_batched']: self.compute_timestamps = self.cfg.greedy.get('compute_timestamps', False) elif self.cfg.strategy in ['beam']: self.compute_timestamps = self.cfg.beam.get('compute_timestamps', False) @@ -234,10 +234,10 @@ def __init__(self, decoding_cfg, blank_id: int): # initialize confidence-related fields self._init_confidence(self.cfg.get('confidence_cfg', None)) - # Confidence estimation is not implemented for strategies other than `greedy` + # Confidence estimation is not implemented for strategies other than `greedy` and `greedy_batched` if ( not self.preserve_frame_confidence - and self.cfg.strategy != 'greedy' + and self.cfg.strategy not in ('greedy', 'greedy_batched') and self.cfg.beam.get('preserve_frame_confidence', False) ): raise NotImplementedError(f"Confidence calculation is not supported for strategy `{self.cfg.strategy}`") @@ -247,6 +247,10 @@ def __init__(self, decoding_cfg, blank_id: int): self.compute_timestamps |= self.preserve_frame_confidence if self.cfg.strategy == 'greedy': + logging.warning( + "CTC decoding strategy 'greedy' is slower than 'greedy_batched', which implements the same exact interface. Consider changing your strategy to 'greedy_batched' for a free performance improvement.", + mode=logging_mode.ONCE, + ) self.decoding = ctc_greedy_decoding.GreedyCTCInfer( blank_id=self.blank_id, @@ -256,6 +260,15 @@ def __init__(self, decoding_cfg, blank_id: int): confidence_method_cfg=self.confidence_method_cfg, ) + elif self.cfg.strategy == "greedy_batched": + self.decoding = ctc_greedy_decoding.GreedyBatchedCTCInfer( + blank_id=self.blank_id, + preserve_alignments=self.preserve_alignments, + compute_timestamps=self.compute_timestamps, + preserve_frame_confidence=self.preserve_frame_confidence, + confidence_method_cfg=self.confidence_method_cfg, + ) + elif self.cfg.strategy == 'beam': self.decoding = ctc_beam_decoding.BeamCTCInfer( @@ -1287,7 +1300,7 @@ def decode_ids_to_tokens(self, tokens: List[int]) -> List[str]: @dataclass class CTCDecodingConfig: - strategy: str = "greedy" + strategy: str = "greedy_batched" # preserve decoding alignments preserve_alignments: Optional[bool] = None diff --git a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py index ab4b4c40e8602..1ef26cd7adf37 100644 --- a/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py +++ b/nemo/collections/asr/parts/submodules/ctc_greedy_decoding.py @@ -110,7 +110,7 @@ class GreedyCTCInfer(Typing, ConfidenceMethodMixin): def input_types(self): """Returns definitions of module input ports. """ - # Input can be of dimention - + # Input can be of dimension - # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] return { @@ -266,6 +266,227 @@ def __call__(self, *args, **kwargs): return self.forward(*args, **kwargs) +class GreedyBatchedCTCInfer(Typing, ConfidenceMethodMixin): + """A vectorized greedy CTC decoder. + + This is basically always faster than GreedyCTCInfer, and supports + the same interface. See issue #8891 on github for what is wrong + with GreedyCTCInfer. GreedyCTCInfer loops over each element in the + batch, running kernels at batch size one. CPU overheads end up + dominating. This implementation does appropriate masking to + appropriately do the same operation in a batched manner. + + Args: + blank_index: int index of the blank token. Can be 0 or len(vocabulary). + preserve_alignments: Bool flag which preserves the history of logprobs generated during + decoding (sample / batched). When set to true, the Hypothesis will contain + the non-null value for `logprobs` in it. Here, `logprobs` is a torch.Tensors. + compute_timestamps: A bool flag, which determines whether to compute the character/subword, or + word based timestamp mapping the output log-probabilities to discrite intervals of timestamps. + The timestamps will be available in the returned Hypothesis.timestep as a dictionary. + preserve_frame_confidence: Bool flag which preserves the history of per-frame confidence scores + generated during decoding. When set to true, the Hypothesis will contain + the non-null value for `frame_confidence` in it. Here, `frame_confidence` is a List of floats. + confidence_method_cfg: A dict-like object which contains the method name and settings to compute per-frame + confidence scores. + + name: The method name (str). + Supported values: + - 'max_prob' for using the maximum token probability as a confidence. + - 'entropy' for using a normalized entropy of a log-likelihood vector. + + entropy_type: Which type of entropy to use (str). Used if confidence_method_cfg.name is set to `entropy`. + Supported values: + - 'gibbs' for the (standard) Gibbs entropy. If the alpha (α) is provided, + the formula is the following: H_α = -sum_i((p^α_i)*log(p^α_i)). + Note that for this entropy, the alpha should comply the following inequality: + (log(V)+2-sqrt(log^2(V)+4))/(2*log(V)) <= α <= (1+log(V-1))/log(V-1) + where V is the model vocabulary size. + - 'tsallis' for the Tsallis entropy with the Boltzmann constant one. + Tsallis entropy formula is the following: H_α = 1/(α-1)*(1-sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/Tsallis_entropy + - 'renyi' for the Rényi entropy. + Rényi entropy formula is the following: H_α = 1/(1-α)*log_2(sum_i(p^α_i)), + where α is a parameter. When α == 1, it works like the Gibbs entropy. + More: https://en.wikipedia.org/wiki/R%C3%A9nyi_entropy + + alpha: Power scale for logsoftmax (α for entropies). Here we restrict it to be > 0. + When the alpha equals one, scaling is not applied to 'max_prob', + and any entropy type behaves like the Shannon entropy: H = -sum_i(p_i*log(p_i)) + + entropy_norm: A mapping of the entropy value to the interval [0,1]. + Supported values: + - 'lin' for using the linear mapping. + - 'exp' for using exponential mapping with linear shift. + + """ + + @property + def input_types(self): + """Returns definitions of module input ports. + """ + # Input can be of dimension - + # ('B', 'T', 'D') [Log probs] or ('B', 'T') [Labels] + + return { + "decoder_output": NeuralType(None, LogprobsType()), + "decoder_lengths": NeuralType(tuple('B'), LengthsType()), + } + + @property + def output_types(self): + """Returns definitions of module output ports. + """ + return {"predictions": [NeuralType(elements_type=HypothesisType())]} + + def __init__( + self, + blank_id: int, + preserve_alignments: bool = False, + compute_timestamps: bool = False, + preserve_frame_confidence: bool = False, + confidence_method_cfg: Optional[DictConfig] = None, + ): + super().__init__() + + self.blank_id = blank_id + self.preserve_alignments = preserve_alignments + # we need timestamps to extract non-blank per-frame confidence + self.compute_timestamps = compute_timestamps | preserve_frame_confidence + self.preserve_frame_confidence = preserve_frame_confidence + + # set confidence calculation method + self._init_confidence_method(confidence_method_cfg) + + @typecheck() + def forward( + self, decoder_output: torch.Tensor, decoder_lengths: torch.Tensor, + ): + """Returns a list of hypotheses given an input batch of the encoder hidden embedding. + Output token is generated auto-repressively. + + Args: + decoder_output: A tensor of size (batch, timesteps, features) or (batch, timesteps) (each timestep is a label). + decoder_lengths: list of int representing the length of each sequence + output sequence. + + Returns: + packed list containing batch number of sentences (Hypotheses). + """ + if decoder_output.ndim == 2: + hypotheses = self._greedy_decode_labels_batched(decoder_output, decoder_lengths) + else: + hypotheses = self._greedy_decode_logprobs_batched(decoder_output, decoder_lengths) + packed_result = pack_hypotheses(hypotheses, decoder_lengths) + return (packed_result,) + + @torch.no_grad() + def _greedy_decode_logprobs_batched(self, x: torch.Tensor, out_len: torch.Tensor): + # x: [B, T, D] + # out_len: [B] + + batch_size = x.shape[0] + max_time = x.shape[1] + + predictions = x + # In CTC greedy decoding, each output maximum likelihood token + # is calculated independent of the other tokens. + predictions_logprobs, predictions_labels = predictions.max(dim=-1) + + # Since predictions_logprobs is a padded matrix in the time + # dimension, we consider invalid timesteps to be "blank". + time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time) + non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, time_steps < out_len.unsqueeze(1)) + # Sum the non-blank labels to compute the score of the + # transcription. This follows from Eq. (3) of "Connectionist + # Temporal Classification: Labelling Unsegmented Sequence Data + # with Recurrent Neural Networks". + scores = torch.where(non_blank_ids_mask, predictions_logprobs, 0.0).sum(axis=1) + + scores = scores.cpu() + predictions_labels = predictions_labels.cpu() + out_len = out_len.cpu() + + if self.preserve_alignments or self.preserve_frame_confidence: + predictions = predictions.cpu() + + hypotheses = [] + + # This mimics the for loop in GreedyCTCInfer::forward. + for i in range(batch_size): + hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None) + hypothesis.score = scores[i] + + prediction_labels_no_padding = predictions_labels[i, : out_len[i]].tolist() + + assert predictions_labels.dtype == torch.int64 + hypothesis.y_sequence = prediction_labels_no_padding + + if self.preserve_alignments: + hypothesis.alignments = ( + predictions[i, : out_len[i], :].clone(), + predictions_labels[i, : out_len[i]].clone(), + ) + if self.compute_timestamps: + # TOOD: Could do this in a vectorized manner... Would + # prefer to have nonzero_static, though, for sanity. + # Or do a prefix sum on out_len + hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist() + if self.preserve_frame_confidence: + hypothesis.frame_confidence = self._get_confidence(predictions[i, : out_len[i], :]) + + hypotheses.append(hypothesis) + + return hypotheses + + @torch.no_grad() + def _greedy_decode_labels_batched(self, x: torch.Tensor, out_len: torch.Tensor): + """ + This does greedy decoding in the case where you have already found the + most likely token at each timestep. + """ + # x: [B, T] + # out_len: [B] + + batch_size = x.shape[0] + max_time = x.shape[1] + + predictions_labels = x + time_steps = torch.arange(max_time, device=x.device).unsqueeze(0).expand(batch_size, max_time) + non_blank_ids_mask = torch.logical_and(predictions_labels != self.blank_id, time_steps < out_len.unsqueeze(1)) + predictions_labels = predictions_labels.cpu() + out_len = out_len.cpu() + + hypotheses = [] + + for i in range(batch_size): + hypothesis = rnnt_utils.Hypothesis(score=0.0, y_sequence=[], dec_state=None, timestep=[], last_token=None) + hypothesis.y_sequence = predictions_labels[i, : out_len[i]].tolist() + hypothesis.score = -1.0 + + if self.preserve_alignments: + raise ValueError( + "Requested for alignments, but predictions provided were labels, not log probabilities." + ) + if self.compute_timestamps: + # TOOD: Could do this in a vectorized manner... Would + # prefer to have nonzero_static, though, for sanity. + # Or do a prefix sum on out_len + hypothesis.timestep = torch.nonzero(non_blank_ids_mask[i], as_tuple=False)[:, 0].cpu().tolist() + if self.preserve_frame_confidence: + raise ValueError( + "Requested for per-frame confidence, but predictions provided were labels, not log probabilities." + ) + + hypotheses.append(hypothesis) + + return hypotheses + + def __call__(self, *args, **kwargs): + return self.forward(*args, **kwargs) + + @dataclass class GreedyCTCInferConfig: preserve_alignments: bool = False diff --git a/tests/collections/asr/decoding/test_ctc_decoding.py b/tests/collections/asr/decoding/test_ctc_decoding.py index a3a5689062bff..8eceb822fd386 100644 --- a/tests/collections/asr/decoding/test_ctc_decoding.py +++ b/tests/collections/asr/decoding/test_ctc_decoding.py @@ -26,6 +26,7 @@ CTCDecoding, CTCDecodingConfig, ) +from nemo.collections.asr.parts.utils.asr_confidence_utils import ConfidenceConfig from nemo.collections.asr.parts.utils.rnnt_utils import Hypothesis @@ -191,3 +192,82 @@ def test_subword_decoding_greedy_forward_hypotheses(self, tmp_tokenizer, alignme # timestamps check if timestamps: check_subword_timestamps(hyp, decoding) + + @pytest.mark.unit + @pytest.mark.parametrize('alignments', [False, True]) + @pytest.mark.parametrize('timestamps', [False, True]) + @pytest.mark.parametrize('preserve_frame_confidence', [False, True]) + def test_batched_decoding_logprobs(self, tmp_tokenizer, alignments, timestamps, preserve_frame_confidence): + cfg = CTCBPEDecodingConfig( + strategy='greedy', + preserve_alignments=alignments, + compute_timestamps=timestamps, + confidence_cfg=ConfidenceConfig(preserve_frame_confidence=preserve_frame_confidence), + ) + unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + + cfg.strategy = 'greedy_batched' + batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + + torch.manual_seed(1) + B, T = 4, 20 + V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1 + input_signal = torch.randn(size=(B, T, V)) + # Set the blank index to a very high probability to make sure + # that we always handle at least a few blanks. + input_signal[:, 0, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 + input_signal[:, 1, unbatched_decoding.tokenizer.tokenizer.vocab_size] = 1000 + length = torch.randint(low=1, high=T, size=[B]) + + with torch.inference_mode(): + hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor( + input_signal, length, fold_consecutive=True, return_hypotheses=True + ) + + batched_hyps, _ = batched_decoding.ctc_decoder_predictions_tensor( + input_signal, length, fold_consecutive=True, return_hypotheses=True + ) + + assert len(hyps) == len(batched_hyps) == B + for hyp, batched_hyp in zip(hyps, batched_hyps): + assert torch.abs(hyp.score - batched_hyp.score) <= 1e-5 + assert torch.all(hyp.y_sequence == batched_hyp.y_sequence) + if timestamps: + assert hyp.timestep == batched_hyp.timestep + if alignments: + assert torch.all(hyp.alignments[0] == batched_hyp.alignments[0]) + assert torch.all(hyp.alignments[1] == batched_hyp.alignments[1]) + + @pytest.mark.unit + @pytest.mark.parametrize('timestamps', [False, True]) + def test_batched_decoding_labels(self, tmp_tokenizer, timestamps): + cfg = CTCBPEDecodingConfig(strategy='greedy', compute_timestamps=timestamps) + unbatched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + cfg.strategy = 'greedy_batched' + batched_decoding = CTCBPEDecoding(decoding_cfg=cfg, tokenizer=tmp_tokenizer) + + torch.manual_seed(1) + B, T = 4, 20 + V = unbatched_decoding.tokenizer.tokenizer.vocab_size + 1 + input_labels = torch.randint(V, size=(B, T)) + # Set some indices to blank to make sure that we always handle + # at least a few blanks. + input_labels[:, 0] = unbatched_decoding.tokenizer.tokenizer.vocab_size + input_labels[:, 1] = unbatched_decoding.tokenizer.tokenizer.vocab_size + length = torch.randint(low=1, high=T, size=[B]) + + with torch.inference_mode(): + hyps, _ = unbatched_decoding.ctc_decoder_predictions_tensor( + input_labels, length, fold_consecutive=True, return_hypotheses=True + ) + + batched_hyps, _ = batched_decoding.ctc_decoder_predictions_tensor( + input_labels, length, fold_consecutive=True, return_hypotheses=True + ) + + assert len(hyps) == len(batched_hyps) == B + for hyp, batched_hyp in zip(hyps, batched_hyps): + assert abs(hyp.score - batched_hyp.score) <= 1e-5 + assert torch.all(hyp.y_sequence == batched_hyp.y_sequence) + if timestamps: + assert hyp.timestep == batched_hyp.timestep diff --git a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py index 744936263a032..2005c0e8d41c6 100644 --- a/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py +++ b/tests/collections/asr/test_asr_ctc_encoder_model_bpe.py @@ -269,7 +269,7 @@ def test_vocab_change(self, test_data_dir, asr_model): def test_decoding_change(self, asr_model): assert asr_model.decoding is not None assert isinstance(asr_model.decoding, CTCBPEDecoding) - assert asr_model.decoding.cfg.strategy == "greedy" + assert asr_model.decoding.cfg.strategy == "greedy_batched" assert asr_model.decoding.preserve_alignments is False assert asr_model.decoding.compute_timestamps is False diff --git a/tests/collections/asr/test_asr_ctcencdec_model.py b/tests/collections/asr/test_asr_ctcencdec_model.py index 98d563a4a688b..d2587913b879b 100644 --- a/tests/collections/asr/test_asr_ctcencdec_model.py +++ b/tests/collections/asr/test_asr_ctcencdec_model.py @@ -150,7 +150,7 @@ def test_vocab_change(self, asr_model): def test_decoding_change(self, asr_model): assert asr_model.decoding is not None assert isinstance(asr_model.decoding, CTCDecoding) - assert asr_model.decoding.cfg.strategy == "greedy" + assert asr_model.decoding.cfg.strategy == "greedy_batched" assert asr_model.decoding.preserve_alignments is False assert asr_model.decoding.compute_timestamps is False diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py index 55e780c022d8b..994d832ec6e54 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_bpe.py @@ -297,7 +297,7 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoding is not None assert isinstance(hybrid_asr_model.ctc_decoding, CTCBPEDecoding) - assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy" + assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batched" assert hybrid_asr_model.ctc_decoding.preserve_alignments is False assert hybrid_asr_model.ctc_decoding.compute_timestamps is False diff --git a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py index 018c9bcc4aa23..923263787def5 100644 --- a/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py +++ b/tests/collections/asr/test_asr_hybrid_rnnt_ctc_model_char.py @@ -231,7 +231,7 @@ def test_decoding_change(self, hybrid_asr_model): assert hybrid_asr_model.ctc_decoding is not None assert isinstance(hybrid_asr_model.ctc_decoding, CTCDecoding) - assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy" + assert hybrid_asr_model.ctc_decoding.cfg.strategy == "greedy_batched" assert hybrid_asr_model.ctc_decoding.preserve_alignments is False assert hybrid_asr_model.ctc_decoding.compute_timestamps is False