Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trainer fix #84

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
depth=2 for summary and self.trainer attribute
AnFreTh committed Jul 25, 2024
commit be1879d3c09776f9b4a5489b349aa10cfa8c0620
11 changes: 8 additions & 3 deletions mambular/models/sklearn_base_classifier.py
Original file line number Diff line number Diff line change
@@ -9,6 +9,7 @@
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
import numpy as np
from lightning.pytorch.callbacks import ModelSummary


class SklearnBaseClassifier(BaseEstimator):
@@ -367,12 +368,16 @@ def fit(
)

# Initialize the trainer and train the model
trainer = pl.Trainer(
self.trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[
early_stop_callback,
checkpoint_callback,
ModelSummary(max_depth=2),
],
**trainer_kwargs
)
trainer.fit(self.model, self.data_module)
self.trainer.fit(self.model, self.data_module)

best_model_path = checkpoint_callback.best_model_path
if best_model_path:
11 changes: 8 additions & 3 deletions mambular/models/sklearn_base_lss.py
Original file line number Diff line number Diff line change
@@ -31,6 +31,7 @@
PoissonDistribution,
StudentTDistribution,
)
from lightning.pytorch.callbacks import ModelSummary


class SklearnBaseLSS(BaseEstimator):
@@ -409,12 +410,16 @@ def fit(
)

# Initialize the trainer and train the model
trainer = pl.Trainer(
self.trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[
early_stop_callback,
checkpoint_callback,
ModelSummary(max_depth=2),
],
**trainer_kwargs
)
trainer.fit(self.model, self.data_module)
self.trainer.fit(self.model, self.data_module)

best_model_path = checkpoint_callback.best_model_path
if best_model_path:
11 changes: 8 additions & 3 deletions mambular/models/sklearn_base_regressor.py
Original file line number Diff line number Diff line change
@@ -8,6 +8,7 @@
from ..base_models.lightning_wrapper import TaskModel
from ..data_utils.datamodule import MambularDataModule
from ..preprocessing import Preprocessor
from lightning.pytorch.callbacks import ModelSummary


class SklearnBaseRegressor(BaseEstimator):
@@ -356,12 +357,16 @@ def fit(
)

# Initialize the trainer and train the model
trainer = pl.Trainer(
self.trainer = pl.Trainer(
max_epochs=max_epochs,
callbacks=[early_stop_callback, checkpoint_callback],
callbacks=[
early_stop_callback,
checkpoint_callback,
ModelSummary(max_depth=2),
],
**trainer_kwargs
)
trainer.fit(self.model, self.data_module)
self.trainer.fit(self.model, self.data_module)

best_model_path = checkpoint_callback.best_model_path
if best_model_path: