This is a PyTorch implementation of the CheSS paper:
@article{Cho2023,
doi = {10.1007/s10278-023-00782-4},
url = {https://doi.org/10.1007/s10278-023-00782-4},
year = {2023},
month = jan,
publisher = {Springer Science and Business Media {LLC}},
author = {Kyungjin Cho and Ki Duk Kim and Yujin Nam and Jiheon Jeong and Jeeyoung Kim and Changyong Choi and Soyoung Lee and Jun Soo Lee and Seoyeon Woo and Gil-Sun Hong and Joon Beom Seo and Namkug Kim},
title = {{CheSS}: Chest X-Ray Pre-trained Model via Self-supervised Contrastive Learning},
journal = {Journal of Digital Imaging}
}
pip install -r requirements.txt
model = resnet50(num_classes=1000)
pretrained_model = "CheSS pretrained model path"
if pretrained_model is not None:
if os.path.isfile(pretrained_model):
print("=> loading checkpoint '{}'".format(pretrained_model))
checkpoint = torch.load(pretrained_model, map_location="cpu")
# rename moco pre-trained keys
state_dict = checkpoint['state_dict']
for k in list(state_dict.keys()):
# retain only encoder_q up to before the embedding layer
if k.startswith('module.encoder_q') and not k.startswith('module.encoder_q.fc'):
# remove prefix
state_dict[k[len("module.encoder_q."):]] = state_dict[k]
# delete renamed or unused k
del state_dict[k]
msg = model.load_state_dict(state_dict, strict=False)
assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
print("=> loaded pre-trained model '{}'".format(pretrained_model))
else:
print("=> no checkpoint found at '{}'".format(pretrained_model))
##freeze all layers but the last fc
for name, param in model.named_parameters():
if name not in ['fc.weight', 'fc.bias']:
param.requires_grad = False
model.fc = nn.Linear(2048, num_class)
or you can use gdown in Python
!gdown https://drive.google.com/uc?id=1C_Gis2qcZcA9X3l2NEHR1oS4Gn_bTxTe
Page: https://mi2rl.co
Email: kjcho@amc.seoul.kr