diff --git a/supirfactor_dynamical/datasets/anndata_backed_dataset.py b/supirfactor_dynamical/datasets/anndata_backed_dataset.py index 51c3e04..cc01548 100644 --- a/supirfactor_dynamical/datasets/anndata_backed_dataset.py +++ b/supirfactor_dynamical/datasets/anndata_backed_dataset.py @@ -137,10 +137,9 @@ def _load_obs_cat(self, obs_col): if not isinstance(obs_col, (tuple, list, pd.Index)): obs_col = [obs_col] - _invalid_cols = [o not in _obs.columns for o in obs_col] - if any(_invalid_cols): + if any([o not in _obs.columns for o in obs_col]): raise ValueError( - f"Key(s) {_invalid_cols} " + f"Key(s) {[o for o in obs_col if o not in _obs.columns]} " f"not present in obs: {_obs.columns}" ) diff --git a/supirfactor_dynamical/models/_model_mixins/training_mixin.py b/supirfactor_dynamical/models/_model_mixins/training_mixin.py index f9c9cf6..c10b187 100644 --- a/supirfactor_dynamical/models/_model_mixins/training_mixin.py +++ b/supirfactor_dynamical/models/_model_mixins/training_mixin.py @@ -103,7 +103,8 @@ def train_model( epochs, validation_dataloader=None, loss_function=torch.nn.MSELoss(), - optimizer=None + optimizer=None, + post_epoch_hook=None ): """ Train this model @@ -188,6 +189,9 @@ def train_model( self.current_epoch = epoch_num + if post_epoch_hook is not None: + post_epoch_hook(self) + to(self, 'cpu') self.eval()