Skip to content

Commit

Permalink
Merge pull request #18 from Genentech/map-location
Browse files Browse the repository at this point in the history
added map_location arg in load_model to load model onto any device
  • Loading branch information
avantikalal authored Jul 17, 2024
2 parents b7336fa + cba15f0 commit f349d29
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions src/grelu/resources/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import importlib_resources
from tempfile import TemporaryDirectory
from pathlib import Path
from typing import Optional, List, Dict, Any
from typing import Optional, List, Dict, Any, Union

import wandb
from grelu.lightning import LightningModel
Expand Down Expand Up @@ -227,13 +227,16 @@ def get_model_by_dataset(dataset_name:str, project:str, host:str=DEFAULT_WANDB_H
return [x.name for x in runs[0].logged_artifacts()]


def load_model(project:str, model_name:str, host:str=DEFAULT_WANDB_HOST, alias:str='latest', checkpoint_file:str='model.ckpt') -> LightningModel:
def load_model(
project:str, model_name:str, device:Union[str, int]='cpu', host:str=DEFAULT_WANDB_HOST, alias:str='latest', checkpoint_file:str='model.ckpt'
) -> LightningModel:
"""
Download and load a model from the model zoo
Args:
project: Name of the project containing the model
model_name: Name of the model
device: Device index on which to load the model.
host: URL of the Weights & Biases host
alias: Alias of the model artifact
checkpoint_file: Name of the checkpoint file contained in the model artifact
Expand All @@ -245,6 +248,6 @@ def load_model(project:str, model_name:str, host:str=DEFAULT_WANDB_HOST, alias:s

with TemporaryDirectory() as d:
art.download(d)
model = LightningModel.load_from_checkpoint(Path(d) / checkpoint_file)
model = LightningModel.load_from_checkpoint(Path(d) / checkpoint_file, map_location=device)

return model

0 comments on commit f349d29

Please sign in to comment.