Loading from checkpoints re-downloads pre-trained BERT model #9236
-
I am defining a simple multi-class BERT classification model and then training it using pytorch-lightning. The code is in https://colab.research.google.com/drive/1os9mz7w7gmLBL_ZDvZ9K1saz9UA3rmD7?usp=sharing under class BertForMulticlassSequenceClassification(BertPreTrainedModel). The issue is that after training when I am loading the classifier model model = ClassTaggerModel.load_from_checkpoint(checkpoint_file) I get
The reason is probably because |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment
-
It's because lightning instantiates the LightningModel and then loads the weights using load_from_checkpoint and since you have HFModel.from_pretrained in the init it will load the pretrained weights every time. There is a way around for this. class HFLightningModule(LightningModule):
def __init__(self, ..., model_name=None)
if model_name is not None:
self.model = HFModel.from_pretrained(model_name, ...)
else:
self.model = HFModel(config, num_classes)
model = HFLightningModule(..., model_name='bert-base-cased')
trainer.fit(model, ...)
model = HFLightningModule.load_from_checkpoint(...) Although there might be a better solution. |
Beta Was this translation helpful? Give feedback.
It's because lightning instantiates the LightningModel and then loads the weights using load_from_checkpoint and since you have HFModel.from_pretrained in the init it will load the pretrained weights every time. There is a way around for this.
Although there might be a better solution.