Skip to content

Commit

Permalink
ref: moving train loop to own object 2/n (intermediate steps) (#3313)
Browse files Browse the repository at this point in the history
* ref: moving train loop to own object 2/n (intermediate steps)

* ref: moving train loop to own object 2/n (intermediate steps)
  • Loading branch information
williamFalcon authored Sep 2, 2020
1 parent a5288fe commit 0d90d53
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 44 deletions.
19 changes: 19 additions & 0 deletions pytorch_lightning/trainer/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,25 @@ class DataConnector(object):
def __init__(self, trainer):
self.trainer = trainer

def get_profiled_train_dataloader(self, train_dataloader):
profiled_dl = self.trainer.profiler.profile_iterable(
enumerate(self._with_is_last(train_dataloader)),
"get_train_batch"
)
return profiled_dl

def _with_is_last(self, iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
it = iter(iterable)
last = next(it)
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True

def prepare_data(self, model):
# on multi-gpu jobs we only want to manipulate (download, etc) on node_rank=0, local_rank=0
# or in the case where each node needs to do its own manipulation in which case just local_rank=0
Expand Down
42 changes: 10 additions & 32 deletions pytorch_lightning/trainer/training_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ def training_step(self, batch, batch_idx):
from pytorch_lightning.utilities.parsing import AttributeDict
from pytorch_lightning.utilities.model_utils import is_overridden
from pytorch_lightning.trainer.training_loop_temp import TrainLoop
from pytorch_lightning.trainer.data_connector import DataConnector

try:
from apex import amp
Expand Down Expand Up @@ -264,6 +265,7 @@ class TrainerTrainLoopMixin(ABC):
accelerator_backend: ...
val_dataloaders: ...
train_loop: TrainLoop
data_connector: DataConnector

# Callback system
callbacks: List[Callback]
Expand Down Expand Up @@ -443,10 +445,10 @@ def run_training_epoch(self):
# track epoch output
epoch_output = [[] for _ in range(self.train_loop.num_optimizers)]

# run epoch
for batch_idx, (batch, is_last_batch) in self.profiler.profile_iterable(
enumerate(_with_is_last(train_dataloader)), "get_train_batch"
):
# enable profiling for the dataloader
train_dataloader = self.data_connector.get_profiled_train_dataloader(train_dataloader)
dataloader_idx = 0
for batch_idx, (batch, is_last_batch) in train_dataloader:
# stop epoch if we limited the number of training batches
if batch_idx >= self.num_training_batches:
break
Expand All @@ -457,7 +459,7 @@ def run_training_epoch(self):
# ------------------------------------
# TRAINING_STEP + TRAINING_STEP_END
# ------------------------------------
batch_output = self.run_training_batch(batch, batch_idx)
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)

# only track outputs when user implements training_epoch_end
# otherwise we will build up unnecessary memory
Expand All @@ -467,12 +469,8 @@ def run_training_epoch(self):
self.train_loop.checkpoint_accumulator
)

# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
# with 1 step (no tbptt) don't use a sequence at epoch end
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
opt_outputs = opt_outputs[0]
epoch_output[opt_idx].append(opt_outputs)
# hook
self.train_loop.on_train_batch_end(epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx)

# when returning -1 from train_step, we end epoch early
self.should_stop = batch_output.signal == -1
Expand Down Expand Up @@ -748,7 +746,7 @@ def should_check_val(self, batch_idx, is_last_batch):

return should_check_val

def run_training_batch(self, batch, batch_idx):
def run_training_batch(self, batch, batch_idx, dataloader_idx):
# track grad norms
grad_norm_dic = {}

Expand All @@ -767,7 +765,6 @@ def run_training_batch(self, batch, batch_idx):
return AttributeDict(signal=0, grad_norm_dic=grad_norm_dic)

# hook
dataloader_idx = 0
response = self.call_hook('on_batch_start')
if response == -1:
return AttributeDict(signal=-1, grad_norm_dic=grad_norm_dic)
Expand Down Expand Up @@ -859,12 +856,6 @@ def run_training_batch(self, batch, batch_idx):
# reset for next set of accumulated grads
self.batch_loss_value.reset()

# hook
self.call_hook('on_batch_end')

# hook
self.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx)

# collapse all metrics into one dict
batch_log_metrics = {k: v for d in batch_log_metrics for k, v in d.items()}

