Skip to content
Discussion options

You must be logged in to vote

It's because lightning instantiates the LightningModel and then loads the weights using load_from_checkpoint and since you have HFModel.from_pretrained in the init it will load the pretrained weights every time. There is a way around for this.

class HFLightningModule(LightningModule):
    def __init__(self, ..., model_name=None)
        if model_name is not None:
            self.model = HFModel.from_pretrained(model_name, ...)
        else:
            self.model = HFModel(config, num_classes)


model = HFLightningModule(..., model_name='bert-base-cased')
trainer.fit(model, ...)

model = HFLightningModule.load_from_checkpoint(...)

Although there might be a better solution.

Replies: 1 comment

Comment options

You must be logged in to vote
0 replies
Answer selected by sivakhno
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment