From 361576af2c2f3a183fd890ebecf4f04e7e77bd33 Mon Sep 17 00:00:00 2001 From: "github-actions[bot]" <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 21 Sep 2022 14:19:06 -0700 Subject: [PATCH] fix label models restoring issue from wrighted cross entropy (#4968) (#4975) Signed-off-by: nithinraok Signed-off-by: nithinraok Signed-off-by: nithinraok Co-authored-by: Nithin Rao Signed-off-by: Hainan Xu --- nemo/collections/asr/models/label_models.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 7c6236efc640..5859dbfa40e0 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -89,9 +89,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.cal_labels_occurrence_train = False self.labels_occurrence = None + if 'num_classes' in cfg.decoder: + num_classes = cfg.decoder.num_classes + else: + num_classes = cfg.decoder.params.num_classes # to pass test + if 'loss' in cfg: if 'weight' in cfg.loss: if cfg.loss.weight == 'auto': + weight = num_classes * [1] self.cal_labels_occurrence_train = True else: weight = cfg.loss.weight @@ -142,17 +148,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): tmp_loss_cfg = OmegaConf.create( {"_target_": "nemo.collections.common.losses.cross_entropy.CrossEntropyLoss"} ) + self.loss = instantiate(tmp_loss_cfg) self.eval_loss = instantiate(tmp_loss_cfg) - self.task = None self._accuracy = TopKClassificationAccuracy(top_k=[1]) - if 'num_classes' in cfg.decoder: - num_classes = cfg.decoder.num_classes - else: - num_classes = cfg.decoder.params.num_classes # to pass test - self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor) self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder) self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder)