Skip to content

Commit

Permalink
ref: clean config [1/n] add intermediate setters (#4990)
Browse files Browse the repository at this point in the history
* add intermediate setters

* show inputs

* fix options

* move

* fix

* less talk

* fix

* talk less

* str

* cases

* rename

Co-authored-by: chaton <thomas@grid.ai>
  • Loading branch information
Borda and tchaton authored Dec 9, 2020
1 parent 068502f commit ce91795
Show file tree
Hide file tree
Showing 8 changed files with 228 additions and 27 deletions.
12 changes: 1 addition & 11 deletions pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,4 @@ def block_ddp_plugin_sync_behaviour(self):
yield cm


# TODO: allow user to compare with string even internaly we shall use these Enum to prevent typos...
class BackendType(Enum):
DP = 'dp'
DDP = 'ddp'
DDP2 = 'ddp2'
DDP_SPAWN = 'ddp_spawn'
# decuple distrib and device
DDP_CPU = 'ddp_cpu'
HOROVOD = 'horovod'
# this is rather device
TPU = 'tpu'

1 change: 1 addition & 0 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ def set_distributed_mode(self):
self.trainer.use_ddp = True
self.trainer.data_parallel_device_ids = None
self.trainer.on_gpu = False
self.trainer.on_cpu = True
elif self.trainer.distributed_backend == "horovod":
self._set_horovod_backend()

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@


class LoggerStages(str, Enum):
""" Train/validation/test phase in each training step.
>>> # you can math the type with string
>>> LoggerStages.TRAIN == 'train'
True
"""
TRAIN = "train"
VAL = "validation"
TEST = "test"
Expand All @@ -35,7 +41,7 @@ def determine_stage(stage_or_testing: Union[str, bool]) -> 'LoggerStages':
raise RuntimeError(f"Invalid stage {stage_or_testing} of type {type(stage_or_testing)} given")


class ResultStoreType(Enum):
class ResultStoreType(str, Enum):
INSIDE_BATCH_TRAIN_LOOP = "inside_batch_train_loop"
OUTSIDE_BATCH_TRAIN_LOOP = "outside_batch_train_loop"

Expand Down
135 changes: 135 additions & 0 deletions pytorch_lightning/trainer/deprecated_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pytorch_lightning.utilities import rank_zero_warn, DistributedType, DeviceType


class DeprecatedDistDeviceAttributes:

_distrib_type: DistributedType
_device_type: DeviceType
num_gpus: int

@property
def on_cpu(self) -> bool:
# rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.CPU

@on_cpu.setter
def on_cpu(self, val: bool) -> None:
# rank_zero_warn("Internal: `on_cpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
if val:
self._device_type = DeviceType.CPU

@property
def on_tpu(self) -> bool:
# rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.TPU

@on_tpu.setter
def on_tpu(self, val: bool) -> None:
# rank_zero_warn("Internal: `on_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
# todo add logic that it cannot be set if TPU is missing
if val:
self._device_type = DeviceType.TPU

@property
def use_tpu(self) -> bool:
# rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.TPU

@use_tpu.setter
def use_tpu(self, val: bool) -> None:
# rank_zero_warn("Internal: `use_tpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
# todo add logic that it cannot be set if TPU is missing
if val:
self._device_type = DeviceType.TPU

@property
def on_gpu(self) -> bool:
# rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._device_type == DeviceType.GPU

@on_gpu.setter
def on_gpu(self, val: bool) -> None:
# rank_zero_warn("Internal: `on_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
# todo add logic that it cannot be set if GPU is missing
if val:
self._device_type = DeviceType.GPU

@property
def use_dp(self) -> bool:
# rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._distrib_type == DistributedType.DP

@use_dp.setter
def use_dp(self, val: bool) -> None:
# rank_zero_warn("Internal: `use_dp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
if val:
self._distrib_type = DistributedType.DP

@property
def use_ddp(self) -> bool:
# rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._distrib_type == DistributedType.DDP

@use_ddp.setter
def use_ddp(self, val: bool) -> None:
# rank_zero_warn("Internal: `use_ddp` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
if val:
self._distrib_type = DistributedType.DDP

@property
def use_ddp2(self) -> bool:
# rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
return self._device_type and self._distrib_type == DistributedType.DDP2

@use_ddp2.setter
def use_ddp2(self, val: bool) -> None:
# rank_zero_warn("Internal: `use_ddp2` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning)
if val:
self._distrib_type = DistributedType.DDP2

@property
def use_horovod(self) -> bool:
# rank_zero_warn(
# "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning
# )
return self._device_type and self._distrib_type == DistributedType.HOROVOD

@use_horovod.setter
def use_horovod(self, val: bool) -> None:
# rank_zero_warn(
# "Internal: `use_horovod` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning
# )
if val:
self._distrib_type = DistributedType.HOROVOD

@property
def use_single_gpu(self) -> bool:
# rank_zero_warn(
# "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning,
# )
# todo, limiting to exclude DDP2 is not clear but it comes from connectors...
return (self._device_type and self._device_type == DeviceType.GPU
and self.num_gpus == 1
and self._distrib_type not in (DistributedType.DDP2, ))

@use_single_gpu.setter
def use_single_gpu(self, val: bool) -> None:
# rank_zero_warn(
# "Internal: `use_single_gpu` is deprecated in v1.1 and will be removed in v1.2.", DeprecationWarning,
# )
if val:
self._device_type = DeviceType.GPU
12 changes: 10 additions & 2 deletions pytorch_lightning/trainer/states.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,17 @@
import pytorch_lightning


class TrainerState(Enum):
class TrainerState(str, Enum):
""" State which is set in the :class:`~pytorch_lightning.trainer.trainer.Trainer`
to indicate what is currently or was executed. """
to indicate what is currently or was executed.
>>> # you can math the type with string
>>> TrainerState.RUNNING == 'RUNNING'
True
>>> # which is case sensitive
>>> TrainerState.FINISHED == 'finished'
False
"""
INITIALIZING = 'INITIALIZING'
RUNNING = 'RUNNING'
FINISHED = 'FINISHED'
Expand Down
12 changes: 7 additions & 5 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,10 @@
from pytorch_lightning import _logger as log
from pytorch_lightning.accelerators.accelerator import Accelerator
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.accelerators.cpu_accelerator import CPUAccelerator
from pytorch_lightning.callbacks import Callback, ModelCheckpoint
from pytorch_lightning.trainer.deprecated_api import DeprecatedDistDeviceAttributes
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.core.datamodule import LightningDataModule
from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.core.memory import ModelSummary
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.loggers import LightningLoggerBase
from pytorch_lightning.plugins.plugin_connector import PluginConnector
Expand All @@ -53,11 +52,11 @@
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
from pytorch_lightning.trainer.optimizers import TrainerOptimizersMixin
from pytorch_lightning.trainer.properties import TrainerProperties
from pytorch_lightning.trainer.states import TrainerState, trainer_state
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.trainer.training_tricks import TrainerTrainingTricksMixin
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn, DeviceType
from pytorch_lightning.utilities.cloud_io import load as pl_load
from pytorch_lightning.utilities.debugging import InternalDebugger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand All @@ -78,6 +77,7 @@ class Trainer(
TrainerLoggingMixin,
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
DeprecatedDistDeviceAttributes,
):
@overwrite_by_env_vars
def __init__(
Expand Down Expand Up @@ -284,6 +284,8 @@ def __init__(
handle AMP, TPU, accumulated_gradients, etc..
"""
super().__init__()
self._device_type = DeviceType.CPU
self._distrib_type = None

# init connectors
self.dev_debugger = InternalDebugger(self)
Expand Down
59 changes: 58 additions & 1 deletion pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import platform
from distutils.version import LooseVersion
from enum import Enum
from typing import Union

import numpy
import torch
Expand Down Expand Up @@ -66,6 +67,62 @@ def _module_available(module_path: str) -> bool:
FLOAT64_EPSILON = numpy.finfo(numpy.float64).eps


class AMPType(Enum):
class LightningEnum(str, Enum):
""" Type of any enumerator with allowed comparison to string invariant to cases. """

@classmethod
def from_str(cls, value: str) -> 'LightningEnum':
statuses = [status for status in dir(cls) if not status.startswith('_')]
for st in statuses:
if st.lower() == value.lower():
return getattr(cls, st)
return None

def __eq__(self, other: Union[str, Enum]) -> bool:
other = other.value if isinstance(other, Enum) else str(other)
return self.value.lower() == other.lower()


class AMPType(LightningEnum):
"""Type of Automatic Mixed Precission used for training.
>>> # you can math the type with string
>>> AMPType.APEX == 'apex'
True
"""
APEX = 'apex'
NATIVE = 'native'


class DistributedType(LightningEnum):
""" Define type of ditributed computing.
>>> # you can math the type with string
>>> DistributedType.DDP == 'ddp'
True
>>> # which is case invariant
>>> DistributedType.DDP2 == 'DDP2'
True
"""
DP = 'dp'
DDP = 'ddp'
DDP2 = 'ddp2'
DDP_SPAWN = 'ddp_spawn'
HOROVOD = 'horovod'


class DeviceType(LightningEnum):
""" Define Device type byt its nature - acceleatrors.
>>> DeviceType.CPU == DeviceType.from_str('cpu')
True
>>> # you can math the type with string
>>> DeviceType.GPU == 'GPU'
True
>>> # which is case invariant
>>> DeviceType.TPU == 'tpu'
True
"""
CPU = 'CPU'
GPU = 'GPU'
TPU = 'TPU'
16 changes: 9 additions & 7 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1332,15 +1332,17 @@ def test_num_sanity_val_steps_neg_one(tmpdir, limit_val_batches):
),
],
)
# Todo: mock nb Gpus so all these tests can run on any device
# todo: think about simplification, that the the expected will be just a list use_xxx which shall be true...
def test_trainer_config(trainer_kwargs, expected):
trainer = Trainer(**trainer_kwargs)
assert trainer.use_dp is expected["use_dp"]
assert trainer.use_ddp is expected["use_ddp"]
assert trainer.use_ddp2 is expected["use_ddp2"]
assert trainer.num_gpus == expected["num_gpus"]
assert trainer.on_gpu is expected["on_gpu"]
assert trainer.use_single_gpu is expected["use_single_gpu"]
assert trainer.num_processes == expected["num_processes"]
assert trainer.use_dp is expected["use_dp"], 'for input: %s' % trainer_kwargs
assert trainer.use_ddp is expected["use_ddp"], 'for input: %s' % trainer_kwargs
assert trainer.use_ddp2 is expected["use_ddp2"], 'for input: %s' % trainer_kwargs
assert trainer.num_gpus == expected["num_gpus"], 'for input: %s' % trainer_kwargs
assert trainer.on_gpu is expected["on_gpu"], 'for input: %s' % trainer_kwargs
assert trainer.use_single_gpu is expected["use_single_gpu"], 'for input: %s' % trainer_kwargs
assert trainer.num_processes == expected["num_processes"], 'for input: %s' % trainer_kwargs


def test_trainer_subclassing():
Expand Down

0 comments on commit ce91795

Please sign in to comment.