Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorboard logger error: lightning_logs directory not exists in multi-node DDP on nodes with rank != 0 #1375

Closed
areshytko opened this issue Apr 4, 2020 · 0 comments · Fixed by #1377
Labels
bug Something isn't working help wanted Open to be worked on

Comments

@areshytko
Copy link
Contributor

areshytko commented Apr 4, 2020

🐛 Bug

In multi-node DDP train mode on all nodes except rank 0 errors appears at the start of the training caused by accessing lightning_logs directory in tensorboard logger which is not exist at the moment.

To Reproduce

Steps to reproduce the behavior:

  1. setup multi-node cluster (without SLURM)
  2. set environment variables on each node:
export MASTER_ADDR=<rank 0 node IP>
export MASTER_PORT=23456
export RANK=<node id>
export SLURM_NODEID=<node id>
export WORLD_SIZE=<world-size>
  1. install dependencies:
pip install torch torchvision hydra-core pytorch-lightning
  1. copy app.y and conf.yaml to each node
  2. run script on each node
python app.py
  1. see the error:
Exception:

-- Process 0 terminated with the following error:
Traceback (most recent call last):
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
    fn(i, *args)
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.6/site-packages/pytorch_lightning/trainer/distrib_data_parallel.py", line 342, in ddp_train
    self.run_pretrain_routine(model)
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.6/site-packages/pytorch_lightning/trainer/trainer.py", line 777, in run_pretrain_routine
    self.configure_checkpoint_callback()
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.6/site-packages/pytorch_lightning/trainer/callback_config.py", line 45, in configure_checkpoint_callback
    f'version_{self.logger.version}',
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.6/site-packages/pytorch_lightning/loggers/tensorboard.py", line 161, in version
    self._version = self._get_next_version()
  File "/home/ubuntu/anaconda3/envs/nightly_pt/lib/python3.6/site-packages/pytorch_lightning/loggers/tensorboard.py", line 167, in _get_next_version
    for d in os.listdir(root_dir):
FileNotFoundError: [Errno 2] No such file or directory: '/home/ubuntu/pytorch-lightning-intro-guide/outputs/2020-04-04/15-53-26/lightning_logs'

Code sample

app.py:

import pathlib

import hydra
import pytorch_lightning as pl
import torch
from omegaconf import OmegaConf
from torch.nn import functional as F
from torch.optim import Adam
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms


class LitMNIST(pl.LightningModule):
    def __init__(self):
        super().__init__()
        self.layer_1 = torch.nn.Linear(28 * 28, 128)
        self.layer_2 = torch.nn.Linear(128, 256)
        self.layer_3 = torch.nn.Linear(256, 10)

        self.train_dataset = None
        self.val_dataset = None
        self.test_dataset = None

    def forward(self, x):
        batch_size, channels, width, height = x.size()
        x = x.view(batch_size, -1)
        x = self.layer_1(x)
        x = F.relu(x)
        x = self.layer_2(x)
        x = F.relu(x)
        x = self.layer_3(x)
        x = F.log_softmax(x, dim=1)
        return x

    def prepare_data(self):
        # transform
        transform = transforms.Compose(
            [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

        # download
        data_dir = pathlib.Path.home() / 'data'
        mnist_train = datasets.MNIST(data_dir, train=True,
                                     download=True, transform=transform)
        mnist_test = datasets.MNIST(data_dir, train=False,
                                    download=True, transform=transform)

        # train/val split
        mnist_train, mnist_val = random_split(mnist_train, [55000, 5000])

        # assign to use in dataloaders
        self.train_dataset = mnist_train
        self.val_dataset = mnist_val
        self.test_dataset = mnist_test

    def train_dataloader(self):
        return DataLoader(self.train_dataset, batch_size=64)

    def val_dataloader(self):
        return DataLoader(self.val_dataset, batch_size=64)

    def test_dataloader(self):
        return DataLoader(self.test_dataset, batch_size=64)

    def configure_optimizers(self):
        return Adam(self.parameters(), lr=1e-3)

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)

        # add logging
        logs = {'loss': loss}
        return {'loss': loss, 'log': logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return {'val_loss': loss}

    def validation_epoch_end(self, outputs):
        avg_loss = torch.stack(  # pylint: disable=no-member
            [x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def test_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = F.nll_loss(logits, y)
        return {'val_loss': loss}

    def test_epoch_end(self, outputs):
        avg_loss = torch.stack(  # pylint: disable=no-member
            [x['val_loss'] for x in outputs]).mean()
        tensorboard_logs = {'val_loss': avg_loss}
        return {'avg_val_loss': avg_loss, 'log': tensorboard_logs}

    def init_ddp_connection(self, proc_rank: int, world_size: int) -> None:
        torch.distributed.init_process_group(
            'nccl', rank=proc_rank, world_size=world_size)


@hydra.main(config_path='conf.yaml')
def main(conf: OmegaConf):
    model = LitMNIST()

    trainer = pl.Trainer(gpus=conf.gpus,
                         num_nodes=conf.num_nodes,
                         distributed_backend=conf.distributed_backend,
                         max_epochs=3)
    trainer.fit(model)


if __name__ == '__main__':
    main()  # pylint: disable=no-value-for-parameter

conf.yaml:

gpus: 1
num_nodes: 2
distributed_backend: ddp

Expected behavior

Train should go without error

Environment

cuda:
	GPU:
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
		Tesla K80
	available:           True
	version:             10.1
packages:
	numpy:               1.18.1
	pyTorch_debug:       False
	pyTorch_version:     1.4.0
	pytorch-lightning:   0.7.1
	tensorboard:         2.2.0
	tqdm:                4.45.0
system:
	OS:                  Linux
	architecture:
		64bit

	processor:           x86_64
	python:              3.6.10
	version:             #113-Ubuntu SMP Wed Jan 29 14:54:54 UTC 2020

Additional context

@areshytko areshytko added bug Something isn't working help wanted Open to be worked on labels Apr 4, 2020
@areshytko areshytko changed the title Tensorboard logger error lightning_logs directory not exists in multi-node DDP on nodes with rank != 0 Tensorboard logger error: lightning_logs directory not exists in multi-node DDP on nodes with rank != 0 Apr 4, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant