Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor LightningDataParallel #5670

Merged
merged 44 commits into from
Jan 31, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
65b971c
module
awaelchli Jan 27, 2021
4568a86
fix model access
awaelchli Jan 27, 2021
eb05f02
scalar conversion
awaelchli Jan 27, 2021
ed6874b
refactor
awaelchli Jan 27, 2021
f4519c0
kwargs
awaelchli Jan 27, 2021
ebe8b73
auto unsqueeze
awaelchli Jan 27, 2021
71ec469
refactor code duplication
awaelchli Jan 27, 2021
fbf3f32
clean up
awaelchli Jan 27, 2021
fb7d6ce
docs
awaelchli Jan 27, 2021
1f5c89a
update dp docs
awaelchli Jan 27, 2021
e96b218
changelog
awaelchli Jan 27, 2021
237b4d8
generalize test
awaelchli Jan 27, 2021
795f9f0
test
awaelchli Jan 27, 2021
8e2792e
rename
awaelchli Jan 27, 2021
d3cbdc4
warning cache
awaelchli Jan 27, 2021
e4bd878
isort
awaelchli Jan 27, 2021
099a64e
unsqueezing test
awaelchli Jan 27, 2021
b2deeab
device
awaelchli Jan 27, 2021
04d5c61
device
awaelchli Jan 27, 2021
74e67da
scalar test
awaelchli Jan 27, 2021
692a4f2
device
awaelchli Jan 27, 2021
a7347c0
device
awaelchli Jan 27, 2021
18a2dce
include coverage of overrides
awaelchli Jan 27, 2021
2e7b43c
clear
awaelchli Jan 27, 2021
b5db753
add deprecation test
awaelchli Jan 27, 2021
599bb5a
docs
awaelchli Jan 27, 2021
73237c9
improve coverage
awaelchli Jan 27, 2021
d016a34
increase coverage
awaelchli Jan 28, 2021
c3a21f4
Merge branch 'release/1.2-dev' into refactor/lightning-dp
awaelchli Jan 28, 2021
dc9a802
fix merge
awaelchli Jan 28, 2021
0fb7c51
extend test
awaelchli Jan 28, 2021
64ca198
Branch was auto-updated.
github-actions[bot] Jan 28, 2021
a0593da
rename base class
awaelchli Jan 28, 2021
8efc012
mention the predict method in docs
awaelchli Jan 28, 2021
8b92f8e
combine iteration over collection
awaelchli Jan 29, 2021
488c508
remove override
awaelchli Jan 29, 2021
43878b4
move
awaelchli Jan 29, 2021
6d1e5f6
line
awaelchli Jan 29, 2021
8655c8c
Merge branch 'release/1.2-dev' into refactor/lightning-dp
awaelchli Jan 29, 2021
238e0f2
Apply suggestions from code review
Borda Jan 29, 2021
c66ae77
fix running stage
awaelchli Jan 31, 2021
71774fe
Merge branch 'release/1.2-dev' into refactor/lightning-dp
awaelchli Jan 31, 2021
38f837b
f401
awaelchli Jan 31, 2021
6a22d1d
fix cyclic import
awaelchli Jan 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved accelerators and plugins to its `legacy` pkg ([#5645](https://github.com/PyTorchLightning/pytorch-lightning/pull/5645))


- Deprecated `LightningDistributedDataParallel` in favor of new wrapper module `LightningDistributedModule` ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))


- Deprecated `LightningDataParallel` in favor of new wrapper module `LightningParallelModule` ([#5670](https://github.com/PyTorchLightning/pytorch-lightning/pull/5670))


### Removed

- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/accelerators/legacy/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -74,7 +74,7 @@ def __init_torch_data_parallel(self, model):

# set dp device
torch.cuda.set_device(self.trainer.root_gpu)
model = LightningDataParallel(model, device_ids=device_ids)
model = torch.nn.DataParallel(LightningParallelModule(model), device_ids=device_ids)
return model

def __init_half_precision(self, model):
Expand Down Expand Up @@ -181,8 +181,10 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
scheduler.load_state_dict(state)

def get_reference_model(self, model) -> LightningModule:
if isinstance(model, LightningDataParallel):
return model.module
if isinstance(model, torch.nn.DataParallel):
model = model.module
if isinstance(model, LightningParallelModule):
model = model.module
awaelchli marked this conversation as resolved.
Show resolved Hide resolved
return model

@property
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/overrides/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401
from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401
63 changes: 63 additions & 0 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class _LightningModuleWrapperBase(torch.nn.Module):
Borda marked this conversation as resolved.
Show resolved Hide resolved

def __init__(self, pl_module: LightningModule):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ``test_step``.
If the LightningModule is in none of the states `training`, `testing` or `validation`,
the inputs will be redirected to the
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method.
Inheriting classes may also modify the inputs or outputs of forward.

Args:
pl_module: the model to wrap
"""
super().__init__()
self.module = pl_module

def forward(self, *inputs, **kwargs):
running_stage = self.module.running_stage

if running_stage == RunningStage.TRAINING:
output = self.module.training_step(*inputs, **kwargs)
warn_if_output_is_none(output, "training_step")
elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
elif running_stage == RunningStage.EVALUATING:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
else:
output = self.module.predict(*inputs, **kwargs)

return output


def warn_if_output_is_none(output: Any, method_name: str) -> None:
""" Warns user about which method returned None. """
if output is None:
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')
Loading