Skip to content

Commit

Permalink
Merge pull request #256 from neuronets/ohinds-model-checkpointing
Browse files Browse the repository at this point in the history
Processing model checkpointing
  • Loading branch information
satra committed Aug 25, 2023
2 parents 3551c21 + 0185eff commit 99eaf9d
Show file tree
Hide file tree
Showing 4 changed files with 187 additions and 17 deletions.
39 changes: 36 additions & 3 deletions nobrainer/processing/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,13 @@ class BaseEstimator:
state_variables = []
model_ = None

def __init__(self, multi_gpu=False):
def __init__(self, checkpoint_filepath=None, multi_gpu=False):
self.checkpoint_tracker = None
if checkpoint_filepath:
from .checkpoint import CheckpointTracker

self.checkpoint_tracker = CheckpointTracker(self, checkpoint_filepath)

self.strategy = get_strategy(multi_gpu)

@property
Expand All @@ -38,7 +44,7 @@ def save(self, save_dir):
# are stored as members, which doesn't leave room for
# parameters that are specific to the runtime context.
# (e.g. multi_gpu).
if key == "multi_gpu":
if key == "multi_gpu" or key == "checkpoint_filepath":
continue
model_info["__init__"][key] = getattr(self, key)
for val in self.state_variables:
Expand All @@ -49,7 +55,7 @@ def save(self, save_dir):

@classmethod
def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
"""Saves a trained model"""
"""Loads a trained model from a save directory"""
model_dir = Path(str(model_dir).rstrip(os.pathsep))
assert model_dir.exists() and model_dir.is_dir()
model_file = model_dir / "model_params.pkl"
Expand All @@ -70,6 +76,33 @@ def load(cls, model_dir, multi_gpu=False, custom_objects=None, compile=False):
)
return klass

@classmethod
def init_with_checkpoints(cls, model_name, checkpoint_filepath):
"""Initialize a model for training, either from the latest
checkpoint found, or from scratch if no checkpoints are
found. This is useful for long-running model fits that may be
interrupted or preepmted during training and need to pick up
where they left off.
model_name: str or Module in nobrainer.models, the base model
for this estimator.
checkpoint_filepath: str, path to which checkpoints will be
saved and loaded. Supports the epoch and block flormating
parameters supported by tensorflows ModelCheckpoint,
e.g. <path_to_checkpoint_dir>/{epoch:03d}
"""
from .checkpoint import CheckpointTracker

checkpoint_tracker = CheckpointTracker(cls, checkpoint_filepath)
estimator = checkpoint_tracker.load()
if not estimator:
estimator = cls(model_name)
estimator.checkpoint_tracker = checkpoint_tracker
checkpoint_tracker.estimator = estimator
return estimator


class TransformerMixin:
"""Mixin class for all transformers in scikit-learn."""
Expand Down
49 changes: 49 additions & 0 deletions nobrainer/processing/checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
"""Checkpointing utils"""

from glob import glob
import logging
import os

import tensorflow as tf


class CheckpointTracker(tf.keras.callbacks.ModelCheckpoint):
"""Class for saving/loading estimators at/from checkpoints."""

def __init__(self, estimator, file_path, **kwargs):
"""
estimator: BaseEstimator, instance of an estimator (e.g., Segmentation).
file_path: str, directory to/from which to save or load.
"""
self.estimator = estimator
super().__init__(file_path, **kwargs)

def _save_model(self, epoch, batch, logs):
"""Save the current state of the estimator. This overrides the
base class implementation to save `nobrainer` specific info.
epoch: int, the index of the epoch that just finished.
batch: int, the index of the batch that just finished.
logs: dict, logging info passed into on_epoch_end or on_batch_end.
"""
self.save(self._get_file_path(epoch, batch, logs))

def save(self, directory):
"""Save the current state of the estimator.
directory: str, path in which to save the model.
"""
logging.info(f"Saving to dir {directory}")
self.estimator.save(directory)

def load(self):
"""Loads the most-recently created checkpoint from the
checkpoint directory.
"""
checkpoints = glob(os.path.join(os.path.dirname(self.filepath), "*/"))
if not checkpoints:
return None

latest = max(checkpoints, key=os.path.getctime)
self.estimator = self.estimator.load(latest)
logging.info(f"Loaded estimator from {latest}.")
return self.estimator
30 changes: 16 additions & 14 deletions nobrainer/processing/segmentation.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,24 @@
import importlib
import os
import logging

import tensorflow as tf

from .base import BaseEstimator
from .. import losses, metrics
from ..dataset import get_steps_per_epoch

logging.getLogger().setLevel(logging.INFO)


