Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
Signed-off-by: Travis Bartley <tbartley@nvidia.com>
  • Loading branch information
tbartley94 committed Jul 25, 2024
1 parent 770b223 commit 576f202
Show file tree
Hide file tree
Showing 14 changed files with 437 additions and 62 deletions.
83 changes: 75 additions & 8 deletions nemo/collections/asr/data/audio_to_text_lhotse_prompted.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from lhotse.dataset.collation import collate_vectors

from nemo.collections.asr.data.audio_to_text_lhotse import TokenizerWrapper
from nemo.collections.common.prompts.canary import CanaryPromptFormatter
from nemo.collections.common.prompts import CanaryPromptFormatter, ParakeetPromptFormatter
from nemo.collections.common.tokenizers import CanaryTokenizer, TokenizerSpec
from nemo.collections.common.tokenizers.canary_tokenizer import CANARY_SPECIAL_TOKENIZER

Expand Down Expand Up @@ -64,13 +64,10 @@ def __getitem__(self, cuts: CutSet) -> tuple[torch.Tensor, torch.Tensor, torch.T
prompts_with_answers_lens = torch.tensor([t.size(0) for t in prompts_with_answers], dtype=torch.long)
prompts_with_answers = collate_vectors(prompts_with_answers, padding_value=self.padding_value)

if self.inference:
prompts = [torch.as_tensor(t) for t in prompts]
prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long)
prompts = collate_vectors(prompts, padding_value=self.padding_value)
else:
prompts = None
prompts_lens = None
prompts = [torch.as_tensor(t) for t in prompts]
prompts_lens = torch.tensor([t.size(0) for t in prompts], dtype=torch.long)
prompts = collate_vectors(prompts, padding_value=self.padding_value)


return audio, audio_lens, prompts_with_answers, prompts_with_answers_lens, prompts, prompts_lens

