Skip to content

Commit

Permalink
FIX Pickling untrained nets with progress bar (#1034)
Browse files Browse the repository at this point in the history
  • Loading branch information
marco-c authored Nov 28, 2023
1 parent 11cc6c3 commit 7f55393
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 1 deletion.
2 changes: 1 addition & 1 deletion skorch/callbacks/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -669,7 +669,7 @@ def __getstate__(self):
# don't save away the temporary pbar_ object which gets created on
# epoch begin anew anyway. This avoids pickling errors with tqdm.
state = self.__dict__.copy()
del state['pbar_']
state.pop('pbar_', None)
return state


Expand Down
12 changes: 12 additions & 0 deletions skorch/tests/callbacks/test_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -770,6 +770,18 @@ def test_pickle(self, net_cls, progressbar_cls, data):
net = pickle.loads(dump)
net.fit(*data)

def test_pickle_without_fit(self, net_cls, progressbar_cls, data):
# pickling should work even if the net hasn't been fit.
# see https://github.com/skorch-dev/skorch/pull/1034.
import pickle

net = net_cls(callbacks=[
progressbar_cls(),
])
dump = pickle.dumps(net)

net = pickle.loads(dump)


@pytest.mark.skipif(
not tensorboard_installed, reason='tensorboard is not installed')
Expand Down

0 comments on commit 7f55393

Please sign in to comment.