class Segmentation(BaseEstimator):
"""Perform segmentation type operations"""

state_variables = ["block_shape_", "volume_shape_", "scalar_labels_"]

def __init__(self, base_model, model_args=None, multi_gpu=False):
super().__init__(multi_gpu=multi_gpu)
def __init__(
self, base_model, model_args=None, checkpoint_filepath=None, multi_gpu=False
):
super().__init__(checkpoint_filepath=checkpoint_filepath, multi_gpu=multi_gpu)

if not isinstance(base_model, str):
self.base_model = base_model.__name__
Expand All @@ -31,8 +35,6 @@ def fit(
dataset_train,
dataset_validate=None,
epochs=1,
checkpoint_dir=os.getcwd(),
warm_start=False,
# TODO: figure out whether optimizer args should be flattened
optimizer=None,
opt_args=None,
Expand Down Expand Up @@ -73,21 +75,17 @@ def _compile():
metrics=metrics,
)

if warm_start:
if self.model is None:
raise ValueError("warm_start requested, but model is undefined")
with self.strategy.scope():
_compile()
else:
if self.model is None:
mod = importlib.import_module("..models", "nobrainer.processing")
base_model = getattr(mod, self.base_model)
if batch_size % self.strategy.num_replicas_in_sync:
raise ValueError("batch size must be a multiple of the number of GPUs")

with self.strategy.scope():
_create(base_model)
_compile()
print(self.model_.summary())
with self.strategy.scope():
_compile()
self.model_.summary()

train_steps = get_steps_per_epoch(
n_volumes=dataset_train.n_volumes,
Expand All @@ -105,13 +103,17 @@ def _compile():
batch_size=batch_size,
)

# TODO add checkpoint
callbacks = []
if self.checkpoint_tracker:
callbacks.append(self.checkpoint_tracker)

self.model_.fit(
dataset_train,
epochs=epochs,
steps_per_epoch=train_steps,
validation_data=dataset_validate,
validation_steps=evaluate_steps,
callbacks=callbacks,
)

return self
Expand Down
86 changes: 86 additions & 0 deletions nobrainer/tests/checkpoint_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
"""Tests for `nobrainer.processing.checkpoint`."""

import os

import numpy as np
from numpy.testing import assert_allclose
import tensorflow as tf

from nobrainer.models import meshnet
from nobrainer.processing.segmentation import Segmentation


def _get_toy_dataset():
data_shape = (8, 8, 8, 8, 1)
train = tf.data.Dataset.from_tensors(
(np.random.rand(*data_shape), np.random.randint(0, 1, data_shape))
)
train.scalar_labels = False
train.n_volumes = data_shape[0]
train.volume_shape = data_shape[1:4]
return train


def _assert_model_weights_allclose(model1, model2):
for layer1, layer2 in zip(model1.model.layers, model2.model.layers):
weights1 = layer1.get_weights()
weights2 = layer2.get_weights()
assert len(weights1) == len(weights2)
for index in range(len(weights1)):
assert_allclose(weights1[index], weights2[index], rtol=1e-06, atol=1e-08)


def test_checkpoint(tmp_path):
train = _get_toy_dataset()

checkpoint_filepath = os.path.join(tmp_path, "checkpoint-epoch_{epoch:03d}")
model1 = Segmentation.init_with_checkpoints(
meshnet,
checkpoint_filepath=checkpoint_filepath,
)
model1.fit(
dataset_train=train,
epochs=2,
)

model2 = Segmentation.init_with_checkpoints(
meshnet,
checkpoint_filepath=checkpoint_filepath,
)
_assert_model_weights_allclose(model1, model2)
model2.fit(
dataset_train=train,
epochs=3,
)

model3 = Segmentation.init_with_checkpoints(
meshnet,
checkpoint_filepath=checkpoint_filepath,
)
_assert_model_weights_allclose(model2, model3)


def test_warm_start_workflow(tmp_path):
train = _get_toy_dataset()

checkpoint_dir = os.path.join(tmp_path, "checkpoints")
checkpoint_filepath = os.path.join(checkpoint_dir, "{epoch:03d}")
if not os.path.exists(checkpoint_dir):
os.mkdir(checkpoint_dir)

for iteration in range(2):
bem = Segmentation.init_with_checkpoints(
meshnet,
checkpoint_filepath=checkpoint_filepath,
)
if iteration == 0:
assert bem.model is None
else:
assert bem.model is not None
for layer in bem.model.layers:
for weight_array in layer.get_weights():
assert np.count_nonzero(weight_array)
bem.fit(
dataset_train=train,
epochs=2,
)

0 comments on commit 99eaf9d

Please sign in to comment.