Skip to content

Commit

Permalink
ctc loss integration
Browse files Browse the repository at this point in the history
Signed-off-by: andrusenkoau <andrusenkoau@gmail.com>
  • Loading branch information
andrusenkoau committed Jul 30, 2024
1 parent 43a9c03 commit ce9bd0d
Show file tree
Hide file tree
Showing 3 changed files with 141 additions and 65 deletions.
90 changes: 85 additions & 5 deletions nemo/collections/multimodal/speech_llm/models/modular_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@
from nemo.utils.apex_utils import _reconfigure_microbatch_calculator
from nemo.utils.model_utils import inject_model_parallel_rank

from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCBPEDecoding, CTCBPEDecodingConfig
from nemo.collections.asr.metrics.wer import WER

try:
from megatron.core import InferenceParams, parallel_state, tensor_parallel
from megatron.core.models.gpt import GPTModel as MCoreGPTModel
Expand All @@ -78,6 +82,40 @@
class ModularAudioGPTModel(SpeechLLMAdapterMixin, MegatronGPTSFTModel):
"""Modularized speech GPT model."""

def setup_ctc_head(self, cfg):
if 'aux_ctc' not in cfg:
raise ValueError(
"The config need to have a section for the CTC decoder named as aux_ctc for Hybrid models."
)
# self.cfg.aux_ctc.decoder.vocabulary = ListConfig(self.tokenizer.vocab)
# the error arises from 5303 token "${" in the tokenizer
# use dummy vocab for now (temporary fix)
self.cfg.aux_ctc.decoder.vocabulary = [1] * len(self.tokenizer.vocab)
self.cfg.aux_ctc.decoder.num_classes = len(self.tokenizer.vocab)

self.ctc_decoder = self.from_config_dict(self.cfg.aux_ctc.decoder)
self.ctc_loss_weight = self.cfg.aux_ctc.get("ctc_loss_weight", 0.1)

self.ctc_loss = CTCLoss(
num_classes=self.ctc_decoder.num_classes_with_blank - 1,
zero_infinity=True,
reduction=self.cfg.aux_ctc.get("ctc_reduction", "mean_batch"),
)

ctc_decoding_cfg = self.cfg.aux_ctc.get('decoding', None)
if ctc_decoding_cfg is None:
ctc_decoding_cfg = OmegaConf.structured(CTCBPEDecodingConfig)
with open_dict(self.cfg.aux_ctc):
self.cfg.aux_ctc.decoding = ctc_decoding_cfg

self.ctc_decoding = CTCBPEDecoding(self.cfg.aux_ctc.decoding, tokenizer=self.tokenizer)
self.ctc_wer = WER(
decoding=self.ctc_decoding,
use_cer=self.cfg.aux_ctc.get('use_cer', False),
dist_sync_on_step=True,
log_prediction=self.cfg.get("log_prediction", False),
)

def setup_perception_modules(self, cfg):
if 'target' in cfg.perception:
imported_cls = model_utils.import_class_by_path(cfg.perception.target)
Expand All @@ -96,6 +134,10 @@ def __init__(self, cfg: DictConfig, trainer: Trainer):
self.enforce_divisible_batch = False
self.setup_perception_modules(cfg)

### CTC head start:
self.setup_ctc_head(cfg)
### CTC head end.

# print out params in more details
self.summarize(max_depth=2)

Expand Down Expand Up @@ -328,7 +370,7 @@ def prepare_llm_input(self, audio_batch):
context_start_idx = audio_batch.get("context_start_idx", None)