Expand Down Expand Up @@ -184,5 +181,75 @@ def canary(
return prompts_with_answers, prompts


@registered_prompt_format_fn
def parakeet(
cuts: CutSet, tokenizer: TokenizerWrapper, inference: bool = False
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
"""
Prepend and append control tokens to the token sequence as per Canary format.
We use the following special tokens:
* <|nopnc|>
* <|pnc|>
* <|LANG|> - for each supported language.
* <|nospeech|>
The prompt format syntax is as follows:
[ <|nospeech|> | <|LANG|> <|LANG|> [ <|pnc|> | <|nopnc|> ] TEXT ]
Where expression ``[ a | b ]`` denotes expression ``a`` or expression ``b``, and can be nested.
Note that ``<|LANG|>`` appears twice: the first occurrence is for the "source" language
(i.e., spoken language in the recording) and the second occurrence is for the "target" language
(i.e., the language in which we are going to output the text).
"""

assert isinstance(
tokenizer._tokenizer, CanaryTokenizer
), "To use 'canary' prompt format, you must use the CanaryTokenizer."
formatter = ParakeetPromptFormatter(tokenizer._tokenizer)

prompts_with_answers, prompts = [], []
for cut in cuts:
if isinstance(cut, MixedCut):
cut = cut._first_non_padding_cut
if not isinstance(cut, MonoCut):
raise TypeError(
f"Expected input audio to have a single channel (required MonoCut/MixedCut, but we received: {cut=})"
)

# first, validate the utterance
expected_slots = set(formatter.get_slots("user"))
missing_keys = expected_slots - set(cut.custom)
if missing_keys:
raise RuntimeError(
f"We found cut with ID {cut.id} that is missing the following keys: {missing_keys}"
f"Please ensure that every utterance in the input manifests contains these keys."
)

encoded = formatter.encode_dialog(
turns=[
dict(
role="user",
slots={
**{slot: cut.custom[slot] for slot in expected_slots},
formatter.PROMPT_LANGUAGE_SLOT: CANARY_SPECIAL_TOKENIZER,
},
),
dict(
role="assistant",
slots={
"text": ' '.join(s.text for s in cut.supervisions),
formatter.PROMPT_LANGUAGE_SLOT: cut.supervisions[0].language,
},
),
]
)
prompts_with_answers.append(encoded["input_ids"])
prompts.append(encoded["context_ids"])

return prompts_with_answers, prompts


class ProbablyIncorrectLanguageKeyError(RuntimeError):
pass
12 changes: 3 additions & 9 deletions nemo/collections/asr/models/aed_multitask_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
)
from nemo.collections.asr.metrics import BLEU, WER
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin
from nemo.collections.asr.parts.mixins import ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin, PromptingMixin
from nemo.collections.asr.parts.mixins.transcription import (
GenericTranscriptionType,
InternalTranscribeConfig,
Expand Down Expand Up @@ -115,7 +115,7 @@ def __post_init__(self):
self.prompt = parse_multitask_prompt(self.prompt)


class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin):
class EncDecMultiTaskModel(ASRModel, ExportableEncDecModel, ASRBPEMixin, ASRModuleMixin, ASRTranscriptionMixin, PromptingMixin):
"""Base class for AED multi-task models"""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand All @@ -125,18 +125,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
cfg = model_utils.maybe_update_config_version(cfg)
_config_check(cfg)

self.prompt_format = cfg.prompt_format
self.sample_rate = cfg.sample_rate
self._setup_tokenizer(cfg.tokenizer)
self._maybe_setup_prompting(cfg)

super().__init__(cfg=cfg, trainer=trainer)

prompt_cls = PromptFormatter.resolve(self.prompt_format)
self.prompt = prompt_cls(
tokenizer=self.tokenizer,
defaults=OmegaConf.to_container(pd) if (pd := cfg.get("prompt_defaults")) is not None else None,
)

# Setup audio preprocessor
self.preprocessor = EncDecMultiTaskModel.from_config_dict(self.cfg.preprocessor)
# Setup audio encoder
Expand Down
20 changes: 16 additions & 4 deletions nemo/collections/asr/models/ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,11 @@
from nemo.collections.asr.data.audio_to_text import _AudioTextDataset
from nemo.collections.asr.data.audio_to_text_dali import AudioToBPEDALIDataset
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextLhotseDataset, get_prompt_format_fn
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.parts.mixins import ASRBPEMixin
from nemo.collections.asr.parts.mixins import ASRBPEMixin, PromptingMixin
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCBPEDecoding, CTCBPEDecodingConfig
from nemo.collections.asr.parts.utils.asr_batching import get_semi_sorted_batch_sampler
from nemo.collections.common.data.lhotse import get_lhotse_dataloader_from_config
Expand All @@ -36,7 +37,7 @@
__all__ = ['EncDecCTCModelBPE']


class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin):
class EncDecCTCModelBPE(EncDecCTCModel, ASRBPEMixin, PromptingMixin):
"""Encoder decoder CTC-based models with Byte Pair Encoding."""

def __init__(self, cfg: DictConfig, trainer=None):
Expand Down Expand Up @@ -93,8 +94,19 @@ def __init__(self, cfg: DictConfig, trainer=None):
log_prediction=self._cfg.get("log_prediction", False),
)

def _setup_dataloader_from_config(self, config: Optional[Dict]):
def _setup_dataloader_from_config(self, config: Optional[Dict], inference: bool = False):
if config.get("use_lhotse"):
if self.prompt_format:
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=PromptedAudioToTextLhotseDataset(
tokenizer=self.tokenizer,
prompt_format_fn=get_prompt_format_fn(self.prompt_format),
inference=inference,
),
)
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
Expand Down Expand Up @@ -196,7 +208,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)
return temporary_datalayer

