diff --git a/mambular/base_models/lightning_wrapper.py b/mambular/base_models/lightning_wrapper.py index dbcd9e7..b26c643 100644 --- a/mambular/base_models/lightning_wrapper.py +++ b/mambular/base_models/lightning_wrapper.py @@ -37,7 +37,7 @@ def __init__( lss=False, family=None, loss_fct: callable = None, - **kwargs + **kwargs, ): super().__init__() self.num_classes = num_classes @@ -300,7 +300,7 @@ def configure_optimizers(self): A dictionary containing the optimizer and lr_scheduler configurations. """ optimizer = torch.optim.Adam( - self.parameters(), + self.model.parameters(), lr=self.lr, weight_decay=self.weight_decay, ) diff --git a/mambular/models/sklearn_base_classifier.py b/mambular/models/sklearn_base_classifier.py index 7373701..21e6c25 100644 --- a/mambular/models/sklearn_base_classifier.py +++ b/mambular/models/sklearn_base_classifier.py @@ -9,6 +9,7 @@ from ..data_utils.datamodule import MambularDataModule from ..preprocessing import Preprocessor import numpy as np +from lightning.pytorch.callbacks import ModelSummary class SklearnBaseClassifier(BaseEstimator): @@ -367,12 +368,16 @@ def fit( ) # Initialize the trainer and train the model - trainer = pl.Trainer( + self.trainer = pl.Trainer( max_epochs=max_epochs, - callbacks=[early_stop_callback, checkpoint_callback], + callbacks=[ + early_stop_callback, + checkpoint_callback, + ModelSummary(max_depth=2), + ], **trainer_kwargs ) - trainer.fit(self.model, self.data_module) + self.trainer.fit(self.model, self.data_module) best_model_path = checkpoint_callback.best_model_path if best_model_path: diff --git a/mambular/models/sklearn_base_lss.py b/mambular/models/sklearn_base_lss.py index 3298cff..5855b2a 100644 --- a/mambular/models/sklearn_base_lss.py +++ b/mambular/models/sklearn_base_lss.py @@ -31,6 +31,7 @@ PoissonDistribution, StudentTDistribution, ) +from lightning.pytorch.callbacks import ModelSummary class SklearnBaseLSS(BaseEstimator): @@ -409,12 +410,16 @@ def fit( ) # Initialize the trainer and train the model - trainer = pl.Trainer( + self.trainer = pl.Trainer( max_epochs=max_epochs, - callbacks=[early_stop_callback, checkpoint_callback], + callbacks=[ + early_stop_callback, + checkpoint_callback, + ModelSummary(max_depth=2), + ], **trainer_kwargs ) - trainer.fit(self.model, self.data_module) + self.trainer.fit(self.model, self.data_module) best_model_path = checkpoint_callback.best_model_path if best_model_path: diff --git a/mambular/models/sklearn_base_regressor.py b/mambular/models/sklearn_base_regressor.py index c909641..19a4560 100644 --- a/mambular/models/sklearn_base_regressor.py +++ b/mambular/models/sklearn_base_regressor.py @@ -8,6 +8,7 @@ from ..base_models.lightning_wrapper import TaskModel from ..data_utils.datamodule import MambularDataModule from ..preprocessing import Preprocessor +from lightning.pytorch.callbacks import ModelSummary class SklearnBaseRegressor(BaseEstimator): @@ -356,12 +357,16 @@ def fit( ) # Initialize the trainer and train the model - trainer = pl.Trainer( + self.trainer = pl.Trainer( max_epochs=max_epochs, - callbacks=[early_stop_callback, checkpoint_callback], + callbacks=[ + early_stop_callback, + checkpoint_callback, + ModelSummary(max_depth=2), + ], **trainer_kwargs ) - trainer.fit(self.model, self.data_module) + self.trainer.fit(self.model, self.data_module) best_model_path = checkpoint_callback.best_model_path if best_model_path: