Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove trainer hidden state | sanity refactor [2 / n] #7507

Merged
merged 12 commits into from
May 17, 2021
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Refactored Loops
* Moved attributes `global_step`, `current_epoch`, `max/min_steps`, `max/min_epochs`, `batch_idx`, and `total_batch_idx` to TrainLoop ([#7437](https://github.com/PyTorchLightning/pytorch-lightning/pull/7025))
* Refactored result handling in training loop ([#7506](https://github.com/PyTorchLightning/pytorch-lightning/pull/7506))
* Moved attributes `hiddens` and `split_idx` to TrainLoop ([#7507](https://github.com/PyTorchLightning/pytorch-lightning/pull/7507))

- `DataModule`s now avoid duplicate `{setup,teardown,prepare_data}` calls for the same stage ([#7238](https://github.com/PyTorchLightning/pytorch-lightning/pull/7238))

Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -1619,7 +1619,7 @@ def get_progress_bar_dict(self):
module_tbptt_enabled = self.truncated_bptt_steps > 0
trainer_tbptt_enabled = self.trainer.truncated_bptt_steps is not None and self.trainer.truncated_bptt_steps > 0
if module_tbptt_enabled or trainer_tbptt_enabled:
tqdm_dict["split_idx"] = self.trainer.split_idx
tqdm_dict["split_idx"] = self.trainer.train_loop.split_idx
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is get_progress_bar_dict still needed on the lightning module? wouldn't things come from self.log ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's for users to override and customize the default elements in the progress bar. Like how the version number is displayed or apparently the split index here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not something one could customize through self.log directly I would say.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it customization we would want to push to a custom progress bar callback instead of being part of the core module?

Copy link
Contributor

@carmocca carmocca May 12, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

custom progress bar callback instead of being part of the core module?

IMO yes for maximum separation of concerns

But people might complain about having to define a custom progress bar just for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

one could argue maybe the split_idx is not very useful to display in the progbar but still I would keep the hook


if self.trainer.logger is not None and self.trainer.logger.version is not None:
version = self.trainer.logger.version
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@

class DataConnector(object):

def __init__(self, trainer):
def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"):
self.trainer = trainer
self.multiple_trainloader_mode = multiple_trainloader_mode

def on_trainer_init(
self, check_val_every_n_epoch: int, reload_dataloaders_every_epoch: bool, prepare_data_per_node: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,6 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_ste
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
self.trainer.log_every_n_steps = log_every_n_steps
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
self.trainer.split_idx = None

@property
def should_flush_logs(self):
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/data_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def reset_train_dataloader(self, model: LightningModule) -> None:
apply_to_collection(self.train_dataloader, DataLoader, self.auto_add_worker_init_fn)

# wrap the sequence of train loaders to a CombinedLoader object for computing the num_training_batches
self.train_dataloader = CombinedLoader(self.train_dataloader, self._multiple_trainloader_mode)
self.train_dataloader = CombinedLoader(self.train_dataloader, self.data_connector.multiple_trainloader_mode)

self.num_training_batches = len(self.train_dataloader) if has_len(self.train_dataloader) else float('inf')

Expand Down
6 changes: 2 additions & 4 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def __init__(
# init connectors
self.dev_debugger = InternalDebugger(self)
self.config_validator = ConfigValidator(self)
self.data_connector = DataConnector(self)
self.data_connector = DataConnector(self, multiple_trainloader_mode)
self.optimizer_connector = OptimizerConnector(self)

self.accelerator_connector = AcceleratorConnector(
Expand All @@ -329,9 +329,7 @@ def __init__(
self.checkpoint_connector = CheckpointConnector(self)
self.slurm_connector = SLURMConnector(self)
self.tuner = Tuner(self)
self.train_loop = TrainLoop(
self, multiple_trainloader_mode, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps
)
self.train_loop = TrainLoop(self, max_epochs, min_epochs, max_steps, min_steps, num_sanity_val_steps)
self.evaluation_loop = EvaluationLoop(self)
self.predict_loop = PredictLoop(self)

Expand Down
26 changes: 12 additions & 14 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class TrainLoop:
def __init__(
self,
trainer,
multiple_trainloader_mode: str,
max_epochs: Optional[int],
min_epochs: Optional[int],
max_steps: Optional[int],
Expand All @@ -51,17 +50,21 @@ def __init__(
self.warning_cache = WarningCache()
self._teardown_already_run = False
self.running_loss = TensorRunningAccum(window_length=20)
self._multiple_trainloader_mode = multiple_trainloader_mode
self._skip_backward = False
self.trainer._multiple_trainloader_mode = multiple_trainloader_mode
self._optimizer_freq_cumsum = None
self._hiddens = None
awaelchli marked this conversation as resolved.
Show resolved Hide resolved

self.global_step = 0
self.current_epoch = 0
self.trainer.should_stop = False

# the total batch index across all epochs
self.total_batch_idx = 0
# the current batch index in the loop that runs over the dataloader(s)
self.batch_idx = 0
# the current split index when the batch gets split into chunks in truncated backprop through time
self.split_idx = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • is this meant to be writable from the outside? should these be made available as properties instead?
  • while going through this, could you add a comment for what split_idx means for contributors who go through this code later?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no but for now I follow the above pattern and the major goal is to define this state strictly on the loop, not the trainer anymore.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

while going through this, could you add a comment for what split_idx means for contributors who go through this code later?

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay added comments


self.trainer.num_training_batches = 0
self.trainer.train_dataloader = None

Expand Down Expand Up @@ -336,7 +339,7 @@ def _process_training_step_output(self, training_step_output, split_batch):

# map to results under the hood
result.minimize = loss
self.trainer.hiddens = hiddens
self._hiddens = hiddens

# track batch for manual reduction with result
result.track_batch_size(len(split_batch))
Expand Down Expand Up @@ -478,7 +481,6 @@ def run_training_epoch(self):

for batch_idx, (batch, is_last_batch) in train_dataloader:
self.batch_idx = batch_idx
self.trainer.is_last_batch = is_last_batch

# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
Expand Down Expand Up @@ -655,7 +657,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
grad_norm_dict = {}

# bookkeeping
self.trainer.hiddens = None
self._hiddens = None

optimizers = self.prepare_optimizers()

Expand Down Expand Up @@ -684,6 +686,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
splits = self._tbptt_split_batch(batch)

for split_idx, split_batch in enumerate(splits):
self.split_idx = split_idx
carmocca marked this conversation as resolved.
Show resolved Hide resolved

# create an iterable for optimizers and loop over them
for opt_idx, optimizer in optimizers:
Expand All @@ -702,9 +705,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
# automatic_optimization=True: perform dpp sync only when performing optimizer_step
# automatic_optimization=False: don't block synchronization here
with self.block_ddp_sync_behaviour():
result = self.training_step_and_backward(
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
)
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self._hiddens)

# ------------------------------
# BACKWARD PASS
Expand All @@ -716,15 +717,15 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx):
def train_step_and_backward_closure():
nonlocal result
result = self.training_step_and_backward(
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens
split_batch, batch_idx, opt_idx, optimizer, self._hiddens
)
return None if result is None else result.loss

# optimizer step
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)

else:
result = self.training_step(split_batch, batch_idx, opt_idx, self.trainer.hiddens)
result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens)

if not result:
# user decided to skip optimization
Expand Down Expand Up @@ -967,9 +968,6 @@ def prepare_optimizers(self):
return optimizers

def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer):
# set split_idx to trainer for tracking
self.trainer.split_idx = split_idx

# make sure only the gradients of the current optimizer's parameters are calculated
# in the training step to prevent dangling gradients in multiple-optimizer setup.
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1:
Expand Down
8 changes: 6 additions & 2 deletions tests/trainer/loops/test_evaluation_loop_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,11 @@ def backward(self, loss, optimizer, optimizer_idx):

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
batch,
batch_idx,
0,
trainer.optimizers[0],
hiddens=None,
)
assert opt_closure_result['loss'].item() == 171

Expand Down Expand Up @@ -150,7 +154,7 @@ def backward(self, loss, optimizer, optimizer_idx):

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
assert opt_closure_result['loss'].item() == 171

Expand Down
8 changes: 6 additions & 2 deletions tests/trainer/loops/test_training_loop_flow_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,11 @@ def backward(self, loss, optimizer, optimizer_idx):

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
batch,
batch_idx,
0,
trainer.optimizers[0],
hiddens=None,
)
assert opt_closure_result['loss'].item() == 171

Expand Down Expand Up @@ -241,7 +245,7 @@ def backward(self, loss, optimizer, optimizer_idx):

# make sure the optimizer closure returns the correct things
opt_closure_result = trainer.train_loop.training_step_and_backward(
batch, batch_idx, 0, trainer.optimizers[0], trainer.hiddens
batch, batch_idx, 0, trainer.optimizers[0], hiddens=None
)
assert opt_closure_result['loss'].item() == 171

Expand Down