def change_vocabulary(
Expand Down
64 changes: 52 additions & 12 deletions nemo/collections/asr/models/ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@
from nemo.collections.asr.data.audio_to_text import _AudioTextDataset
from nemo.collections.asr.data.audio_to_text_dali import AudioToCharDALIDataset, DALIOutputs
from nemo.collections.asr.data.audio_to_text_lhotse import LhotseSpeechToTextBpeDataset
from nemo.collections.asr.data.audio_to_text_lhotse_prompted import PromptedAudioToTextLhotseDataset, get_prompt_format_fn
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.metrics.wer import WER
from nemo.collections.asr.models.asr_model import ASRModel, ExportableEncDecModel
from nemo.collections.asr.parts.mixins import ASRModuleMixin, ASRTranscriptionMixin, InterCTCMixin, TranscribeConfig
from nemo.collections.asr.parts.mixins import ASRModuleMixin, ASRTranscriptionMixin, InterCTCMixin, TranscribeConfig, PromptingMixin
from nemo.collections.asr.parts.mixins.transcription import GenericTranscriptionType, TranscriptionReturnType
from nemo.collections.asr.parts.preprocessing.segment import ChannelSelectorType
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig
Expand All @@ -47,7 +48,7 @@
__all__ = ['EncDecCTCModel']


class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin, InterCTCMixin, ASRTranscriptionMixin):
class EncDecCTCModel(ASRModel, ExportableEncDecModel, ASRModuleMixin, InterCTCMixin, ASRTranscriptionMixin, PromptingMixin):
"""Base class for encoder decoder CTC-based models."""

def __init__(self, cfg: DictConfig, trainer: Trainer = None):
Expand All @@ -57,6 +58,9 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
if trainer is not None:
self.world_size = trainer.world_size

# Setup lhotse prompting, if needed
self._maybe_setup_prompting(cfg)

super().__init__(cfg=cfg, trainer=trainer)
self.preprocessor = EncDecCTCModel.from_config_dict(self._cfg.preprocessor)
self.encoder = EncDecCTCModel.from_config_dict(self._cfg.encoder)
Expand Down Expand Up @@ -118,6 +122,8 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
# Adapter modules setup (from ASRAdapterModelMixin)
self.setup_adapters()

self._maybe_setup_prompt_adapters(cfg)

def transcribe(
self,
audio: Union[str, List[str], torch.Tensor, np.ndarray, DataLoader],
Expand Down Expand Up @@ -274,12 +280,23 @@ def change_decoding_strategy(self, decoding_cfg: DictConfig):

logging.info(f"Changed decoding strategy to \n{OmegaConf.to_yaml(self.cfg.decoding)}")

def _setup_dataloader_from_config(self, config: Optional[Dict]):
def _setup_dataloader_from_config(self, config: Optional[Dict], inference=False):
# Automatically inject args from model config to dataloader config
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='sample_rate')
audio_to_text_dataset.inject_dataloader_value_from_model_config(self.cfg, config, key='labels')

