-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
remove trainer hidden state | sanity refactor [2 / n] #7507
Changes from 11 commits
76004b7
0c5cd0c
a8171ab
211e250
9370329
79ad4e3
cb09423
cdea059
9756178
f397a27
3548978
44c08d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,7 +39,6 @@ class TrainLoop: | |
def __init__( | ||
self, | ||
trainer, | ||
multiple_trainloader_mode: str, | ||
max_epochs: Optional[int], | ||
min_epochs: Optional[int], | ||
max_steps: Optional[int], | ||
|
@@ -53,17 +52,21 @@ def __init__( | |
self.running_loss = TensorRunningAccum(window_length=20) | ||
self._curr_step_result = None | ||
self._cur_grad_norm_dict = None | ||
self._multiple_trainloader_mode = multiple_trainloader_mode | ||
self._skip_backward = False | ||
self.trainer._multiple_trainloader_mode = multiple_trainloader_mode | ||
self._optimizer_freq_cumsum = None | ||
self._hiddens = None | ||
awaelchli marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
self.global_step = 0 | ||
self.current_epoch = 0 | ||
self.trainer.should_stop = False | ||
|
||
# the total batch index across all epochs | ||
self.total_batch_idx = 0 | ||
# the current batch index in the loop that runs over the dataloader(s) | ||
self.batch_idx = 0 | ||
# the current split index when the batch gets split into chunks in truncated backprop through time | ||
self.split_idx = None | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. no but for now I follow the above pattern and the major goal is to define this state strictly on the loop, not the trainer anymore. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
yes There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay added comments |
||
|
||
self.trainer.num_training_batches = 0 | ||
self.trainer.train_dataloader = None | ||
|
||
|
@@ -338,7 +341,7 @@ def _process_training_step_output(self, training_step_output, split_batch): | |
|
||
# map to results under the hood | ||
result.minimize = loss | ||
self.trainer.hiddens = hiddens | ||
self._hiddens = hiddens | ||
|
||
# track batch for manual reduction with result | ||
result.track_batch_size(len(split_batch)) | ||
|
@@ -480,7 +483,6 @@ def run_training_epoch(self): | |
|
||
for batch_idx, (batch, is_last_batch) in train_dataloader: | ||
self.batch_idx = batch_idx | ||
self.trainer.is_last_batch = is_last_batch | ||
|
||
# ------------------------------------ | ||
# TRAINING_STEP + TRAINING_STEP_END | ||
|
@@ -657,7 +659,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): | |
grad_norm_dic = {} | ||
|
||
# bookkeeping | ||
self.trainer.hiddens = None | ||
self._hiddens = None | ||
|
||
optimizers = self.prepare_optimizers() | ||
|
||
|
@@ -686,6 +688,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): | |
splits = self._tbptt_split_batch(batch) | ||
|
||
for split_idx, split_batch in enumerate(splits): | ||
self.split_idx = split_idx | ||
carmocca marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
# create an iterable for optimizers and loop over them | ||
for opt_idx, optimizer in optimizers: | ||
|
@@ -703,9 +706,7 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): | |
# automatic_optimization=True: perform dpp sync only when performing optimizer_step | ||
# automatic_optimization=False: don't block synchronization here | ||
with self.block_ddp_sync_behaviour(): | ||
self.training_step_and_backward( | ||
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens | ||
) | ||
self.training_step_and_backward(split_batch, batch_idx, opt_idx, optimizer, self._hiddens) | ||
|
||
batch_outputs = self._process_closure_result( | ||
batch_outputs=batch_outputs, | ||
|
@@ -722,17 +723,15 @@ def run_training_batch(self, batch, batch_idx, dataloader_idx): | |
|
||
def train_step_and_backward_closure(): | ||
result = self.training_step_and_backward( | ||
split_batch, batch_idx, opt_idx, optimizer, self.trainer.hiddens | ||
split_batch, batch_idx, opt_idx, optimizer, self._hiddens | ||
) | ||
return None if result is None else result.loss | ||
|
||
# optimizer step | ||
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure) | ||
|
||
else: | ||
self._curr_step_result = self.training_step( | ||
split_batch, batch_idx, opt_idx, self.trainer.hiddens | ||
) | ||
self._curr_step_result = self.training_step(split_batch, batch_idx, opt_idx, self._hiddens) | ||
|
||
if self._curr_step_result is None: | ||
# user decided to skip optimization | ||
|
@@ -983,9 +982,6 @@ def prepare_optimizers(self): | |
return optimizers | ||
|
||
def run_train_split_start(self, split_idx, split_batch, opt_idx, optimizer): | ||
# set split_idx to trainer for tracking | ||
self.trainer.split_idx = split_idx | ||
|
||
# make sure only the gradients of the current optimizer's parameters are calculated | ||
# in the training step to prevent dangling gradients in multiple-optimizer setup. | ||
if self.trainer.lightning_module.automatic_optimization and len(self.trainer.optimizers) > 1: | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is
get_progress_bar_dict
still needed on the lightning module? wouldn't things come fromself.log
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's for users to override and customize the default elements in the progress bar. Like how the version number is displayed or apparently the split index here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's not something one could customize through self.log directly I would say.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is it customization we would want to push to a custom progress bar callback instead of being part of the core module?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IMO yes for maximum separation of concerns
But people might complain about having to define a custom progress bar just for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
one could argue maybe the split_idx is not very useful to display in the progbar but still I would keep the hook