Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix ModelPruning(make_pruning_permanent=True) buffers getting removed when saved during training #6073

Merged
merged 11 commits into from
Mar 3, 2021
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed multiple early stopping callbacks ([#6197](https://github.com/PyTorchLightning/pytorch-lightning/pull/6197))


- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))


- Fixed incorrect usage of `detach()`, `cpu()`, `to()` ([#6216](https://github.com/PyTorchLightning/pytorch-lightning/pull/6216))


Expand Down
43 changes: 27 additions & 16 deletions pytorch_lightning/callbacks/pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import inspect
from copy import deepcopy
from functools import partial
from typing import Any, Callable, List, Optional, Tuple, Union
from typing import Any, Callable, List, Optional, Tuple, Union, Dict

import torch
import torch.nn.utils.prune as pytorch_prune
Expand All @@ -27,7 +27,7 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_only
from pytorch_lightning.utilities.distributed import rank_zero_only, rank_zero_debug
from pytorch_lightning.utilities.exceptions import MisconfigurationException

_PYTORCH_PRUNING_FUNCTIONS = {
Expand Down Expand Up @@ -246,14 +246,18 @@ def _create_pruning_fn(self, pruning_fn: str, **kwargs) -> Union[Callable, pytor
def _wrap_pruning_fn(pruning_fn, **kwargs):
return partial(pruning_fn, **kwargs)

def make_pruning_permanent(self):
""" Makes ``parameters_to_prune`` current pruning permanent. """
for module, param_name in self._parameters_to_prune:
try:
pytorch_prune.remove(module, param_name)
except ValueError:
# pruning already made permanent
pass
def make_pruning_permanent(self, pl_module: LightningModule):
"""
Removes pruning buffers from any pruned modules

Adapted from https://github.com/pytorch/pytorch/blob/1.7.1/torch/nn/utils/prune.py#L1176-L1180
"""
for _, module in pl_module.named_modules():
for k in list(module._forward_pre_hooks):
hook = module._forward_pre_hooks[k]
if isinstance(hook, pytorch_prune.BasePruningMethod):
hook.remove(module)
del module._forward_pre_hooks[k]

def _restore_original_weights(self, module: nn.Module, orig_module: nn.Module, tensor_name: str):
trained = getattr(module, tensor_name)
Expand Down Expand Up @@ -351,7 +355,7 @@ def _log_sparsity_stats(
f" {curr_mask_zeros} ({curr_mask_zeros / curr_mask_size:.2%})"
)

def on_before_accelerator_backend_setup(self, trainer, pl_module):
def on_before_accelerator_backend_setup(self, trainer, pl_module: LightningModule):
parameters_to_prune = self.sanitize_parameters_to_prune(
pl_module, self._parameters_to_prune, parameter_names=self._parameter_names
)
Expand All @@ -367,7 +371,7 @@ def on_before_accelerator_backend_setup(self, trainer, pl_module):
self._original_layers.setdefault(id_, {"data": deepcopy(module), "names": []})
self._original_layers[id_]["names"].append((i, name))

def on_train_epoch_end(self, trainer, pl_module, *args):
def on_train_epoch_end(self, trainer, pl_module: LightningModule, outputs):
current_epoch = trainer.current_epoch
prune = self._apply_pruning(current_epoch) if isinstance(self._apply_pruning, Callable) else self._apply_pruning
amount = self.amount(current_epoch) if isinstance(self.amount, Callable) else self.amount
Expand All @@ -381,13 +385,20 @@ def on_train_epoch_end(self, trainer, pl_module, *args):
):
self.apply_lottery_ticket_hypothesis()

def on_train_end(self, *args):
def on_train_end(self, trainer, pl_module: LightningModule):
if self._make_pruning_permanent:
self.make_pruning_permanent()
rank_zero_debug("`ModelPruning.on_train_end`. Pruning is made permanent for this checkpoint.")
self.make_pruning_permanent(pl_module)

def on_save_checkpoint(self, *args):
def on_save_checkpoint(self, trainer, pl_module: LightningModule, checkpoint: Dict[str, Any]):
if self._make_pruning_permanent:
self.make_pruning_permanent()
rank_zero_debug("`ModelPruning.on_save_checkpoint`. Pruning is made permanent for this checkpoint.")
prev_device = pl_module.device
# prune a copy so training can continue with the same buffers
copy = deepcopy(pl_module.to("cpu"))
self.make_pruning_permanent(copy)
checkpoint["state_dict"] = copy.state_dict()
carmocca marked this conversation as resolved.
Show resolved Hide resolved
pl_module.to(prev_device)

@staticmethod
def sanitize_parameters_to_prune(
Expand Down
57 changes: 51 additions & 6 deletions tests/callbacks/test_pruning.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import platform
from collections import OrderedDict
from logging import INFO
from unittest import mock

import pytest
import torch
Expand All @@ -24,7 +23,7 @@
from torch.nn import Sequential

from pytorch_lightning import seed_everything, Trainer
from pytorch_lightning.callbacks import ModelPruning
from pytorch_lightning.callbacks import ModelPruning, ModelCheckpoint
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel

Expand All @@ -42,6 +41,10 @@ def __init__(self):
])
)

def training_step(self, batch, batch_idx):
self.log("test", -batch_idx)
return super().training_step(batch, batch_idx)


class TestPruningMethod(pytorch_prune.BasePruningMethod):
PRUNING_TYPE = "unstructured"
Expand Down Expand Up @@ -219,7 +222,6 @@ def apply_lottery_ticket_hypothesis(self):


@pytest.mark.parametrize("make_pruning_permanent", (False, True))
@mock.patch.dict(os.environ, {}, clear=True)
def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
seed_everything(0)
model = TestModel()
Expand All @@ -244,8 +246,9 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
with caplog.at_level(INFO):
trainer.fit(model)

actual = [m.strip() for m in caplog.messages[-9:]]
expected = [
actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `L1Unstructured`. Pruned: 0/1122 (0.00%) -> 544/1122 (48.48%)",
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 506 (49.41%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 0 (0.00%) -> 38 (59.38%)", # noqa: E501
Expand All @@ -256,11 +259,53 @@ def test_multiple_pruning_callbacks(tmpdir, caplog, make_pruning_permanent):
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=32, bias=True).weight` with amount=0.5. Pruned: 633 (61.82%) -> 828 (80.86%)", # noqa: E501
"Applied `L1Unstructured` to `Linear(in_features=32, out_features=2, bias=True).weight` with amount=0.5. Pruned: 47 (73.44%) -> 56 (87.50%)", # noqa: E501
]
assert actual == expected

filepath = str(tmpdir / "foo.ckpt")
trainer.save_checkpoint(filepath)

model.load_from_checkpoint(filepath, strict=False)
has_pruning = hasattr(model.layer.mlp_1, "weight_orig")
assert not has_pruning if make_pruning_permanent else has_pruning


def test_permanent_when_model_is_saved_multiple_times(tmpdir, caplog):
carmocca marked this conversation as resolved.
Show resolved Hide resolved
"""
When a model is saved multiple times and make_permanent=True, we need to
make sure a copy is pruned and not the trained model if we want to continue
with the same pruning buffers.
"""
seed_everything(0)

class TestPruning(ModelPruning):
def on_save_checkpoint(self, trainer, pl_module, checkpoint):
super().on_save_checkpoint(trainer, pl_module, checkpoint)
assert "layer.mlp_3.weight_orig" not in checkpoint["state_dict"]
assert hasattr(pl_module.layer.mlp_3, "weight_orig")

model = TestModel()
pruning_callback = TestPruning(
"random_unstructured",
parameters_to_prune=[(model.layer.mlp_3, "weight")],
verbose=1,
make_pruning_permanent=True
)
ckpt_callback = ModelCheckpoint(monitor="test", save_top_k=2, save_last=True)
trainer = Trainer(callbacks=[pruning_callback, ckpt_callback], max_epochs=3, progress_bar_refresh_rate=0)
with caplog.at_level(INFO):
trainer.fit(model)

actual = [m.strip() for m in caplog.messages]
actual = [m for m in actual if m.startswith("Applied")]
assert actual == [
"Applied `RandomUnstructured`. Pruned: 0/66 (0.00%) -> 32/66 (48.48%)",
"Applied `RandomUnstructured`. Pruned: 32/66 (48.48%) -> 48/66 (72.73%)",
"Applied `RandomUnstructured`. Pruned: 48/66 (72.73%) -> 56/66 (84.85%)",
]

# removed on_train_end
assert not hasattr(model.layer.mlp_3, "weight_orig")

model.load_from_checkpoint(trainer.checkpoint_callback.kth_best_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")
model.load_from_checkpoint(trainer.checkpoint_callback.last_model_path)
assert not hasattr(model.layer.mlp_3, "weight_orig")