-
Notifications
You must be signed in to change notification settings - Fork 45
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #256 from neuronets/ohinds-model-checkpointing
Processing model checkpointing
- Loading branch information
Showing
4 changed files
with
187 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
"""Checkpointing utils""" | ||
|
||
from glob import glob | ||
import logging | ||
import os | ||
|
||
import tensorflow as tf | ||
|
||
|
||
class CheckpointTracker(tf.keras.callbacks.ModelCheckpoint): | ||
"""Class for saving/loading estimators at/from checkpoints.""" | ||
|
||
def __init__(self, estimator, file_path, **kwargs): | ||
""" | ||
estimator: BaseEstimator, instance of an estimator (e.g., Segmentation). | ||
file_path: str, directory to/from which to save or load. | ||
""" | ||
self.estimator = estimator | ||
super().__init__(file_path, **kwargs) | ||
|
||
def _save_model(self, epoch, batch, logs): | ||
"""Save the current state of the estimator. This overrides the | ||
base class implementation to save `nobrainer` specific info. | ||
epoch: int, the index of the epoch that just finished. | ||
batch: int, the index of the batch that just finished. | ||
logs: dict, logging info passed into on_epoch_end or on_batch_end. | ||
""" | ||
self.save(self._get_file_path(epoch, batch, logs)) | ||
|
||
def save(self, directory): | ||
"""Save the current state of the estimator. | ||
directory: str, path in which to save the model. | ||
""" | ||
logging.info(f"Saving to dir {directory}") | ||
self.estimator.save(directory) | ||
|
||
def load(self): | ||
"""Loads the most-recently created checkpoint from the | ||
checkpoint directory. | ||
""" | ||
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/")) | ||
if not checkpoints: | ||
return None | ||
|
||
latest = max(checkpoints, key=os.path.getctime) | ||
self.estimator = self.estimator.load(latest) | ||
logging.info(f"Loaded estimator from {latest}.") | ||
return self.estimator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
"""Tests for `nobrainer.processing.checkpoint`.""" | ||
|
||
import os | ||
|
||
import numpy as np | ||
from numpy.testing import assert_allclose | ||
import tensorflow as tf | ||
|
||
from nobrainer.models import meshnet | ||
from nobrainer.processing.segmentation import Segmentation | ||
|
||
|
||
def _get_toy_dataset(): | ||
data_shape = (8, 8, 8, 8, 1) | ||
train = tf.data.Dataset.from_tensors( | ||
(np.random.rand(*data_shape), np.random.randint(0, 1, data_shape)) | ||
) | ||
train.scalar_labels = False | ||
train.n_volumes = data_shape[0] | ||
train.volume_shape = data_shape[1:4] | ||
return train | ||
|
||
|
||
def _assert_model_weights_allclose(model1, model2): | ||
for layer1, layer2 in zip(model1.model.layers, model2.model.layers): | ||
weights1 = layer1.get_weights() | ||
weights2 = layer2.get_weights() | ||
assert len(weights1) == len(weights2) | ||
for index in range(len(weights1)): | ||
assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08) | ||
|
||
|
||
def test_checkpoint(tmp_path): | ||
train = _get_toy_dataset() | ||
|
||
checkpoint_filepath = os.path.join(tmp_path, "checkpoint-epoch_{epoch:03d}") | ||
model1 = Segmentation.init_with_checkpoints( | ||
meshnet, | ||
checkpoint_filepath=checkpoint_filepath, | ||
) | ||
model1.fit( | ||
dataset_train=train, | ||
epochs=2, | ||
) | ||
|
||
model2 = Segmentation.init_with_checkpoints( | ||
meshnet, | ||
checkpoint_filepath=checkpoint_filepath, | ||
) | ||
_assert_model_weights_allclose(model1, model2) | ||
model2.fit( | ||
dataset_train=train, | ||
epochs=3, | ||
) | ||
|
||
model3 = Segmentation.init_with_checkpoints( | ||
meshnet, | ||
checkpoint_filepath=checkpoint_filepath, | ||
) | ||
_assert_model_weights_allclose(model2, model3) | ||
|
||
|
||
def test_warm_start_workflow(tmp_path): | ||
train = _get_toy_dataset() | ||
|
||
checkpoint_dir = os.path.join(tmp_path, "checkpoints") | ||
checkpoint_filepath = os.path.join(checkpoint_dir, "{epoch:03d}") | ||
if not os.path.exists(checkpoint_dir): | ||
os.mkdir(checkpoint_dir) | ||
|
||
for iteration in range(2): | ||
bem = Segmentation.init_with_checkpoints( | ||
meshnet, | ||
checkpoint_filepath=checkpoint_filepath, | ||
) | ||
if iteration == 0: | ||
assert bem.model is None | ||
else: | ||
assert bem.model is not None | ||
for layer in bem.model.layers: | ||
for weight_array in layer.get_weights(): | ||
assert np.count_nonzero(weight_array) | ||
bem.fit( | ||
dataset_train=train, | ||
epochs=2, | ||
) |