# [b, t, c]
encoded, encoded_len = self.perception(
encoded, encoded_len, audio_encoder_outputs = self.perception(
input_signal=input_signal,
input_signal_length=input_signal_length,
processed_signal=None,
Expand All @@ -354,7 +396,7 @@ def prepare_llm_input(self, audio_batch):
loss_mask, input_length, encoded_len, encoder_max_length, pad_token=0
)

return encoder_input, attention_mask, labels, loss_mask, encoder_length
return encoder_input, attention_mask, labels, loss_mask, encoder_length, audio_encoder_outputs

def forward(
self,
Expand All @@ -373,7 +415,7 @@ def forward(
rank_zero_only=False,
)

encoder_input, attention_mask, labels, loss_mask, _ = self.prepare_llm_input(audio_batch)
encoder_input, attention_mask, labels, loss_mask, _, audio_encoder_outputs = self.prepare_llm_input(audio_batch)
if self.mcore_gpt:
output = self.model(
input_ids=None,
Expand All @@ -392,7 +434,7 @@ def forward(
checkpoint_activations_all_layers=checkpoint_activations_all_layers,
)

return output, loss_mask
return output, loss_mask, audio_encoder_outputs

def get_forward_output_only_func(self):
def fwd_output_only_func(dataloader_iter, model):
Expand Down Expand Up @@ -432,6 +474,7 @@ def fwd_output_only_func(dataloader_iter, model):
):
attention_mask = None

# model returns stndard output and audio_encoder_outputs
output_tensor = model(
input_ids=None,
position_ids=None,
Expand Down Expand Up @@ -478,7 +521,10 @@ def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_
if not self.mcore_gpt:
batch['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers

output_tensor, loss_mask = self.forward(
# output_tensor, loss_mask, ctc_head_output = self.forward(
# batch, checkpoint_activations_all_layers=checkpoint_activations_all_layers
# )
output_tensor, loss_mask, audio_encoder_outputs = self.forward(
batch, checkpoint_activations_all_layers=checkpoint_activations_all_layers
)
batch['loss_mask'] = loss_mask
Expand All @@ -487,11 +533,44 @@ def loss_func(output_tensor):
# Loss for a micro-batch (ub)
loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor)
cp_size = self.cfg.get('context_parallel_size', 1)

# compute ctc loss

ctc_log_probs = self.ctc_decoder(encoder_output=audio_encoder_outputs[0])
ctc_input_lengths = audio_encoder_outputs[1]

ctc_loss = self.ctc_loss(
log_probs=ctc_log_probs,
targets=batch["ctc_tokens"],
input_lengths=ctc_input_lengths,
target_lengths=batch["ctc_tokens_length"],
)
self.log("aux_ctc_loss", ctc_loss, batch_size=1)
self.log("loss_for_ub", loss_for_ub, batch_size=1)

loss_for_ub = (1 - self.ctc_loss_weight) * loss_for_ub + self.ctc_loss_weight * ctc_loss

# logging.warning("*************"*10)
# # logging.warning(f"batch: {batch}")
# # logging.warning(f"ctc_head_output[0].shape: {ctc_head_output[0].shape}")
# logging.warning(f"batch['ctc_tokens']: {batch['ctc_tokens']}")
# logging.warning(f"batch['ctc_tokens'][0]: {self.tokenizer.ids_to_tokens(batch['ctc_tokens'][0].tolist())}")
# logging.warning(f"CTC Loss: {ctc_loss}")
# logging.warning(f"loss_for_ub: {loss_for_ub}")
# # raise NotImplementedError("CTC loss implementation in progress...")

if self.cfg.data.get(
"return_output_tensors", False
): # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare)
loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub
# logging.warning("*************"*10)
# logging.warning(f"loss_for_ub loss: {loss_for_ub}")
# raise NotImplementedError("CTC loss implementation in progress...")
# loss_for_ub = (1 - self.perception.ctc_loss_weight) * loss_for_ub + self.perception.ctc_loss_weight * ctc_loss
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])



pos_cs = average_losses_across_data_parallel_group([pos_cs])
neg_cs = average_losses_across_data_parallel_group([neg_cs])
diff_cs = average_losses_across_data_parallel_group([diff_cs])
Expand Down Expand Up @@ -527,6 +606,7 @@ def loss_func(output_tensor):
return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu}
else:
reduced_loss = average_losses_across_data_parallel_group([loss_for_ub])
# logging.warning(f"reduced_loss: {reduced_loss}")
return loss_for_ub * cp_size, {'avg': reduced_loss}

return output_tensor, loss_func
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def init_batch(
"""initialize the batch data before the inference steps."""
# Move to GPU.

audio_feats, audio_feat_lens = self.model.perception(
audio_feats, audio_feat_lens, _ = self.model.perception(
input_signal=audio_signal,
input_signal_length=audio_length,
processed_signal=None,
Expand Down Expand Up @@ -194,6 +194,7 @@ def init_batch(
context_tokens,
_,
(speech_encoded, speech_encoded_len, extra_outputs),
_,
) = self.model.prepare_llm_input(batch)
self.position_ids = build_position_ids(encoder_input[:, :, 0].transpose(0, 1))
self.extra_outputs = extra_outputs
Expand Down
113 changes: 54 additions & 59 deletions nemo/collections/multimodal/speech_llm/modules/perception_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
from nemo.utils.decorators import experimental

from omegaconf import DictConfig, OmegaConf, open_dict
from nemo.collections.asr.losses.ctc import CTCLoss
from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig
from nemo.collections.asr.metrics.wer import WER
# from nemo.collections.asr.losses.ctc import CTCLoss
# from nemo.collections.asr.parts.submodules.ctc_decoding import CTCDecoding, CTCDecodingConfig
# from nemo.collections.asr.metrics.wer import WER
from nemo.utils import logging

__all__ = ["AudioPerceptionModule", "MultiAudioPerceptionModule"]
Expand Down Expand Up @@ -73,60 +73,60 @@ def output_types(self):
}
)

def __init__(self, cfg: DictConfig, tokenizer):
def __init__(self, cfg: DictConfig):
super().__init__()
# Initialize components
self.cfg = cfg
self.preprocessor = self.from_config_dict(cfg.preprocessor)
self.encoder = self.from_config_dict(cfg.encoder)

### CTC head start:
if 'aux_ctc' not in self.cfg:
raise ValueError(
"The config need to have a section for the CTC decoder named as aux_ctc for Hybrid models."
)
# with open_dict(self.cfg.aux_ctc):
# if "feat_in" not in self.cfg.aux_ctc.decoder or (
# not self.cfg.aux_ctc.decoder.feat_in and hasattr(self.encoder, '_feat_out')
# ):
# self.cfg.aux_ctc.decoder.feat_in = self.encoder._feat_out
# if "feat_in" not in self.cfg.aux_ctc.decoder or not self.cfg.aux_ctc.decoder.feat_in:
# raise ValueError("param feat_in of the decoder's config is not set!")

# if self.cfg.aux_ctc.decoder.num_classes < 1 and self.cfg.aux_ctc.decoder.vocabulary is not None:
# logging.info(
# "\nReplacing placeholder number of classes ({}) with actual number of classes - {}".format(
# self.cfg.aux_ctc.decoder.num_classes, len(self.cfg.aux_ctc.decoder.vocabulary)
# )
# )
# self.cfg.aux_ctc.decoder["num_classes"] = len(self.cfg.aux_ctc.decoder.vocabulary)

self.cfg.aux_ctc.decoder.vocabulary = [1]*len(tokenizer.vocab)
self.cfg.aux_ctc.decoder.num_classes = len(tokenizer.vocab)

self.ctc_decoder = self.from_config_dict(self.cfg.aux_ctc.decoder)
self.ctc_loss_weight = self.cfg.aux_ctc.get("ctc_loss_weight", 0.2)

self.ctc_loss = CTCLoss(
num_classes=self.ctc_decoder.num_classes_with_blank - 1,
zero_infinity=True,
reduction=self.cfg.aux_ctc.get("ctc_reduction", "mean_batch"),
)

ctc_decoding_cfg = self.cfg.aux_ctc.get('decoding', None)
if ctc_decoding_cfg is None:
ctc_decoding_cfg = OmegaConf.structured(CTCDecodingConfig)
with open_dict(self.cfg.aux_ctc):
self.cfg.aux_ctc.decoding = ctc_decoding_cfg

self.ctc_decoding = CTCDecoding(self.cfg.aux_ctc.decoding, vocabulary=self.ctc_decoder.vocabulary)
self.ctc_wer = WER(
decoding=self.ctc_decoding,
use_cer=self.cfg.aux_ctc.get('use_cer', False),
dist_sync_on_step=True,
log_prediction=self.cfg.get("log_prediction", False),
)
### CTC head end.
# ### CTC head start:
# if 'aux_ctc' not in self.cfg:
# raise ValueError(
# "The config need to have a section for the CTC decoder named as aux_ctc for Hybrid models."
# )
# # with open_dict(self.cfg.aux_ctc):
# # if "feat_in" not in self.cfg.aux_ctc.decoder or (
# # not self.cfg.aux_ctc.decoder.feat_in and hasattr(self.encoder, '_feat_out')
# # ):
# # self.cfg.aux_ctc.decoder.feat_in = self.encoder._feat_out
# # if "feat_in" not in self.cfg.aux_ctc.decoder or not self.cfg.aux_ctc.decoder.feat_in:
# # raise ValueError("param feat_in of the decoder's config is not set!")

# # if self.cfg.aux_ctc.decoder.num_classes < 1 and self.cfg.aux_ctc.decoder.vocabulary is not None:
# # logging.info(
# # "\nReplacing placeholder number of classes ({}) with actual number of classes - {}".format(
# # self.cfg.aux_ctc.decoder.num_classes, len(self.cfg.aux_ctc.decoder.vocabulary)
# # )
# # )
# # self.cfg.aux_ctc.decoder["num_classes"] = len(self.cfg.aux_ctc.decoder.vocabulary)

# self.cfg.aux_ctc.decoder.vocabulary = [1]*len(tokenizer.vocab)
# self.cfg.aux_ctc.decoder.num_classes = len(tokenizer.vocab)

# self.ctc_decoder = self.from_config_dict(self.cfg.aux_ctc.decoder)
# self.ctc_loss_weight = self.cfg.aux_ctc.get("ctc_loss_weight", 0.2)

# self.ctc_loss = CTCLoss(
# num_classes=self.ctc_decoder.num_classes_with_blank - 1,
# zero_infinity=True,
# reduction=self.cfg.aux_ctc.get("ctc_reduction", "mean_batch"),
# )

# ctc_decoding_cfg = self.cfg.aux_ctc.get('decoding', None)
# if ctc_decoding_cfg is None:
# ctc_decoding_cfg = OmegaConf.structured(CTCDecodingConfig)
# with open_dict(self.cfg.aux_ctc):
# self.cfg.aux_ctc.decoding = ctc_decoding_cfg

# self.ctc_decoding = CTCDecoding(self.cfg.aux_ctc.decoding, vocabulary=self.ctc_decoder.vocabulary)
# self.ctc_wer = WER(
# decoding=self.ctc_decoding,
# use_cer=self.cfg.aux_ctc.get('use_cer', False),
# dist_sync_on_step=True,
# log_prediction=self.cfg.get("log_prediction", False),
# )
# ### CTC head end.

if cfg.get("use_multi_layer_feat", False) and cfg.get("multi_layer_feat", None):
if "_target_" in cfg.multi_layer_feat.aggregator:
Expand Down Expand Up @@ -187,19 +187,14 @@ def forward(
processed_signal = self.spec_augmentation(input_spec=processed_signal, length=processed_signal_length)

encoded, encoded_len = self.encoder(audio_signal=processed_signal, length=processed_signal_length)

### CTC head start:
if self.ctc_loss_weight > 0:
ctc_log_probs = self.ctc_decoder(encoder_output=encoded)
ctc_input_lengths = encoded_len
### CTC head end.

audio_encoder_outputs = (encoded, encoded_len)
encoded, encoded_len = self.modality_adapter(audio_signal=encoded, length=encoded_len)

# b, c, t -> b, t, c
encoded = self.proj(encoded.transpose(1, 2))

return encoded, encoded_len, (ctc_log_probs, ctc_input_lengths)
# return encoded, encoded_len, audio_encoder_outputs
return encoded, encoded_len, audio_encoder_outputs


class MultiFeatureAggregator(nn.Module):
Expand Down

0 comments on commit ce9bd0d

Please sign in to comment.