diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 7c6236efc640..5859dbfa40e0 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -89,9 +89,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.cal_labels_occurrence_train = False self.labels_occurrence = None + if 'num_classes' in cfg.decoder: + num_classes = cfg.decoder.num_classes + else: + num_classes = cfg.decoder.params.num_classes # to pass test + if 'loss' in cfg: if 'weight' in cfg.loss: if cfg.loss.weight == 'auto': + weight = num_classes * [1] self.cal_labels_occurrence_train = True else: weight = cfg.loss.weight @@ -142,17 +148,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): tmp_loss_cfg = OmegaConf.create( {"_target_": "nemo.collections.common.losses.cross_entropy.CrossEntropyLoss"} ) + self.loss = instantiate(tmp_loss_cfg) self.eval_loss = instantiate(tmp_loss_cfg) - self.task = None self._accuracy = TopKClassificationAccuracy(top_k=[1]) - if 'num_classes' in cfg.decoder: - num_classes = cfg.decoder.num_classes - else: - num_classes = cfg.decoder.params.num_classes # to pass test - self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor) self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder) self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder)