Skip to content

Commit

Permalink
Refactor LightningDataParallel (#5670)
Browse files Browse the repository at this point in the history
* module

* fix model access

* scalar conversion

* refactor

* kwargs

* auto unsqueeze

* refactor code duplication

* clean up

* docs

* update dp docs

* changelog

* generalize test

* test

* rename

* warning cache

* isort

* unsqueezing test

* device

* device

* scalar test

* device

* device

* include coverage of overrides

* clear

* add deprecation test

* docs

* improve coverage

* increase coverage

* fix merge

* extend test

* rename base class

* mention the predict method in docs

* combine iteration over collection

* remove override

* move

* line

* Apply suggestions from code review

* fix running stage

* f401

* fix cyclic import

Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
Co-authored-by: Jirka Borovec <Borda@users.noreply.github.com>
  • Loading branch information
3 people authored Jan 31, 2021
1 parent 5d239cc commit 692f77b
Show file tree
Hide file tree
Showing 12 changed files with 349 additions and 349 deletions.
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Moved accelerators and plugins to its `legacy` pkg ([#5645](https://github.com/PyTorchLightning/pytorch-lightning/pull/5645))


- Deprecated `LightningDistributedDataParallel` in favor of new wrapper module `LightningDistributedModule` ([#5185](https://github.com/PyTorchLightning/pytorch-lightning/pull/5185))


- Deprecated `LightningDataParallel` in favor of new wrapper module `LightningParallelModule` ([#5670](https://github.com/PyTorchLightning/pytorch-lightning/pull/5670))


### Removed

- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))
Expand Down
10 changes: 6 additions & 4 deletions pytorch_lightning/accelerators/legacy/dp_accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.distributed import LightningDistributed
from pytorch_lightning.overrides.data_parallel import LightningDataParallel
from pytorch_lightning.overrides.data_parallel import LightningParallelModule
from pytorch_lightning.utilities import AMPType
from pytorch_lightning.utilities.exceptions import MisconfigurationException

Expand Down Expand Up @@ -74,7 +74,7 @@ def __init_torch_data_parallel(self, model):

# set dp device
torch.cuda.set_device(self.trainer.root_gpu)
model = LightningDataParallel(model, device_ids=device_ids)
model = torch.nn.DataParallel(LightningParallelModule(model), device_ids=device_ids)
return model

def __init_half_precision(self, model):
Expand Down Expand Up @@ -181,8 +181,10 @@ def reinit_scheduler_properties(self, optimizers: list, schedulers: list):
scheduler.load_state_dict(state)

def get_reference_model(self, model) -> LightningModule:
if isinstance(model, LightningDataParallel):
return model.module
if isinstance(model, torch.nn.DataParallel):
model = model.module
if isinstance(model, LightningParallelModule):
model = model.module
return model

@property
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/overrides/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from pytorch_lightning.overrides.data_parallel import LightningParallelModule # noqa: F401
from pytorch_lightning.overrides.distributed import LightningDistributedModule # noqa: F401
63 changes: 63 additions & 0 deletions pytorch_lightning/overrides/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any

import torch

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.trainer.states import RunningStage
from pytorch_lightning.utilities.warnings import WarningCache

warning_cache = WarningCache()


class _LightningModuleWrapperBase(torch.nn.Module):

def __init__(self, pl_module: LightningModule):
"""
Wraps the user's LightningModule and redirects the forward call to the appropriate
method, either ``training_step``, ``validation_step`` or ``test_step``.
If the LightningModule is in none of the states `training`, `testing` or `validation`,
the inputs will be redirected to the
:meth:`~pytorch_lightning.core.lightning.LightningModule.predict` method.
Inheriting classes may also modify the inputs or outputs of forward.
Args:
pl_module: the model to wrap
"""
super().__init__()
self.module = pl_module

def forward(self, *inputs, **kwargs):
running_stage = self.module.running_stage

if running_stage == RunningStage.TRAINING:
output = self.module.training_step(*inputs, **kwargs)
warn_if_output_is_none(output, "training_step")
elif running_stage == RunningStage.TESTING:
output = self.module.test_step(*inputs, **kwargs)
warn_if_output_is_none(output, "test_step")
elif running_stage == RunningStage.EVALUATING:
output = self.module.validation_step(*inputs, **kwargs)
warn_if_output_is_none(output, "validation_step")
else:
output = self.module.predict(*inputs, **kwargs)

return output


def warn_if_output_is_none(output: Any, method_name: str) -> None:
""" Warns user about which method returned None. """
if output is None:
warning_cache.warn(f'Your {method_name} returned None. Did you forget to return an output?')
Loading

0 comments on commit 692f77b

Please sign in to comment.