Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ModelPruning(make_pruning_permanent=True) buffers getting removed when saved during training #6073

Merged
merged 11 commits into from
Mar 3, 2021
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Added


- Add `checkpoint` parameter to the callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


### Changed


Expand All @@ -22,6 +25,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
### Fixed


- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))
carmocca marked this conversation as resolved.
Show resolved Hide resolved


## [1.2.0] - 2021-02-18

### Added
Expand Down
24 changes: 19 additions & 5 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
"""

import abc
from typing import Any
from typing import Any, Dict

from pytorch_lightning.core.lightning import LightningModule

Expand Down Expand Up @@ -177,12 +177,26 @@ def on_keyboard_interrupt(self, trainer, pl_module: LightningModule) -> None:
"""Called when the training is interrupted by ``KeyboardInterrupt``."""
pass

def on_save_checkpoint(self, trainer, pl_module: LightningModule) -> None:
"""Called when saving a model checkpoint, use to persist state."""
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]) -> dict:
"""
Called when saving a model checkpoint, use to persist state.

Args:
trainer: the current Trainer instance.
pl_module: the current LightningModule instance.
checkpoint: the checkpoint dictionary that will be saved.

Returns:
The callback state.
"""
pass

def on_load_checkpoint(self, checkpointed_state) -> None:
"""Called when loading a model checkpoint, use to reload state."""
def on_load_checkpoint(self, callback_state: Dict[str, Any]) -> None:
"""Called when loading a model checkpoint, use to reload state.

Args:
callback_state: the callback state returned by ``on_save_checkpoint``.
"""
pass

def on_after_backward(self, trainer, pl_module: LightningModule) -> None:
Expand Down
13 changes: 7 additions & 6 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
Monitor a metric and stop training when it stops improving.

"""
from typing import Any, Dict

import numpy as np
import torch
Expand Down Expand Up @@ -140,19 +141,19 @@ def _validate_condition_metric(self, logs):
def monitor_op(self):
return self.mode_dict[self.mode]

def on_save_checkpoint(self, trainer, pl_module):
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
'wait_count': self.wait_count,
'stopped_epoch': self.stopped_epoch,
'best_score': self.best_score,
'patience': self.patience
}

def on_load_checkpoint(self, checkpointed_state):
self.wait_count = checkpointed_state['wait_count']
self.stopped_epoch = checkpointed_state['stopped_epoch']
self.best_score = checkpointed_state['best_score']
self.patience = checkpointed_state['patience']
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
self.wait_count = callback_state['wait_count']
self.stopped_epoch = callback_state['stopped_epoch']
self.best_score = callback_state['best_score']
self.patience = callback_state['patience']

def on_validation_end(self, trainer, pl_module):
if trainer.running_sanity_check:
Expand Down
8 changes: 4 additions & 4 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ def on_validation_end(self, trainer, pl_module):
"""
self.save_checkpoint(trainer, pl_module)

def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
def on_save_checkpoint(self, trainer, pl_module, checkpoint: Dict[str, Any]) -> Dict[str, Any]:
return {
"monitor": self.monitor,
"best_model_score": self.best_model_score,
Expand All @@ -220,9 +220,9 @@ def on_save_checkpoint(self, trainer, pl_module) -> Dict[str, Any]:
"dirpath": self.dirpath
}

def on_load_checkpoint(self, checkpointed_state: Dict[str, Any]):
self.best_model_score = checkpointed_state["best_model_score"]
self.best_model_path = checkpointed_state["best_model_path"]
def on_load_checkpoint(self, callback_state: Dict[str, Any]):
self.best_model_score = callback_state["best_model_score"]
self.best_model_path = callback_state["best_model_path"]

def save_checkpoint(self, trainer, pl_module):
"""
Expand Down
41 changes: 26 additions & 15 deletions pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
from copy import deepcopy
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union, Dict

import torch
import torch.nn.utils.prune as pytorch_prune
Expand All @@ -27,7 +27,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_debug
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_PYTORCH_PRUNING_FUNCTIONS = {
Expand Down Expand Up @@ -246,14 +246,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor
def _wrap_pruning_fn(pruning_fn, **kwargs):
return partial(pruning_fn, **kwargs)

def make_pruning_permanent(self):
""" Makes ``parameters_to_prune`` current pruning permanent. """
for module, param_name in self._parameters_to_prune:
try:
pytorch_prune.remove(module, param_name)
except ValueError:
# pruning already made permanent
pass
def make_pruning_permanent(self, pl_module: LightningModule):
"""
Removes pruning buffers from any pruned modules

Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
"""
for _, module in pl_module.named_modules():
for k in list(module._forward_pre_hooks):
hook = module._forward_pre_hooks[k]
if isinstance(hook, pytorch_prune.BasePruningMethod):
hook.remove(module)
del module._forward_pre_hooks[k]

def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
trained = getattr(module, tensor_name)
Expand Down Expand Up @@ -367,7 +371,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module):
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
self._original_layers[id_]["names"].append((i, name))

def on_train_epoch_end(self, trainer, pl_module, *args):
def on_train_epoch_end(self, trainer, pl_module, outputs):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
current_epoch = trainer.current_epoch
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
Expand All @@ -381,13 +385,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args):
):
self.apply_lottery_ticket_hypothesis()

