Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor CropModel to infer num_classes from checkpoint during loading #966

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

naxatra2
Copy link
Contributor

Fixes #960

This PR makes num_classes optional in the CropModel class and infers it from the checkpoint when loading via load_from_checkpoint. This improves usability, especially for users who do not know or remember the original number of classes at the time of inference.

Changes Made

  1. Made num_classes an optional argument in the CropModel constructor.
  2. In the constructor, if num_classes is not provided (None), we defer model/metrics initialization until on_load_checkpoint. Otherwise, we initialize them immediately.
  3. Added an on_load_checkpoint method that reads num_classes from the checkpoint if it is still None, then sets up the model and metrics.
  4. Added tests to verify that:
    • We can load from a checkpoint without passing num_classes.
    • Users can still pass num_classes explicitly if they want.
    • Calling forward without having a valid model raises a clear error.

@bw4sz
Copy link
Collaborator

bw4sz commented Mar 12, 2025

This is important and should be merged, two additions. We need to use

self.save_hyperparameters()

at the end of init documented here https://lightning.ai/docs/pytorch/1.6.2/common/hyperparameters.html#lightningmodule-hyperparameters

The second is the label dict, the crop model uses torchvision dataloader that loads folders from file, that could easily not be the same as the label_dict that as user could optionally pass. We should remove label_dict as an argument from init. Then we should add a self.label_dict = <> here

I believe its under self.train_ds. class_to_idx.

https://pytorch.org/vision/main/generated/torchvision.datasets.DatasetFolder.html#torchvision.datasets.DatasetFolder

Now, the question I have is if a use creates a model, saves the hyperparameters, loads a dataset, saves a checkpoint, and then reloads the model, is the label_dict still there?

@bw4sz bw4sz self-requested a review March 12, 2025 17:49
@bw4sz
Copy link
Collaborator

bw4sz commented Mar 12, 2025

Can we also fix a small bug related to the crop label.

This is poor logic, not sure what i was thinking, we should never have hidden conditional formatting. The cropmodel should return the same dtype and it should insist on a label dict, this explains the issue #705.

results["cropmodel_label"] = crop_model.numeric_to_label_dict[label]

There should not be a if statement here. CropModels should have label dicts and they should return labels, not numeric.

The function seems to expect batch_size 1, that is wrong, it should be

results["cropmodel_label"] = [crop_model.numeric_to_label_dict[x] for x in label]

to work for larger batch sizes. Let's confirm this with tests. @naxatra2 if this is too much, I can take over, this PR is important.

@naxatra2
Copy link
Contributor Author

If it is not urgent then can I try doing this @bw4sz, it will be a great learning opportunity. I will update my progress here.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

num_classes should not be a required argument for CropModel
2 participants