Skip to content

Commit

Permalink
ref: organize args 3/n (#3449)
Browse files Browse the repository at this point in the history
* ref: organize args 3/n

* ref: organize args 3/n

* ref: organize args 3/n

* ref: organize args 3/n

* ref: organize args 3/n

* ref: organize args 3/n
  • Loading branch information
williamFalcon authored Sep 10, 2020
1 parent a208d6d commit 3281586
Show file tree
Hide file tree
Showing 10 changed files with 131 additions and 83 deletions.
18 changes: 16 additions & 2 deletions pytorch_lightning/accelerators/accelerator_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import os
import torch
from pytorch_lightning.utilities import device_parser
from pytorch_lightning.utilities import rank_zero_warn
from pytorch_lightning.utilities import rank_zero_warn, rank_zero_only


class AcceleratorConnector:
Expand All @@ -21,8 +21,22 @@ def on_trainer_init(
log_gpu_memory,
sync_batchnorm,
benchmark,
replace_sampler_ddp
replace_sampler_ddp,
deterministic
):
self.trainer.deterministic = deterministic
torch.backends.cudnn.deterministic = self.trainer.deterministic
if self.trainer.deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

# init the default rank if exists
# we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
# this way we only show it on rank 0
if 'LOCAL_RANK' in os.environ:
rank_zero_only.rank = int(os.environ['LOCAL_RANK'])

# benchmarking
self.trainer.benchmark = benchmark
torch.backends.cudnn.benchmark = self.trainer.benchmark
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@ def on_trainer_init(
progress_bar_refresh_rate,
process_position,
default_root_dir,
weights_save_path
weights_save_path,
resume_from_checkpoint
):
self.trainer.resume_from_checkpoint = resume_from_checkpoint

# init folder paths for checkpoint + weights save callbacks
self.trainer._default_root_dir = default_root_dir or os.getcwd()
self.trainer._weights_save_path = weights_save_path or self.trainer._default_root_dir
Expand Down
5 changes: 4 additions & 1 deletion pytorch_lightning/trainer/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,10 @@ class DataConnector(object):
def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch):
def on_trainer_init(self, check_val_every_n_epoch, reload_dataloaders_every_epoch, prepare_data_per_node):
self.trainer.datamodule = None
self.trainer.prepare_data_per_node = prepare_data_per_node

self.trainer.check_val_every_n_epoch = check_val_every_n_epoch
self.trainer.reload_dataloaders_every_epoch = reload_dataloaders_every_epoch
self.trainer._is_data_prepared = False
Expand Down
15 changes: 15 additions & 0 deletions pytorch_lightning/trainer/evaluation_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,21 @@ def __init__(self, trainer):
self.predictions = None
self.max_batches = None

def on_trainer_init(self):
self.trainer.num_val_batches = []
self.trainer.num_sanity_val_batches = []
self.trainer.num_test_batches = []
self.trainer.test_dataloaders = None
self.trainer.val_dataloaders = None
self.trainer.running_sanity_check = False
self.trainer.testing = False

# when .test() is called, it sets this
self.trainer.tested_ckpt_path = None

# when true, prints test results
self.trainer.verbose_test = True

def get_evaluation_dataloaders(self, max_batches):
# select dataloaders
model = self.trainer.get_model()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,16 @@
from pytorch_lightning.utilities.exceptions import MisconfigurationException


class LRSchedulerConnector:
class OptimizerConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self):
self.trainer.lr_schedulers = []
self.trainer.optimizers = None
self.trainer.optimizer_frequencies = []

