Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix logger creating directory structure too early in DDP #6380

Merged
merged 11 commits into from
Mar 9, 2021
4 changes: 2 additions & 2 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))


- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))


- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
Expand All @@ -107,7 +107,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))


- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))


## [1.2.2] - 2021-03-02
Expand Down
23 changes: 7 additions & 16 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,21 +381,6 @@ def __init__(
# Callback system
self.on_init_end()

def setup_trainer(self, model: LightningModule):
"""
Sanity check a few things before starting actual training or testing.

Args:
model: The model to run sanity test on.
"""

# log hyper-parameters
if self.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
self.logger.log_hyperparams(model.hparams_initial)
self.logger.log_graph(model)
self.logger.save()

def fit(
self,
model: LightningModule,
Expand Down Expand Up @@ -444,7 +429,6 @@ def fit(
self.call_setup_hook(model)
self.call_hook("on_before_accelerator_backend_setup", model)
self.accelerator.setup(self, model)
self.setup_trainer(model)

# ----------------------------
# INSPECT THE CORE LOOPS
Expand Down Expand Up @@ -509,6 +493,13 @@ def fit(
def pre_dispatch(self):
self.accelerator.pre_dispatch()

# log hyper-parameters
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we wrap this into a function ?

if self.logger is not None:
# save exp to get started (this is where the first experiment logs are written)
self.logger.log_hyperparams(self.lightning_module.hparams_initial)
self.logger.log_graph(self.lightning_module)
self.logger.save()

def post_dispatch(self):
self.accelerator.post_dispatch()
self.accelerator.teardown()
Expand Down
41 changes: 39 additions & 2 deletions tests/trainer/logging_/test_distributed_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from unittest import mock
from unittest.mock import Mock

from pytorch_lightning import Trainer
from pytorch_lightning import Callback, Trainer
from tests.helpers import BoringModel
from tests.helpers.runif import RunIf

Expand Down Expand Up @@ -66,3 +67,39 @@ def test_global_zero_only_logging_ddp_spawn(tmpdir):
weights_summary=None,
)
trainer.fit(model)


def test_first_logger_call_in_subprocess(tmpdir):
"""
Test that the Trainer does not call the logger too early. Only when the worker processes are initialized
do we have access to the rank and know which one is the main process.
"""

class LoggerCallsObserver(Callback):

def on_fit_start(self, trainer, pl_module):
# this hook is executed directly before Trainer.pre_dispatch
# logger should not write any logs until this point
assert not trainer.logger.method_calls
assert not os.listdir(trainer.logger.save_dir)

def on_train_start(self, trainer, pl_module):
assert trainer.logger.method_call
trainer.logger.log_hyperparams.assert_called_once()
trainer.logger.log_graph.assert_called_once()

logger = Mock()
logger.version = "0"
logger.name = "name"
logger.save_dir = tmpdir

model = BoringModel()
trainer = Trainer(
default_root_dir=tmpdir,
limit_train_batches=1,
limit_val_batches=1,
max_epochs=1,
logger=logger,
callbacks=[LoggerCallsObserver()]
)
trainer.fit(model)