Skip to content

Commit

Permalink
Fix and refactor label models (NVIDIA#4913)
Browse files Browse the repository at this point in the history
* fix testing after training hanging issue

Signed-off-by: fayejf <fayejf07@gmail.com>

* fix cal label occurence

Signed-off-by: fayejf <fayejf07@gmail.com>

* refactor loss in label_models

Signed-off-by: fayejf <fayejf07@gmail.com>

* change langid yaml for refactor

Signed-off-by: fayejf <fayejf07@gmail.com>

* style fix

Signed-off-by: fayejf <fayejf07@gmail.com>

* change speaker yaml files for refactor

Signed-off-by: fayejf <fayejf07@gmail.com>

* fix hang issue for speaker script

Signed-off-by: fayejf <fayejf07@gmail.com>

* reflect nithin's comment and update

Signed-off-by: fayejf <fayejf07@gmail.com>

* update lang id yaml

Signed-off-by: fayejf <fayejf07@gmail.com>

* refactor loss instantiation

Signed-off-by: fayejf <fayejf07@gmail.com>

* add loss to EncDecSpeakerLabelModel test

Signed-off-by: fayejf <fayejf07@gmail.com>

* omegaconf new varible

Signed-off-by: fayejf <fayejf07@gmail.com>

* pop

Signed-off-by: fayejf <fayejf07@gmail.com>

* revert loss in test

Signed-off-by: fayejf <fayejf07@gmail.com>

* fix decoder angular

Signed-off-by: fayejf <fayejf07@gmail.com>

* fix lgtm

Signed-off-by: fayejf <fayejf07@gmail.com>

* remove cls cfg loss check

Signed-off-by: fayejf <fayejf07@gmail.com>

* reflect comment

Signed-off-by: fayejf <fayejf07@gmail.com>

Signed-off-by: fayejf <fayejf07@gmail.com>
Signed-off-by: Matvei Novikov <mattyson.so@gmail.com>
  • Loading branch information
fayejf authored and jubick1337 committed Oct 3, 2022
1 parent 45069e5 commit 9ffc024
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/label_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -502,4 +502,4 @@ def get_batch_embeddings(speaker_model, manifest_filepath, batch_size=32, sample

all_logits, true_labels, all_embs = np.asarray(all_logits), np.asarray(all_labels), np.asarray(all_embs)

return all_embs, all_logits, true_labels, dataset.id2label
return all_embs, all_logits, true_labels, dataset.id2label

0 comments on commit 9ffc024

Please sign in to comment.