Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

transfer learning for resnet50-res512-all #92

Open
dev-yue opened this issue Apr 4, 2022 · 3 comments
Open

transfer learning for resnet50-res512-all #92

dev-yue opened this issue Apr 4, 2022 · 3 comments

Comments

@dev-yue
Copy link

dev-yue commented Apr 4, 2022

Great library! Would you provide the transfer learning code for resnet50-res512-all as well? Thank you so much!

@ieee8023
Copy link
Member

ieee8023 commented Apr 4, 2022

It should be as simple as using this line:

model = xrv.models.ResNet(weights="resnet50-res512-all")

in this script: https://github.com/mlmed/torchxrayvision/blob/master/scripts/transfer_learning.ipynb

Also change the resizing to xrv.datasets.XRayResizer(512) so the images are 512x512

@dev-yue
Copy link
Author

dev-yue commented Apr 4, 2022 via email

@ieee8023
Copy link
Member

ieee8023 commented Apr 4, 2022

Oh sorry I responded to fast and didn't test the code. The resnet loads an internal resnet model inside so the fc is located at model.model.fc.

model = xrv.models.ResNet(weights="resnet50-res512-all")
model.op_threshs = None # prevent pre-trained model calibration
model.model.fc = torch.nn.Linear(2048,1) # reinitialize classifier

optimizer = torch.optim.Adam(model.model.fc.parameters()) # only train classifier
criterion = torch.nn.BCEWithLogitsLoss()

I tested the above code and it seems to train correctly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants