Skip to content

Commit

Permalink
Fix virtual function issue with CTC decoder
Browse files Browse the repository at this point in the history
Currently, creating CTCDecoder object by passing a language model to
`lm` argument without assigning it to a variable elsewhere causes
`RuntimeError: Tried to call pure virtual function "LM::start"`.

According to discussions on PyBind11, (
pybind/pybind11#4013 and
pybind/pybind11#2839
) this is due to Python object garbage-collected by the time
it's used by code implemented in C++. It attempts to call
methods defined in Python, which overrides the base pure virtual
function, but the object which provides this override gets
deleted by garbage collrector, as the original object is not
reference counted.

This commit fixes this by simply assiging the given `lm` object
as an attribute of CTCDecoder class.

Address #3218
  • Loading branch information
mthrok committed Apr 3, 2023
1 parent c22cd16 commit 0b89d32
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/torchaudio_unittest/models/decoder/ctc_decoder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,3 +169,19 @@ def test_index_to_tokens(self, tokens):

expected_tokens = ["|", "f", "|", "o", "a"]
self.assertEqual(tokens, expected_tokens)

def test_lm_lifecycle(self):
"""Passing lm without assiging it to a vaiable won't cause runtime error
https://github.com/pytorch/audio/issues/3218
"""
from torchaudio.models.decoder import ctc_decoder

from .ctc_decoder_utils import CustomZeroLM

decoder = ctc_decoder(
lexicon=get_asset_path("decoder/lexicon.txt"),
tokens=get_asset_path("decoder/tokens.txt"),
lm=CustomZeroLM(),
)
decoder(torch.zeros((1, 3, NUM_TOKENS), dtype=torch.float32))
6 changes: 6 additions & 0 deletions torchaudio/models/decoder/_ctc_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,12 @@ def __init__(
)
else:
self.decoder = _LexiconFreeDecoder(decoder_options, lm, silence, self.blank, transitions)
# https://github.com/pytorch/audio/issues/3218
# If lm is passed like rvalue reference, the lm object gets garbage collected,
# and later call to the lm fails.
# This ensures that lm object is not deleted as long as the decoder is alive.
# https://github.com/pybind/pybind11/discussions/4013
self.lm = lm

def _get_tokens(self, idxs: torch.IntTensor) -> torch.LongTensor:
idxs = (g[0] for g in it.groupby(idxs))
Expand Down

0 comments on commit 0b89d32

Please sign in to comment.