-
Notifications
You must be signed in to change notification settings - Fork 381
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Trainers: support binary, multiclass, and multilabel tasks #2219
base: main
Are you sure you want to change the base?
Trainers: support binary, multiclass, and multilabel tasks #2219
Conversation
num_classes: int = 1000, | ||
task: str = 'multiclass', | ||
num_classes: int | None = None, | ||
num_labels: int | None = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The defaults here match torchmetrics. If task='multiclass'
, only num_classes
is used. If task='multilabel'
, only num_labels
is used. If task='binary'
, both are ignored. Honestly, we could have a single num_classes
if we want and simply use it for both.
@@ -266,147 +262,3 @@ def predict_step( | |||
x = batch['image'] | |||
y_hat: Tensor = self(x).softmax(dim=-1) | |||
return y_hat | |||
|
|||
|
|||
class MultiLabelClassificationTask(ClassificationTask): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The right thing to do would be to deprecate this first and remove it in 0.7.0. Not sure how widely used it is. Deprecation is kind of annoying because you need to change all tests to acknowledge the warning message.
I think we first need to add a multilabel semantic segmentation dataset to properly test this. |
Alternatively, skip multilabel semantic segmentation and only support multilabel classification. |
Instead of having separate trainers for binary, multiclass, and multilabel, let's create a single trainer that can handle all 3.
This applies to both Classification and Semantic Segmentation but not to our other trainers.
Closes #2205 @robmarkcole
Closes #245 @calebrob6