From 0b89d3284403df35fd30e9633462d6bae728020f Mon Sep 17 00:00:00 2001 From: moto <855818+mthrok@users.noreply.github.com> Date: Mon, 3 Apr 2023 10:46:06 -0400 Subject: [PATCH] Fix virtual function issue with CTC decoder Currently, creating CTCDecoder object by passing a language model to `lm` argument without assigning it to a variable elsewhere causes `RuntimeError: Tried to call pure virtual function "LM::start"`. According to discussions on PyBind11, ( https://github.com/pybind/pybind11/discussions/4013 and https://github.com/pybind/pybind11/pull/2839 ) this is due to Python object garbage-collected by the time it's used by code implemented in C++. It attempts to call methods defined in Python, which overrides the base pure virtual function, but the object which provides this override gets deleted by garbage collrector, as the original object is not reference counted. This commit fixes this by simply assiging the given `lm` object as an attribute of CTCDecoder class. Address https://github.com/pytorch/audio/issues/3218 --- .../models/decoder/ctc_decoder_test.py | 16 ++++++++++++++++ torchaudio/models/decoder/_ctc_decoder.py | 6 ++++++ 2 files changed, 22 insertions(+) diff --git a/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py b/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py index f794f92ff9..87dc93ffd3 100644 --- a/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py +++ b/test/torchaudio_unittest/models/decoder/ctc_decoder_test.py @@ -169,3 +169,19 @@ def test_index_to_tokens(self, tokens): expected_tokens = ["|", "f", "|", "o", "a"] self.assertEqual(tokens, expected_tokens) + + def test_lm_lifecycle(self): + """Passing lm without assiging it to a vaiable won't cause runtime error + + https://github.com/pytorch/audio/issues/3218 + """ + from torchaudio.models.decoder import ctc_decoder + + from .ctc_decoder_utils import CustomZeroLM + + decoder = ctc_decoder( + lexicon=get_asset_path("decoder/lexicon.txt"), + tokens=get_asset_path("decoder/tokens.txt"), + lm=CustomZeroLM(), + ) + decoder(torch.zeros((1, 3, NUM_TOKENS), dtype=torch.float32)) diff --git a/torchaudio/models/decoder/_ctc_decoder.py b/torchaudio/models/decoder/_ctc_decoder.py index d9fa5165d8..33daa09ec9 100644 --- a/torchaudio/models/decoder/_ctc_decoder.py +++ b/torchaudio/models/decoder/_ctc_decoder.py @@ -269,6 +269,12 @@ def __init__( ) else: self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions) + # https://github.com/pytorch/audio/issues/3218 + # If lm is passed like rvalue reference, the lm object gets garbage collected, + # and later call to the lm fails. + # This ensures that lm object is not deleted as long as the decoder is alive. + # https://github.com/pybind/pybind11/discussions/4013 + self.lm = lm def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor: idxs = (g[0] for g in it.groupby(idxs))