From 5d239ccd704c7b639b6e83fdff9460fa9dd5e790 Mon Sep 17 00:00:00 2001 From: Justus Schock <12886177+justusschock@users.noreply.github.com> Date: Sat, 30 Jan 2021 20:55:28 +0100 Subject: [PATCH] Base classes for accelerator refactoring (#5715) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * add basic accelerator class. Co-Authored with @awaelchi * Add base plugin class. Co-authored with @awaelchi * add basic trainign type plugin. Co-Authored with @awaelchi * add basic precision plugin. Co-Authored with @awaelchi * Add missing inits. Co-authored with @awaelchi * pep8 Co-authored-by: @awaelchi * ignore flake8 * coverage omit * imports in init * lost * imports * flake8 * . * . * chlog * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec * Update pytorch_lightning/plugins/training_type/training_type_plugin.py Co-authored-by: Jirka Borovec Co-authored-by: Adrian Wälchli Co-authored-by: Jirka Borovec Co-authored-by: Jirka Borovec --- CHANGELOG.md | 5 + pytorch_lightning/accelerators/accelerator.py | 375 ++++++++++++++++++ pytorch_lightning/plugins/__init__.py | 3 + pytorch_lightning/plugins/base_plugin.py | 57 +++ .../plugins/precision/__init__.py | 1 + .../plugins/precision/precision_plugin.py | 108 +++++ .../plugins/training_type/__init__.py | 1 + .../training_type/training_type_plugin.py | 112 ++++++ setup.cfg | 17 +- 9 files changed, 678 insertions(+), 1 deletion(-) create mode 100644 pytorch_lightning/accelerators/accelerator.py create mode 100644 pytorch_lightning/plugins/base_plugin.py create mode 100644 pytorch_lightning/plugins/precision/__init__.py create mode 100644 pytorch_lightning/plugins/precision/precision_plugin.py create mode 100644 pytorch_lightning/plugins/training_type/__init__.py create mode 100644 pytorch_lightning/plugins/training_type/training_type_plugin.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 1e56ed79b9525..a041defcd0028 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -107,6 +107,11 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Changed the default value for the `progress_bar_refresh_rate` Trainer argument in Google COLAB notebooks to 20 ([#5516](https://github.com/PyTorchLightning/pytorch-lightning/pull/5516)) +- Refactored Accelerators and Plugins ( + [#5715](https://github.com/PyTorchLightning/pytorch-lightning/pull/5715), + ) + + ### Deprecated - Function `stat_scores_multiple_classes` is deprecated in favor of `stat_scores` ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) diff --git a/pytorch_lightning/accelerators/accelerator.py b/pytorch_lightning/accelerators/accelerator.py new file mode 100644 index 0000000000000..c5c77d4711e6a --- /dev/null +++ b/pytorch_lightning/accelerators/accelerator.py @@ -0,0 +1,375 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Callable, Iterable, Optional, Union + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.core import LightningModule +from pytorch_lightning.plugins import TrainingTypePlugin +from pytorch_lightning.utilities.apply_func import move_data_to_device +from pytorch_lightning.utilities.enums import LightningEnum + + +class Accelerator(object): + """ + The Accelerator Base Class. + An Accelerator is meant to deal with one type of Hardware. + + Currently there are accelerators for: + - CPU + - GPU + - TPU + + Each Accelerator gets two plugins upon initialization: + One to handle differences from the training routine and one to handle different precisions. + + """ + + def __init__( + self, + precision_plugin, #: PrecisionPlugin # fixme + training_type_plugin: TrainingTypePlugin, + ) -> None: + """ + + Args: + precision_plugin: the plugin to handle precision-specific parts + training_type_plugin: the plugin to handle different training routines + """ + self.precision_plugin = precision_plugin + self.training_type_plugin = training_type_plugin + + self.optimizers = None + self.lr_schedulers = None + self.optimizer_frequencies = None + + def setup(self, trainer: "Trainer", model: LightningModule) -> None: + """ + Connects the plugins to the training process, creates optimizers + + Args: + trainer: the trainer instance to connect to + model: the model to train + """ + self.connect_training_type_plugin(self.training_type_plugin, model) + self.setup_optimizers(trainer, model) + self.connect_precision_plugin(self.precision_plugin) + self.optimizers = trainer.convert_to_lightning_optimizers(self.optimizers) + + @property + def model(self) -> torch.nn.Module: + """Returns the model. This can also be a wrapped LightningModule. + For retrieving the pure LightningModule use :attr:`Accelerator.lightning_module` + + """ + return self.training_type_plugin.model + + @model.setter + def model(self, new_model: torch.nn.Module) -> None: + self.training_type_plugin.model = new_model + + @property + def lightning_module(self) -> LightningModule: + """Returns the pure LightningModule. + To get the potentially wrapped model use :attr:`Accelerator.model` + + """ + return self.training_type_plugin.lightning_module + + @property + def root_device(self) -> torch.device: + return self.training_type_plugin.root_device + + def teardown(self): + """This method is called to teardown the training process. + It is the right place to release memory and free other ressources. + """ + pass + + def batch_to_device(self, batch: Any, device: torch.device) -> Any: + """Moves the batch to the correct device. + The returned batch is of the same type as the input batch, just having all tensors on the correct device. + + Args: + batch: The batch of samples to move to the correct device + device: The target device + """ + model = self.lightning_module + if model is not None: + return model.transfer_batch_to_device(batch, device) + return move_data_to_device(batch, device) + + def on_train_start(self): + """Hook to do something upon the training start""" + pass + + def training_step(self, args): + """The actual training step. + + Args: + args: the arguments for the models training step. Can consist of the following: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): Integer displaying index of this batch + optimizer_idx (int): When using multiple optimizers, this argument will also be present. + hiddens(:class:`~torch.Tensor`): Passed in if + :paramref:`~pytorch_lightning.trainer.trainer.Trainer.truncated_bptt_steps` > 0. + + """ + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.train_step_context(): + with self.training_type_plugin.train_step_context(): + return self.lightning_module.training_step(*args) + + def validation_step(self, args): + """The actual validation step. + + Args: + args: the arguments for the models validation step. Can consist of the following: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): The index of this batch + dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple val dataloaders used) + """ + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.val_step_context(): + with self.training_type_plugin.val_step_context(): + return self.lightning_module.validation_step(*args) + + def test_step(self, args): + """The actual test step. + + Args: + args: the arguments for the models test step. Can consist of the following: + batch (:class:`~torch.Tensor` | (:class:`~torch.Tensor`, ...) | [:class:`~torch.Tensor`, ...]): + The output of your :class:`~torch.utils.data.DataLoader`. A tensor, tuple or list. + batch_idx (int): The index of this batch. + dataloader_idx (int): The index of the dataloader that produced this batch + (only if multiple test dataloaders used). + """ + batch = self.to_device(args[0]) + + args[0] = batch + + with self.precision_plugin.test_step_context(): + with self.training_type_plugin.test_step_context(): + return self.lightning_module.test_step(*args) + + def training_step_end(self, output): + """A hook to do something at the end of the training step + + Args: + output: the output of the training step + """ + return output + + def test_step_end(self, output): + """A hook to do something at the end of the test step + + Args: + output: the output of the test step + """ + return output + + def validation_step_end(self, output): + """A hook to do something at the end of the validation step + + Args: + output: the output of the validation step + """ + return output + + def process_dataloader( + self, dataloader: Union[Iterable, torch.utils.data.DataLoader] + ) -> Union[Iterable, torch.utils.data.DataLoader]: + """Wraps the dataloader if necessary + + Args: + dataloader: iterable. Ideally of type: :class:`torch.utils.data.DataLoader` + """ + return dataloader + + def backward( + self, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ) -> torch.Tensor: + """Forwards backward-calls to the precision plugin. + + Args: + closure_loss: a tensor holding the loss value to backpropagate + optimizer: the optimizer to do the step later on. + opt_idx: the index of the optimizer + should_accumulate: whether to accumulate gradients + """ + output = self.precision_plugin.backward( + self.lightning_module, closure_loss, optimizer, opt_idx, should_accumulate, *args, **kwargs + ) + + # TODO: this is a hack, find a better solution for this (hook?) + # fixme: uncomment when this class is added + # if isinstance(self.training_type_plugin, HorovodPlugin): + # optimizer.synchronize() + + return output + + def optimizer_step( + self, + optimizer: torch.optim.Optimizer, + current_epoch: int, + batch_idx: int, + opt_idx: int, + lambda_closure: Callable, + ): + """performs the actual optimizer step. + + Args: + optimizer: the optimizer performing the step + current_epoch: current training epoch + batch_idx: index of the current batch + opt_idx: index of the current optimizer + lambda_closure: closure calculating the loss value + + """ + model_ref = self.lightning_module + is_lbfgs = isinstance(optimizer, torch.optim.LBFGS) + # fixme: uncomment when this class is added + # is_native_amp = ( + # isinstance(self.precision_plugin, MixedPrecisionPlugin) and self.precision_plugin.backend == AMPType.NATIVE + # ) + is_native_amp = False + + self.precision_plugin.pre_optimizer_step(optimizer, opt_idx) + self.training_type_plugin.pre_optimizer_step(optimizer, opt_idx) + + # model hook + res = model_ref.optimizer_step( + epoch=current_epoch, + batch_idx=batch_idx, + optimizer=optimizer, + optimizer_idx=opt_idx, + optimizer_closure=lambda_closure, + on_tpu=False, # TPUAccelerator class sets this as True + using_native_amp=is_native_amp, + using_lbfgs=is_lbfgs, + ) + + self.precision_plugin.post_optimizer_step(optimizer, opt_idx) + self.training_type_plugin.post_optimizer_step(optimizer, opt_idx) + return res + + def optimizer_zero_grad( + self, current_epoch: int, batch_idx: int, optimizer: torch.optim.Optimizer, opt_idx: int + ) -> None: + """Zeros all model parameter's gradients""" + model_ref = self.lightning_module + model_ref.optimizer_zero_grad(current_epoch, batch_idx, optimizer, opt_idx) + + def clip_gradients(self, optimizer: torch.optim.Optimizer, clip_val: Union[int, float]) -> None: + """clips all the optimizer parameters to the given value""" + + self.precision_plugin.clip_gradients(optimizer, clip_val) + + def on_train_epoch_end(self, outputs) -> None: + """Hook to do something on the end of an training epoch + + Args: + outputs: the outputs of the training steps + """ + pass + + def on_train_end(self) -> None: + """Hook to do something at the end of the training""" + pass + + def setup_optimizers(self, trainer: "Trainer", model: LightningModule): + """creates optimizers and schedulers + + Args: + trainer: the Trainer, these optimizers should be connected to + model: the model to be optimized by the created optimizers + """ + if trainer.testing is True: + return + optimizers, lr_schedulers, optimizer_frequencies = trainer.init_optimizers(model) + self.optimizers = optimizers + self.lr_schedulers = lr_schedulers + self.optimizer_frequencies = optimizer_frequencies + + def connect_training_type_plugin(self, plugin: TrainingTypePlugin, model: LightningModule) -> None: + """Attaches the training type plugin to the accelerator. + Also transfers ownership of the model to this plugin + + """ + plugin.connect(model) + + def connect_precision_plugin(self, plugin): #: PrecisionPlugin # fixme + """Attaches the precision plugin to the accelerator""" + model, optimizers, schedulers = plugin.connect(self.model, self.optimizers, self.lr_schedulers) + self.model = model + self.optimizers = optimizers + self.schedulers = schedulers + + def to_device(self, batch: Any) -> Any: + """Pushes the batch to the root device""" + return self.batch_to_device(batch, self.root_device) + + @property + def amp_backend(self) -> Optional[LightningEnum]: + # fixme: uncomment when this class is added + # if isinstance(self.precision_plugin, ApexMixedPrecisionPlugin): + # return AMPType.APEX + # elif isinstance(self.precision_plugin, NativeMixedPrecisionPlugin): + # return AMPType.NATIVE + # return None + pass + + @property + def precision(self) -> int: + return self.precision_plugin.precision + + @property + def scaler(self): + if hasattr(self.precision_plugin, "scaler"): + return self.precision_plugin.scaler + + return None + + @property + def rpc_enabled(self) -> bool: + return self.training_type_plugin.rpc_enabled + + def optimizer_state(self, optimizer: Optimizer) -> dict: + """ + Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom + plugins. + """ + if self.training_type_plugin and hasattr(self.training_type_plugin, "optimizer_state"): + return self.training_type_plugin.optimizer_state(optimizer) + return optimizer.state_dict() + + def on_save(self, checkpoint): + return checkpoint \ No newline at end of file diff --git a/pytorch_lightning/plugins/__init__.py b/pytorch_lightning/plugins/__init__.py index e69de29bb2d1d..a17d5127edfc6 100644 --- a/pytorch_lightning/plugins/__init__.py +++ b/pytorch_lightning/plugins/__init__.py @@ -0,0 +1,3 @@ +from pytorch_lightning.plugins.base_plugin import Plugin # noqa: F401 +from pytorch_lightning.plugins.precision.precision_plugin import PrecisionPlugin # noqa: F401 +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin # noqa: F401 diff --git a/pytorch_lightning/plugins/base_plugin.py b/pytorch_lightning/plugins/base_plugin.py new file mode 100644 index 0000000000000..c4eeff52751a6 --- /dev/null +++ b/pytorch_lightning/plugins/base_plugin.py @@ -0,0 +1,57 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import contextlib + +import torch + + +class Plugin(object): + """Basic Plugin class to derive precision and training type plugins from.""" + + def connect(self, model: torch.nn.Module, *args, **kwargs): + """Connects the plugin with the accelerator (and thereby with trainer and model). + Will be called by the accelerator. + """ + pass + + def pre_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int): + """Hook to do something before each optimizer step.""" + pass + + def post_optimizer_step(self, optimizer: torch.optim.Optimizer, optimizer_idx: int): + """Hook to do something after each optimizer step.""" + pass + + def pre_training(self): + """Hook to do something before the training starts.""" + pass + + def post_training(self): + """Hook to do something after the training finishes.""" + pass + + @contextlib.contextmanager + def train_step_context(self): + """A contextmanager for the trainstep""" + yield + + @contextlib.contextmanager + def val_step_context(self): + """A contextmanager for the validation step""" + yield + + @contextlib.contextmanager + def test_step_context(self): + """A contextmanager for the teststep""" + yield \ No newline at end of file diff --git a/pytorch_lightning/plugins/precision/__init__.py b/pytorch_lightning/plugins/precision/__init__.py new file mode 100644 index 0000000000000..8b137891791fe --- /dev/null +++ b/pytorch_lightning/plugins/precision/__init__.py @@ -0,0 +1 @@ + diff --git a/pytorch_lightning/plugins/precision/precision_plugin.py b/pytorch_lightning/plugins/precision/precision_plugin.py new file mode 100644 index 0000000000000..0ff54bf1e8515 --- /dev/null +++ b/pytorch_lightning/plugins/precision/precision_plugin.py @@ -0,0 +1,108 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from typing import Generator, Union + +import torch +from torch.optim import Optimizer + +from pytorch_lightning.core import LightningModule +from pytorch_lightning.plugins.base_plugin import Plugin + + +class PrecisionPlugin(Plugin): + EPSILON = 1e-6 + precision = 32 + + def master_params(self, optimizer: torch.optim.Optimizer) -> Generator[torch.Tensor, None, None]: + """The master params of the model. Returns the plain model params here. + Maybe different in other precision plugins. + + """ + for group in optimizer.param_groups: + for p in group["params"]: + yield p + + def connect(self, model: torch.nn.Module, optimizers, lr_schedulers): + """Connects this plugin to the accelerator and the training process""" + return model, optimizers, lr_schedulers + + def backward( + self, + model: LightningModule, + closure_loss: torch.Tensor, + optimizer: torch.optim.Optimizer, + opt_idx: int, + should_accumulate: bool, + *args, + **kwargs, + ): + """performs the actual backpropagation + + Args: + model: the model to be optimized + closure_loss: the loss value obtained from the closure + optimizer: the optimizer to perform the step lateron + opt_idx: the optimizer's index + should_accumulate: whether to accumulate gradients or not + + """ + automatic_optimization = model.automatic_optimization + + # do backward pass + if automatic_optimization: + model.backward(closure_loss, optimizer, opt_idx) + else: + closure_loss.backward(*args, **kwargs) + + # once backward has been applied, release graph + closure_loss = closure_loss.detach() + + return closure_loss + + def clip_gradients(self, optimizer: Optimizer, clip_val: Union[int, float], norm_type: float = float(2.0)): + """Clips the gradients to a specific value""" + # TODO: separate TPU case from here + if clip_val is None: + return + + grad_clip_val = float(clip_val) + + if grad_clip_val <= 0: + return + + parameters = self.master_params(optimizer) + + max_norm = grad_clip_val + + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + parameters = list(filter(lambda p: p.grad is not None, parameters)) + + device = parameters[0].device + + if norm_type == math.inf: + total_norm = max(p.grad.data.abs().max() for p in parameters) + else: + out = torch.empty(len(parameters), device=device) + for i, p in enumerate(parameters): + torch.norm(p.grad.data.to(device), norm_type, out=out[i]) + total_norm = torch.norm(out, norm_type) + + eps = self.EPSILON + + clip_coef = torch.tensor(max_norm, device=device) / (total_norm + eps) + clip_coef = torch.min(clip_coef, torch.ones_like(clip_coef)) + for p in parameters: + p.grad.data.mul_(clip_coef.to(p.grad.data.device)) diff --git a/pytorch_lightning/plugins/training_type/__init__.py b/pytorch_lightning/plugins/training_type/__init__.py new file mode 100644 index 0000000000000..329f6347b17c3 --- /dev/null +++ b/pytorch_lightning/plugins/training_type/__init__.py @@ -0,0 +1 @@ +from pytorch_lightning.plugins.training_type.training_type_plugin import TrainingTypePlugin diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py new file mode 100644 index 0000000000000..d1e7907d5d97f --- /dev/null +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -0,0 +1,112 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from abc import ABC, abstractmethod +from typing import Optional + +import torch + +from pytorch_lightning import _logger as log +from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.plugins.base_plugin import Plugin + + +class TrainingTypePlugin(Plugin, ABC): + """A Plugin to change the behaviour of the training, validation and test-loop.""" + + def __init__(self): + self._model = None + self._results = None + self.global_rank = 0 + + @property + @abstractmethod + def on_gpu(self) -> bool: + """Returns whether the current process is done on GPU""" + + @property + @abstractmethod + def root_device(self) -> torch.device: + """Returns the root device""" + + @abstractmethod + def model_to_device(self): + """Moves the model to the correct device""" + + @property + @abstractmethod + def is_global_zero(self) -> bool: + """Whether the current process is the rank zero process not only on the local node, but for all nodes.""" + + @abstractmethod + def reduce(self, output, *args, **kwargs): + """Reduces the given output (e.g. across GPUs/Processes)""" + + @abstractmethod + def barrier(self, name: Optional[str] = None): + """Forces all possibly joined processes to wait for each other""" + + @abstractmethod + def broadcast(self, obj: object, src: int = 0) -> object: + """Broadcasts an object to all processes""" + + # TODO method this is currently unused. Check after complete refactors are pushed + def set_nvidia_flags(self, is_slurm_managing_tasks, device_ids): + if device_ids is None: + return + + # set the correct cuda visible devices (using pci order) + os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" + all_gpu_ids = ",".join([str(x) for x in range(torch.cuda.device_count())]) + devices = os.environ.get("CUDA_VISIBLE_DEVICES", all_gpu_ids) + log.info(f"LOCAL_RANK: {self.trainer.local_rank} - CUDA_VISIBLE_DEVICES: [{devices}]") + + def reduce_early_stopping_decision(self, should_stop: bool) -> bool: + """Reduce the early stopping decision across all possibly spawned processes""" + return should_stop + + @property + def model(self) -> torch.nn.Module: + """Returns the potentially wrapped LightningModule""" + return self._model + + @model.setter + def model(self, new_model: torch.nn.Module): + self._model = new_model + + @property + def lightning_module(self) -> LightningModule: + """Returns the pure LightningModule without potential wrappers""" + return self._model + + @property + def results(self): + """ + The results of the last training/testing run will be cached here. + In distributed training, we make sure to transfer the results to the appropriate master process. + """ + # TODO: improve these docs + return self._results + + @property + def rpc_enabled(self) -> bool: + return False + + def start_training(self, trainer: "Trainer") -> None: + # double dispatch to initiate the training loop + self._results = trainer.train() + + def start_testing(self, trainer: "Trainer") -> None: + # double dispatch to initiate the test loop + self._results = trainer.run_test() diff --git a/setup.cfg b/setup.cfg index 20a53751e1dc2..dc351647a9986 100644 --- a/setup.cfg +++ b/setup.cfg @@ -53,12 +53,27 @@ omit = pytorch_lightning/utilities/xla_device_utils.py pytorch_lightning/utilities/distributed.py pytorch_lightning/tuner/auto_gpu_select.py + # TODO: temporary, until accelerator refactor is finished + pytorch_lightning/accelerators/accelerator.py + pytorch_lightning/plugins/training_type/*.py + pytorch_lightning/plugins/precision/*.py + pytorch_lightning/plugins/base_plugin.py [flake8] # TODO: this should be 88 or 100 according PEP8 max-line-length = 120 -exclude = .tox,*.egg,build,temp +exclude = + .tox, + *.egg + build + temp + # TODO: temporary until accelerator refactor finished + pytorch_lightning/accelerators/accelerator.py + pytorch_lightning/plugins/training_type + pytorch_lightning/plugins/precision + pytorch_lightning/plugins/base_plugin.py + select = E,W,F doctests = True verbose = 2