Skip to content

Commit

Permalink
Add post epoch hook
Browse files Browse the repository at this point in the history
  • Loading branch information
asistradition committed May 27, 2024
1 parent a69cfa8 commit 21ea305
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
5 changes: 2 additions & 3 deletions supirfactor_dynamical/datasets/anndata_backed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,10 +137,9 @@ def _load_obs_cat(self, obs_col):
if not isinstance(obs_col, (tuple, list, pd.Index)):
obs_col = [obs_col]

_invalid_cols = [o not in _obs.columns for o in obs_col]
if any(_invalid_cols):
if any([o not in _obs.columns for o in obs_col]):
raise ValueError(
f"Key(s) {_invalid_cols} "
f"Key(s) {[o for o in obs_col if o not in _obs.columns]} "
f"not present in obs: {_obs.columns}"
)

Expand Down
6 changes: 5 additions & 1 deletion supirfactor_dynamical/models/_model_mixins/training_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def train_model(
epochs,
validation_dataloader=None,
loss_function=torch.nn.MSELoss(),
optimizer=None
optimizer=None,
post_epoch_hook=None
):
"""
Train this model
Expand Down Expand Up @@ -188,6 +189,9 @@ def train_model(

self.current_epoch = epoch_num

if post_epoch_hook is not None:
post_epoch_hook(self)

to(self, 'cpu')
self.eval()

Expand Down

0 comments on commit 21ea305

Please sign in to comment.