Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix top_k for multiclass-f1score #2839

Merged
merged 52 commits into from
Dec 21, 2024

Conversation

rittik9
Copy link
Contributor

@rittik9 rittik9 commented Nov 21, 2024

What does this PR do?

Fixes #1653

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Did you have fun?

Make sure you had fun coding 🙃


📚 Documentation preview 📚: https://torchmetrics--2839.org.readthedocs.build/en/2839/

@eneserdo
Copy link

eneserdo commented Nov 22, 2024

Thanks for the effort. Imho, you could have handled the "refining" on directly preds tensor using something like this:

preds_topk = torch.argsort(preds, dim=-1, descending=True)[:, :top_k]
preds_top1 = preds_topk[:, 0]
preds=torch.where((target.view(-1, 1) == preds_topk).sum(dim=-1).bool(), target, preds_top1)

Which is more compact way of doing the same job. (Cloning and reshaping are omitted here)

Also, these changes will break the current tests for all top_k related classes/functions e.g. for recall, accuracy, f1, so on so forth. I think it is important to re-write these tests. Additionally, maybe for the topk accuracy you can take the scikit learn's top_k_accuracy_score as a reference.

@rittik9
Copy link
Contributor Author

rittik9 commented Nov 22, 2024

Thanks for your suggestions.I've noticed some of the tests have failed. I'm working on them. I am also comparing them with other library implementations. I'll keep updating here.

@rittik9 rittik9 force-pushed the rittik/multiclassf1_topk branch from fe3ac74 to 6d40bb4 Compare November 22, 2024 22:33
@rittik9 rittik9 force-pushed the rittik/multiclassf1_topk branch from 4574d2f to a949834 Compare November 22, 2024 23:22
@rittik9 rittik9 marked this pull request as ready for review November 23, 2024 13:26
@mergify mergify bot removed the has conflicts label Dec 11, 2024
CHANGELOG.md Show resolved Hide resolved
CHANGELOG.md Outdated Show resolved Hide resolved
tests/unittests/classification/test_stat_scores.py Outdated Show resolved Hide resolved
tests/unittests/classification/test_stat_scores.py Outdated Show resolved Hide resolved
tests/unittests/classification/test_specificity.py Outdated Show resolved Hide resolved
@lantiga
Copy link
Contributor

lantiga commented Dec 21, 2024

Thanks @rittik9!

@mergify mergify bot removed the ready label Dec 21, 2024
@mergify mergify bot added the ready label Dec 21, 2024
@Borda Borda added the bug / fix Something isn't working label Dec 21, 2024
tests/unittests/classification/test_f_beta.py Outdated Show resolved Hide resolved
tests/unittests/classification/test_precision_recall.py Outdated Show resolved Hide resolved
tests/unittests/classification/test_specificity.py Outdated Show resolved Hide resolved
tests/unittests/classification/test_stat_scores.py Outdated Show resolved Hide resolved
tests/unittests/classification/test_stat_scores.py Outdated Show resolved Hide resolved
@Borda Borda merged commit 4d9c843 into Lightning-AI:master Dec 21, 2024
49 of 58 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

top_k for multiclassf1score is not working correctly
4 participants