Skip to content

Commit

Permalink
fix label models restoring issue from wrighted cross entropy (NVIDIA#…
Browse files Browse the repository at this point in the history
…4968) (NVIDIA#4975)

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>

Signed-off-by: nithinraok <nithinrao.koluguri@gmail.com>
Co-authored-by: Nithin Rao <nithinrao.koluguri@gmail.com>
Signed-off-by: Hainan Xu <hainanx@nvidia.com>
  • Loading branch information
2 people authored and Hainan Xu committed Nov 29, 2022
1 parent 3f041ba commit 361576a
Showing 1 changed file with 7 additions and 6 deletions.
13 changes: 7 additions & 6 deletions nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,15 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
self.cal_labels_occurrence_train = False
self.labels_occurrence = None

if 'num_classes' in cfg.decoder:
num_classes = cfg.decoder.num_classes
else:
num_classes = cfg.decoder.params.num_classes # to pass test

if 'loss' in cfg:
if 'weight' in cfg.loss:
if cfg.loss.weight == 'auto':
weight = num_classes * [1]
self.cal_labels_occurrence_train = True
else:
weight = cfg.loss.weight
Expand Down Expand Up @@ -142,17 +148,12 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
tmp_loss_cfg = OmegaConf.create(
{"_target_": "nemo.collections.common.losses.cross_entropy.CrossEntropyLoss"}
)

self.loss = instantiate(tmp_loss_cfg)
self.eval_loss = instantiate(tmp_loss_cfg)

self.task = None
self._accuracy = TopKClassificationAccuracy(top_k=[1])

if 'num_classes' in cfg.decoder:
num_classes = cfg.decoder.num_classes
else:
num_classes = cfg.decoder.params.num_classes # to pass test

self.preprocessor = EncDecSpeakerLabelModel.from_config_dict(cfg.preprocessor)
self.encoder = EncDecSpeakerLabelModel.from_config_dict(cfg.encoder)
self.decoder = EncDecSpeakerLabelModel.from_config_dict(cfg.decoder)
Expand Down

0 comments on commit 361576a

Please sign in to comment.