def on_train_end(self, *args):
def on_train_end(self, trainer, pl_module: LightningModule):
if self._make_pruning_permanent:
self.make_pruning_permanent()
rank_zero_debug("`on_train_end`. Pruning permanently...")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
self.make_pruning_permanent(pl_module)

def on_save_checkpoint(self, *args):
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]):
if self._make_pruning_permanent:
self.make_pruning_permanent()
rank_zero_debug("`on_save_checkpoint`. Pruning permanently...")
carmocca marked this conversation as resolved.
Show resolved Hide resolved
prev_device = pl_module.device
# prune a copy so training can continue with the same buffers
copy = deepcopy(pl_module.to("cpu"))
self.make_pruning_permanent(copy)
checkpoint["state_dict"] = copy.state_dict()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
pl_module.to(prev_device)

@staticmethod
def sanitize_parameters_to_prune(
Expand Down
9 changes: 4 additions & 5 deletions pytorch_lightning/trainer/callback_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

from abc import ABC
from copy import deepcopy
from typing import List
from typing import List, Dict, Any, Type

from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.lightning import LightningModule
Expand Down Expand Up @@ -197,14 +197,13 @@ def on_keyboard_interrupt(self):
for callback in self.callbacks:
callback.on_keyboard_interrupt(self, self.lightning_module)

def on_save_checkpoint(self):
def on_save_checkpoint(self, checkpoint: Dict[str, Any]) -> Dict[Type, dict]:
"""Called when saving a model checkpoint."""
callback_states = {}
for callback in self.callbacks:
callback_class = type(callback)
state = callback.on_save_checkpoint(self, self.lightning_module)
state = callback.on_save_checkpoint(self, self.lightning_module, checkpoint)
if state:
callback_states[callback_class] = state
callback_states[type(callback)] = state
return callback_states

def on_load_checkpoint(self, checkpoint):
Expand Down
14 changes: 5 additions & 9 deletions pytorch_lightning/trainer/connectors/checkpoint_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,17 +270,18 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
if not has_reached_max_steps:
current_epoch += 1

model = self.trainer.lightning_module

checkpoint = {
'epoch': current_epoch,
'global_step': global_step,
'pytorch-lightning_version': pytorch_lightning.__version__,
'state_dict': model.state_dict(),
}

if not weights_only:

# dump callbacks
callback_states = self.trainer.on_save_checkpoint()
checkpoint['callbacks'] = callback_states
checkpoint['callbacks'] = self.trainer.on_save_checkpoint(checkpoint)

optimizer_states = []
for i, optimizer in enumerate(self.trainer.optimizers):
Expand All @@ -305,12 +306,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
elif self.trainer.amp_backend == AMPType.APEX:
checkpoint['amp_scaling_state'] = amp.state_dict()

# add the hyper_parameters and state_dict from the model
model = self.trainer.lightning_module

# dump the module_arguments and state_dict from the model
checkpoint['state_dict'] = model.state_dict()

# dump hyper-parameters
if model.hparams:
if hasattr(model, '_hparams_name'):
checkpoint[LightningModule.CHECKPOINT_HYPER_PARAMS_NAME] = model._hparams_name
Expand Down
52 changes: 46 additions & 6 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import platform
from collections import OrderedDict
from logging import INFO
from unittest import mock

import pytest
import torch
Expand All @@ -24,7 +23,7 @@
from torch.nn import Sequential

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelPruning
from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel

Expand All @@ -42,6 +41,10 @@ def __init__(self):
])
)

def training_step(self, batch, batch_idx):
self.log("test", -batch_idx)
return super().training_step(batch, batch_idx)


class TestPruningMethod(pytorch_prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
Expand Down Expand Up @@ -219,7 +222,6 @@ def apply_lottery_ticket_hypothesis(self):


@pytest.mark.parametrize("make_pruning_permanent", (False, True))
@mock.patch.dict(os.environ, {}, clear=True)
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
seed_everything(0)
model = TestModel()
Expand All @@ -244,8 +246,9 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
with caplog.at_level(INFO):
trainer.fit(model)

actual = [m.strip() for m in caplog.messages[-9:]]
expected = [
actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501
Expand All @@ -256,11 +259,48 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501
]
assert actual == expected

filepath = str(tmpdir / "foo.ckpt")
trainer.save_checkpoint(filepath)

model.load_from_checkpoint(filepath, strict=False)
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
assert not has_pruning if make_pruning_permanent else has_pruning


def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
seed_everything(0)

class TestPruning(ModelPruning):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
super().on_save_checkpoint(trainer, pl_module, checkpoint)
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
assert hasattr(pl_module.layer.mlp_3, "weight_orig")

model = TestModel()
pruning_callback = TestPruning(
"random_unstructured",
parameters_to_prune=[(model.layer.mlp_3, "weight")],
verbose=1,
make_pruning_permanent=True
)
ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True)
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0)
with caplog.at_level(INFO):
trainer.fit(model)

actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)",
"Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)",
"Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)",
]

# removed on_train_end
assert not hasattr(model.layer.mlp_3, "weight_orig")

model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")
model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")