From 57c0c5c076fa97cfc03e0411d1117fb7601350a1 Mon Sep 17 00:00:00 2001 From: nithinraok Date: Tue, 20 Sep 2022 18:00:20 -0700 Subject: [PATCH] fix label models restoring issue from wrighted cross entropy Signed-off-by: nithinraok --- nemo/collections/asr/models/label_models.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/nemo/collections/asr/models/label_models.py b/nemo/collections/asr/models/label_models.py index 7c6236efc640..5859dbfa40e0 100644 --- a/nemo/collections/asr/models/label_models.py +++ b/nemo/collections/asr/models/label_models.py @@ -89,9 +89,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): self.cal_labels_occurrence_train = False self.labels_occurrence = None + if 'num_classes' in cfg.decoder: + num_classes = cfg.decoder.num_classes + else: + num_classes = cfg.decoder.params.num_classes # to pass test + if 'loss' in cfg: if 'weight' in cfg.loss: if cfg.loss.weight == 'auto': + weight = num_classes * [1] self.cal_labels_occurrence_train = True else: weight = cfg.loss.weight @@ -142,17 +148,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): tmp_loss_cfg = OmegaConf.create( {"_target_": "nemo.collections.common.losses.cross_entropy.CrossEntropyLoss"} ) + self.loss = instantiate(tmp_loss_cfg) self.eval_loss = instantiate(tmp_loss_cfg) - self.task = None self._accuracy = TopKClassificationAccuracy(top_k=[1]) - if 'num_classes' in cfg.decoder: - num_classes = cfg.decoder.num_classes - else: - num_classes = cfg.decoder.params.num_classes # to pass test - self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor) self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder) self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder)