Skip to content

Commit

Permalink
checkpoint as a model config parameter for warmup cosine learning rat…
Browse files Browse the repository at this point in the history
…es (#66)

* adding chkpt path optional to configfile to use warmup cosine learning rate

* fixing the loading of weights

* remove additional blank

---------

Co-authored-by: Ziwen Liu <67518483+ziw-liu@users.noreply.github.com>
  • Loading branch information
edyoshikun and ziw-liu committed Jun 12, 2024
1 parent 6c3132c commit d54c278
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 0 deletions.
1 change: 1 addition & 0 deletions examples/configs/fit_example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ model:
loss_function: null
lr: 0.001
schedule: Constant
ckpt_path: null
log_batches_per_epoch: 8
log_samples_per_batch: 1
data:
Expand Down
7 changes: 7 additions & 0 deletions viscy/light/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class VSUNet(LightningModule):
:param float lr: learning rate in training, defaults to 1e-3
:param Literal['WarmupCosine', 'Constant'] schedule:
learning rate scheduler, defaults to "Constant"
:param str chkpt_path: path to the checkpoint to load weights, defaults to None
:param int log_batches_per_epoch:
number of batches to log each training/validation epoch,
has to be smaller than steps per epoch, defaults to 8
Expand All @@ -121,6 +122,7 @@ def __init__(
loss_function: Union[nn.Module, MixedLoss] = None,
lr: float = 1e-3,
schedule: Literal["WarmupCosine", "Constant"] = "Constant",
ckpt_path: str = None,
log_batches_per_epoch: int = 8,
log_samples_per_batch: int = 1,
example_input_yx_shape: Sequence[int] = (256, 256),
Expand Down Expand Up @@ -161,6 +163,11 @@ def __init__(
self.test_cellpose_diameter = test_cellpose_diameter
self.test_evaluate_cellpose = test_evaluate_cellpose

if ckpt_path is not None:
self.load_state_dict(
torch.load(ckpt_path)["state_dict"]
) # loading only weights

def forward(self, x) -> torch.Tensor:
return self.model(x)

Expand Down

0 comments on commit d54c278

Please sign in to comment.