def update_learning_rates(self, interval: str, monitor_metrics=None):
"""Update learning rates.
Expand Down Expand Up @@ -52,14 +57,12 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
f' which is not available. Available metrics are: {avail_metrics}.'
' Condition can be set using `monitor` key in lr scheduler dict'
)
if self.trainer.dev_debugger.enabled:
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
lr_scheduler['scheduler'].step(monitor_val)
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if self.trainer.dev_debugger.enabled:
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx,
interval,
Expand All @@ -69,14 +72,12 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
monitor_key,
)
else:
if self.trainer.dev_debugger.enabled:
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

# update LR
old_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
lr_scheduler['scheduler'].step()
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']

if self.trainer.dev_debugger.enabled:
new_lr = lr_scheduler['scheduler'].optimizer.param_groups[0]['lr']
self.trainer.dev_debugger.track_lr_schedulers_update(
self.trainer.batch_idx,
interval,
Expand Down
26 changes: 26 additions & 0 deletions pytorch_lightning/trainer/profiler_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License
from pytorch_lightning.profiler import PassThroughProfiler, SimpleProfiler


class ProfilerConnector:

def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, profiler):
# configure profiler
if profiler is True:
profiler = SimpleProfiler()
self.trainer.profiler = profiler or PassThroughProfiler()
94 changes: 27 additions & 67 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,14 +45,15 @@
from pytorch_lightning.trainer.training_loop import TrainLoop
from pytorch_lightning.accelerators.accelerator_connector import AcceleratorConnector
from pytorch_lightning.trainer.logger_connector import LoggerConnector
from pytorch_lightning.trainer.lr_scheduler_connector import LRSchedulerConnector
from pytorch_lightning.trainer.optimizer_connector import OptimizerConnector
from pytorch_lightning.trainer.training_trick_connector import TrainingTricksConnector
from pytorch_lightning.trainer.callback_connector import CallbackConnector
from pytorch_lightning.trainer.model_connector import ModelConnector
from pytorch_lightning.trainer.debugging_connector import DebuggingConnector
from pytorch_lightning import _logger as log
from pytorch_lightning.tuner.tuning import Tuner
from pytorch_lightning.trainer.precision_connector import PrecisionConnector
from pytorch_lightning.trainer.profiler_connector import ProfilerConnector
from pytorch_lightning.trainer.data_connector import DataConnector
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer import docstrings
Expand Down Expand Up @@ -154,71 +155,28 @@ def __init__(
):
super().__init__()

self.deterministic = deterministic
torch.backends.cudnn.deterministic = self.deterministic
if self.deterministic:
# fixing non-deterministic part of horovod
# https://github.com/PyTorchLightning/pytorch-lightning/pull/1572/files#r420279383
os.environ["HOROVOD_FUSION_THRESHOLD"] = str(0)

# init the default rank if exists
# we need to call this here or NVIDIA flags and other messaging in init will show on all ranks
# this way we only show it on rank 0
if 'LOCAL_RANK' in os.environ:
rank_zero_only.rank = int(os.environ['LOCAL_RANK'])

# tracks internal state for debugging
# init connectors
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.lr_scheduler_connector = LRSchedulerConnector(self)
self.optimizer_connector = OptimizerConnector(self)
self.accelerator_connector = AcceleratorConnector(self)
self.logger_connector = LoggerConnector(self)
self.model_connector = ModelConnector(self)
self.precision_connector = PrecisionConnector(self)
self.callback_connector = CallbackConnector(self)
self.debugging_connector = DebuggingConnector(self)
self.training_tricks_connector = TrainingTricksConnector(self)

self.profile_connector = ProfilerConnector(self)
self.tuner = Tuner(self)
self.accelerator_backend = None

# loops
self.evaluation_loop = EvaluationLoop(self)
self.train_loop = TrainLoop(self)

# training bookeeping
self.total_batch_idx = 0
self.batch_idx = 0
self.num_training_batches = 0
self.num_val_batches = []
self.num_sanity_val_batches = []
self.num_test_batches = []
self.train_dataloader = None
self.test_dataloaders = None
self.val_dataloaders = None

# when true, prints test results
self.verbose_test = True

# when .test() is called, it sets this
self.tested_ckpt_path = None

# training state
self.weights_summary = weights_summary
self.model = None
self.datamodule = None
self.testing = False
self.prepare_data_per_node = prepare_data_per_node
self.lr_schedulers = []
self.optimizers = None
self.optimizer_frequencies = []
self.global_step = 0
self.current_epoch = 0
self.interrupted = False
self.should_stop = False
self.running_sanity_check = False
self._state = TrainerState.INITIALIZING
self.shown_warnings = set()

# init callbacks
self.callback_connector.on_trainer_init(
Expand All @@ -229,20 +187,29 @@ def __init__(
process_position,
default_root_dir,
weights_save_path,
resume_from_checkpoint
)

# init data flags
self.data_connector.on_trainer_init(check_val_every_n_epoch, reload_dataloaders_every_epoch)

# hook
self.on_init_start()

# init optimizer + lr scheduler related flags
self.optimizer_connector.on_trainer_init()

# init data flags
self.data_connector.on_trainer_init(
check_val_every_n_epoch,
reload_dataloaders_every_epoch,
prepare_data_per_node
)

# init training tricks
self.training_tricks_connector.on_trainer_init(
gradient_clip_val,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps
truncated_bptt_steps,
terminate_on_nan
)

# init accelerator related flags
Expand All @@ -256,23 +223,19 @@ def __init__(
log_gpu_memory,
sync_batchnorm,
benchmark,
replace_sampler_ddp
replace_sampler_ddp,
deterministic
)

# init train loop related flags
self.train_loop.on_init_start(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
self.train_loop.on_trainer_init(max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
self.evaluation_loop.on_trainer_init()

self.auto_lr_find = auto_lr_find
self.auto_scale_batch_size = auto_scale_batch_size

self.resume_from_checkpoint = resume_from_checkpoint
self.terminate_on_nan = terminate_on_nan
self.shown_warnings = set()
# configure tuner
self.tuner.on_trainer_init(auto_lr_find, auto_scale_batch_size)

# configure profiler
if profiler is True:
profiler = SimpleProfiler()
self.profiler = profiler or PassThroughProfiler()
self.profile_connector.on_trainer_init(profiler)

# init logger flags
self.logger_connector.on_trainer_init(logger, log_save_interval, row_log_interval)
Expand Down Expand Up @@ -309,9 +272,6 @@ def tune(
# setup data, etc...
self.setup_fit(model, train_dataloader, val_dataloaders, datamodule)

# hook
self.call_hook('on_fit_start', model)

# hook
self.data_connector.prepare_data(model)

Expand Down Expand Up @@ -502,7 +462,7 @@ def train(self):
return

# update LR schedulers
self.lr_scheduler_connector.update_learning_rates(interval='epoch')
self.optimizer_connector.update_learning_rates(interval='epoch')

# early stopping
met_min_epochs = epoch >= self.min_epochs - 1
Expand Down
16 changes: 14 additions & 2 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from pytorch_lightning.core.step_result import EvalResult, Result
from pytorch_lightning.utilities.parsing import AttributeDict
from copy import copy, deepcopy
from pytorch_lightning.trainer.states import TrainerState


class TrainLoop:
Expand All @@ -38,7 +39,18 @@ def __init__(self, trainer):
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)

def on_init_start(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps):
def on_trainer_init(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps):
self.trainer.global_step = 0
self.trainer.current_epoch = 0
self.trainer.interrupted = False
self.trainer.should_stop = False
self.trainer._state = TrainerState.INITIALIZING

self.trainer.total_batch_idx = 0
self.trainer.batch_idx = 0
self.trainer.num_training_batches = 0
self.trainer.train_dataloader = None

self.trainer.max_epochs = max_epochs
self.trainer.min_epochs = min_epochs
self.trainer.max_steps = max_steps
Expand Down Expand Up @@ -560,7 +572,7 @@ def update_train_loop_lr_schedulers(self, monitor_metrics=None):

if num_accumulated_batches_reached or num_training_batches_reached:
# update lr
self.trainer.lr_scheduler_connector.update_learning_rates(interval='step', monitor_metrics=monitor_metrics)
self.trainer.optimizer_connector.update_learning_rates(interval='step', monitor_metrics=monitor_metrics)

def run_on_epoch_end_hook(self):
self.trainer.call_hook('on_epoch_end')
Expand Down
12 changes: 11 additions & 1 deletion pytorch_lightning/trainer/training_trick_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,17 @@ class TrainingTricksConnector:
def __init__(self, trainer):
self.trainer = trainer

def on_trainer_init(self, gradient_clip_val, track_grad_norm, accumulate_grad_batches, truncated_bptt_steps):
def on_trainer_init(
self,
gradient_clip_val,
track_grad_norm,
accumulate_grad_batches,
truncated_bptt_steps,
terminate_on_nan
):

self.trainer.terminate_on_nan = terminate_on_nan

# gradient clipping
self.trainer.gradient_clip_val = gradient_clip_val

Expand Down
Loading

0 comments on commit 3281586

Please sign in to comment.