if config.get("use_lhotse"):
if self.prompt_format:
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
world_size=self.world_size,
dataset=PromptedAudioToTextLhotseDataset(
tokenizer=self.tokenizer,
prompt_format_fn=get_prompt_format_fn(self.prompt_format),
inference=inference,
),
)
return get_lhotse_dataloader_from_config(
config,
global_rank=self.global_rank,
Expand Down Expand Up @@ -414,7 +431,7 @@ def setup_validation_data(self, val_data_config: Optional[Union[DictConfig, Dict
# preserve config
self._update_dataset_config(dataset_name='validation', config=val_data_config)

self._validation_dl = self._setup_dataloader_from_config(config=val_data_config)
self._validation_dl = self._setup_dataloader_from_config(config=val_data_config, inference=True)

def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
"""
Expand All @@ -437,7 +454,7 @@ def setup_test_data(self, test_data_config: Optional[Union[DictConfig, Dict]]):
# preserve config
self._update_dataset_config(dataset_name='test', config=test_data_config)

self._test_dl = self._setup_dataloader_from_config(config=test_data_config)
self._test_dl = self._setup_dataloader_from_config(config=test_data_config, inference=True)

@property
def input_types(self) -> Optional[Dict[str, NeuralType]]:
Expand All @@ -450,6 +467,8 @@ def input_types(self) -> Optional[Dict[str, NeuralType]]:
"input_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"processed_signal": NeuralType(('B', 'D', 'T'), SpectrogramType(), optional=True),
"processed_signal_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"prompt": NeuralType(('B', 'T'), LengthsType(), optional=True),
"prompt_length": NeuralType(tuple('B'), LengthsType(), optional=True),
"sample_id": NeuralType(tuple('B'), LengthsType(), optional=True),
}

Expand All @@ -463,7 +482,7 @@ def output_types(self) -> Optional[Dict[str, NeuralType]]:

@typecheck()
def forward(
self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None
self, input_signal=None, input_signal_length=None, processed_signal=None, processed_signal_length=None, prompt=None, prompt_length=None
):
"""
Forward pass of the model.
Expand Down Expand Up @@ -501,10 +520,14 @@ def forward(

if self.spec_augmentation is not None and self.training:
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

processed_signal, processed_signal_length = self.apply_prompt(processed_signal, processed_signal_length, prompt, prompt_length, loc="preprocessor")

encoder_output = self.encoder(audio_signal=processed_signal, length=processed_signal_length)
encoded = encoder_output[0]
encoded_len = encoder_output[1]

encoded, encoded_len = self.apply_prompt(encoded, encoded_len, prompt, prompt_length, loc="encoder")
log_probs = self.decoder(encoder_output=encoded)
greedy_predictions = log_probs.argmax(dim=-1, keepdim=False)

Expand All @@ -523,13 +546,21 @@ def training_step(self, batch, batch_nb):
if self.is_interctc_enabled():
AccessMixin.set_access_enabled(access_enabled=True, guid=self.model_guid)

signal, signal_len, transcript, transcript_len = batch
signal, signal_len, transcript, transcript_len = batch[:4]
if self.prompt_format:
prompt, prompt_length = batch[4:]
transcript = transcript[:,prompt.shape[1]:]
transcript_len = transcript_len - prompt_length
else:
prompt, prompt_length = None, None

if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
log_probs, encoded_len, predictions = self.forward(
processed_signal=signal, processed_signal_length=signal_len
)
else:
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len, prompt=prompt, prompt_length=prompt_length)


if hasattr(self, '_trainer') and self._trainer is not None:
log_every_n_steps = self._trainer.log_every_n_steps
Expand Down Expand Up @@ -594,13 +625,19 @@ def validation_pass(self, batch, batch_idx, dataloader_idx=0):
if self.is_interctc_enabled():
AccessMixin.set_access_enabled(access_enabled=True, guid=self.model_guid)

signal, signal_len, transcript, transcript_len = batch
signal, signal_len, transcript, transcript_len = batch[:4]
if self.prompt_format:
prompt, prompt_length = batch[4:]
transcript = transcript[:,prompt_length:]
transcript_len = transcript_len - prompt_length
else:
prompt, prompt_length = None, None
if isinstance(batch, DALIOutputs) and batch.has_processed_signal:
log_probs, encoded_len, predictions = self.forward(
processed_signal=signal, processed_signal_length=signal_len
)
else:
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len)
log_probs, encoded_len, predictions = self.forward(input_signal=signal, input_signal_length=signal_len, prompt=prompt, prompt_length=prompt_length)

loss_value = self.loss(
log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
Expand Down Expand Up @@ -666,7 +703,10 @@ def test_dataloader(self):
""" Transcription related methods """

def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
logits, logits_len, greedy_predictions = self.forward(input_signal=batch[0], input_signal_length=batch[1])
if self.prompt_format:
logits, logits_len, greedy_predictions = self.forward(input_signal=batch[0], input_signal_length=batch[1], prompt=batch[-2], prompt_length=batch[-1])
else:
logits, logits_len, greedy_predictions = self.forward(input_signal=batch[0], input_signal_length=batch[1])
output = dict(logits=logits, logits_len=logits_len)
del greedy_predictions
return output
Expand Down Expand Up @@ -751,7 +791,7 @@ def _setup_transcribe_dataloader(self, config: Dict) -> 'torch.utils.data.DataLo
if config.get("augmentor"):
dl_config['augmentor'] = config.get("augmentor")

temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config))
temporary_datalayer = self._setup_dataloader_from_config(config=DictConfig(dl_config), inference=True)
return temporary_datalayer

@classmethod
Expand Down
Loading

0 comments on commit 576f202

Please sign in to comment.