Skip to content

Commit

Permalink
Specify map location when loading model (#272)
Browse files Browse the repository at this point in the history
* specify map location when loading model

* more map_location

* use self.device instead

* cleanup

* specify cpu

* put back hparams

* remove comment

* remove newline

* simplfications

* cleanup
  • Loading branch information
ejm714 authored May 12, 2023
1 parent 1ebcd41 commit ee693e7
Show file tree
Hide file tree
Showing 5 changed files with 13 additions and 14 deletions.
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def __init__(
backbone = torch.nn.Linear(num_frames, num_hidden)
torch.nn.init.ones_(backbone.weight)
else:
backbone = self.load_from_checkpoint(finetune_from).backbone
backbone = self.from_disk(finetune_from).backbone

for param in backbone.parameters():
param.requires_grad = False
Expand Down
7 changes: 5 additions & 2 deletions zamba/models/efficientnet_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,18 @@ class TimeDistributedEfficientNet(ZambaVideoClassificationLightningModule):
)

def __init__(
self, num_frames=16, finetune_from: Optional[Union[os.PathLike, str]] = None, **kwargs
self,
num_frames=16,
finetune_from: Optional[Union[os.PathLike, str]] = None,
**kwargs,
):
super().__init__(**kwargs)

if finetune_from is None:
efficientnet = timm.create_model("efficientnetv2_rw_m", pretrained=True)
efficientnet.classifier = nn.Identity()
else:
efficientnet = self.load_from_checkpoint(finetune_from).base.module
efficientnet = self.from_disk(finetune_from).base.module

# freeze base layers
for param in efficientnet.parameters():
Expand Down
11 changes: 3 additions & 8 deletions zamba/models/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def instantiate_model(
Only used if labels is not None.
model_name (ModelEnum, optional): Model name used to look up default hparams used for that model.
Only relevant if training from scratch.
use_default_model_labels(bool, optional): Whether to output the full set of default model labels rather than
use_default_model_labels (bool, optional): Whether to output the full set of default model labels rather than
just the species in the labels file. Only used if labels is not None.
Returns:
Expand All @@ -78,9 +78,8 @@ def instantiate_model(

# predicting
if labels is None:
# predict; load from checkpoint uses associated hparams
logger.info("Loading from checkpoint.")
model = model_class.load_from_checkpoint(checkpoint_path=checkpoint)
model = model_class.from_disk(path=checkpoint, **hparams)
return model

# get species from labels file
Expand Down Expand Up @@ -110,10 +109,8 @@ def instantiate_model(
return resume_training(
scheduler_config=scheduler_config,
hparams=hparams,
species=species,
model_class=model_class,
checkpoint=checkpoint,
labels=labels,
)

else:
Expand Down Expand Up @@ -157,10 +154,8 @@ def replace_head(scheduler_config, hparams, species, model_class, checkpoint):
def resume_training(
scheduler_config,
hparams,
species,
model_class,
checkpoint,
labels,
):
# resume training; add additional species columns to labels file if needed
logger.info(
Expand All @@ -170,7 +165,7 @@ def resume_training(
if scheduler_config != "default":
hparams.update(scheduler_config.dict())

model = model_class.load_from_checkpoint(checkpoint_path=checkpoint, **hparams)
model = model_class.from_disk(path=checkpoint, **hparams)
log_schedulers(model)
return model

Expand Down
2 changes: 1 addition & 1 deletion zamba/models/slowfast_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(
if finetune_from is None:
self.initialize_from_torchub()
else:
model = self.load_from_checkpoint(finetune_from)
model = self.from_disk(finetune_from)
self._backbone_output_dim = model.head.proj.in_features
self.backbone = model.backbone
self.base = model.base
Expand Down
5 changes: 3 additions & 2 deletions zamba/pytorch_lightning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,5 +303,6 @@ def to_disk(self, path: os.PathLike):
torch.save(checkpoint, path)

@classmethod
def from_disk(cls, path: os.PathLike):
return cls.load_from_checkpoint(path)
def from_disk(cls, path: os.PathLike, **kwargs):
# note: we always load models onto CPU; moving to GPU is handled by `devices` in pl.Trainer
return cls.load_from_checkpoint(path, map_location="cpu", **kwargs)

0 comments on commit ee693e7

Please sign in to comment.