Expand Down Expand Up @@ -1186,16 +1177,3 @@ def update_learning_rates(self, interval: str, monitor_metrics=None):
scheduler_idx,
old_lr, new_lr
)


def _with_is_last(iterable):
"""Pass through values from the given iterable with an added boolean indicating if this is the last item.
See `https://stackoverflow.com/a/1630350 <https://stackoverflow.com/a/1630350>`_"""
it = iter(iterable)
last = next(it)
for val in it:
# yield last and has next
yield last, False
last = val
# yield last, no longer has next
yield last, True
18 changes: 18 additions & 0 deletions pytorch_lightning/trainer/training_loop_temp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pytorch_lightning.trainer.supporters import Accumulator
import numpy as np
from pytorch_lightning.core.step_result import Result


class TrainLoop:
Expand Down Expand Up @@ -27,6 +28,23 @@ def on_train_epoch_start(self):
self.early_stopping_accumulator = Accumulator()
self.checkpoint_accumulator = Accumulator()

def on_train_batch_end(self, epoch_output, epoch_end_outputs, batch, batch_idx, dataloader_idx):
# figure out what to track for epoch end
self.track_epoch_end_reduce_metrics(epoch_output, epoch_end_outputs)

# hook
self.trainer.call_hook('on_batch_end')
self.trainer.call_hook('on_train_batch_end', batch, batch_idx, dataloader_idx)

def track_epoch_end_reduce_metrics(self, epoch_output, epoch_end_outputs):
# track the outputs to reduce at the end of the epoch
for opt_idx, opt_outputs in enumerate(epoch_end_outputs):
# with 1 step (no tbptt) don't use a sequence at epoch end
if isinstance(opt_outputs, list) and len(opt_outputs) == 1 and not isinstance(opt_outputs[0], Result):
opt_outputs = opt_outputs[0]
epoch_output[opt_idx].append(opt_outputs)


def get_optimizers_iterable(self):
"""
Generates an iterable with (idx, optimizer) for each optimizer.
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_dict_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_training_step_dict(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0
Expand Down Expand Up @@ -76,7 +76,7 @@ def training_step_with_step_end(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 14.0
assert out.batch_log_metrics['log_acc2'] == 9.0
Expand Down Expand Up @@ -117,7 +117,7 @@ def test_full_training_loop_dict(tmpdir):
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 14.0
assert out.batch_log_metrics['log_acc2'] == 9.0
Expand Down Expand Up @@ -204,7 +204,7 @@ def test_train_step_epoch_end(tmpdir):
# make sure training outputs what is expected
batch_idx, batch = 0, next(iter(model.train_dataloader()))

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics['log_acc1'] == 12.0
assert out.batch_log_metrics['log_acc2'] == 7.0
Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_result_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ def test_training_step_result_log_step_only(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert out.batch_log_metrics[f'step_log_and_pbar_acc1_b{batch_idx}'] == 11.0
assert out.batch_log_metrics[f'step_log_acc2_b{batch_idx}'] == 12.0
Expand Down Expand Up @@ -144,7 +144,7 @@ def test_training_step_result_log_epoch_only(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0

Expand Down Expand Up @@ -277,7 +277,7 @@ def test_training_step_result_log_step_and_epoch(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2

Expand Down Expand Up @@ -356,7 +356,7 @@ def test_training_step_epoch_end_result(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 2

Expand Down
8 changes: 4 additions & 4 deletions tests/trainer/test_trainer_steps_scalar_return.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def test_training_step_scalar(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
Expand Down Expand Up @@ -68,7 +68,7 @@ def training_step_scalar_with_step_end(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
Expand Down Expand Up @@ -115,7 +115,7 @@ def test_full_training_loop_scalar(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
Expand Down Expand Up @@ -158,7 +158,7 @@ def test_train_step_epoch_end_scalar(tmpdir):
for batch_idx, batch in enumerate(model.train_dataloader()):
break

out = trainer.run_training_batch(batch, batch_idx)
out = trainer.run_training_batch(batch, batch_idx, 0)
assert out.signal == 0
assert len(out.batch_log_metrics) == 0 and isinstance(out.batch_log_metrics, dict)
assert len(out.grad_norm_dic) == 0 and isinstance(out.grad_norm_dic, dict)
Expand Down

0 comments on commit 0d90d53

Please sign in to comment.