Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DDP TensorMetric and NumpyMetric exception #3507

Closed
Vozf opened this issue Sep 15, 2020 · 19 comments
Closed

DDP TensorMetric and NumpyMetric exception #3507

Vozf opened this issue Sep 15, 2020 · 19 comments
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Milestone

Comments

@Vozf
Copy link
Contributor

Vozf commented Sep 15, 2020

🐛 Bug

when trying to train with metrics inherited from tensormetric or numpy metric an exception occurs
pytorch 1.6.0
pytorch-lightning == 0.9.0

CUDA_VISIBLE_DEVICES: [0,1,2]
Traceback (most recent call last):
  File "/test.py", line 38, in <module>
    main()
  File "/test.py", line 34, in main
    trainer.test(pl_model)
  File "/lib/python3.8/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1355, in test
    results = self.__test_given_model(model, test_dataloaders)
  File "/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1418, in __test_given_model
    results = self.fit(model)
  File "/lib/python3.8/site-packages/pytorch_lightning/trainer/states.py", line 48, in wrapped_fn
    result = fn(self, *args, **kwargs)
  File "/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1052, in fit
    self.accelerator_backend.train(model, nprocs=self.num_processes)
  File "/lib/python3.8/site-packages/pytorch_lightning/accelerators/ddp_spawn_backend.py", line 43, in train
    mp.spawn(self.ddp_train, nprocs=nprocs, args=(self.mp_queue, model,))
  File "/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 200, in spawn
    return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
  File "/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 149, in start_processes
    process.start()
  File "/lib/python3.8/multiprocessing/process.py", line 121, in start
    self._popen = self._Popen(self)
  File "/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
    return Popen(process_obj)
  File "/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in __init__
    super().__init__(process_obj)
  File "/lib/python3.8/multiprocessing/popen_fork.py", line 19, in __init__
    self._launch(process_obj)
  File "/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
    reduction.dump(process_obj, fp)
  File "/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
    ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object '_apply_to_outputs.<locals>.decorator_fn.<locals>.new_func'
@Vozf Vozf added bug Something isn't working help wanted Open to be worked on labels Sep 15, 2020
@Borda Borda added the Metrics label Sep 15, 2020
@Borda
Copy link
Member

Borda commented Sep 15, 2020

@Vozf mind share an example?

@Vozf
Copy link
Contributor Author

Vozf commented Sep 15, 2020

Cant really share the whole repo, but error starts reproducing with with this small change. If you have exact questions, I'll happily answer.
Evaluation begins without ddp errors(this is pl lightning implementation of mae, no changes):

class MAE(Metric):
    """
    Computes the mean absolute loss or L1-loss.

    Example:

        >>> pred = torch.tensor([0., 1, 2, 3])
        >>> target = torch.tensor([0., 1, 2, 2])
        >>> metric = MAE()
        >>> metric(pred, target)
        tensor(0.2500)

    """

    def __init__(
        self,
        reduction: str = "elementwise_mean",
    ):
        """
        Args:
            reduction: a method to reduce metric score over labels (default: takes the mean)
                Available reduction methods:
                - elementwise_mean: takes the mean
                - none: pass array
                - sum: add elements
        """
        super().__init__(name="mae")
        self.reduction = reduction

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Actual metric computation

        Args:
            pred: predicted labels
            target: ground truth labels

        Return:
            A Tensor with the mae loss.
        """
        return mae(pred, target, self.reduction)

Error mentioned above reproduces (the only change is the superclass changed to TensorMetric):

class MAE(TensorMetric):
    """
    Computes the mean absolute loss or L1-loss.

    Example:

        >>> pred = torch.tensor([0., 1, 2, 3])
        >>> target = torch.tensor([0., 1, 2, 2])
        >>> metric = MAE()
        >>> metric(pred, target)
        tensor(0.2500)

    """

    def __init__(
        self,
        reduction: str = "elementwise_mean",
    ):
        """
        Args:
            reduction: a method to reduce metric score over labels (default: takes the mean)
                Available reduction methods:
                - elementwise_mean: takes the mean
                - none: pass array
                - sum: add elements
        """
        super().__init__(name="mae")
        self.reduction = reduction

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Actual metric computation

        Args:
            pred: predicted labels
            target: ground truth labels

        Return:
            A Tensor with the mae loss.
        """
        return mae(pred, target, self.reduction)

Basically all that's done is computation of this metric

    trainer = Trainer(
        gpus=3,
        auto_select_gpus=True,
    )
    trainer.test(pl_model)

And inside test_step

self.metric = MAE()
...
self.metric(preds, y)

MAE is just an example this would reproduce for any metric with TensorMetric or NumpyMetric super class

@Borda
Copy link
Member

Borda commented Sep 15, 2020

@Vozf mind tests it on the actual master, we have fixed several issues with metrics recently...

@Borda Borda added the priority: 0 High priority task label Sep 15, 2020
@Vozf
Copy link
Contributor Author

Vozf commented Sep 15, 2020

Well the error got worse. DDP NumpyMetric.

-- Process 1 terminated with the following error:
Traceback (most recent call last):
  File "/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
    fn(i, *args)
  File "/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 164, in ddp_train_tmp
    results = self.train_or_test()
  File "/site-packages/pytorch_lightning/accelerators/base_backend.py", line 34, in train_or_test
    results = self.trainer.run_test()
  File "/site-packages/pytorch_lightning/trainer/trainer.py", line 471, in run_test
    eval_loop_results, _ = self.run_evaluation(test_mode=True)
  File "/site-packages/pytorch_lightning/trainer/trainer.py", line 433, in run_evaluation
    output = self.evaluation_loop.evaluation_step(test_mode, batch, batch_idx, dataloader_idx)
  File "/site-packages/pytorch_lightning/trainer/evaluation_loop.py", line 143, in evaluation_step
    output = self.trainer.accelerator_backend.test_step(args)
  File "/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 53, in test_step
    output = self.training_step(args)
  File "/site-packages/pytorch_lightning/accelerators/ddp_base_backend.py", line 45, in training_step
    output = self.trainer.model(*args)
  File "/site-packages/torch/nn/modules/module.py", line 722, in _call_impl
    result = self.forward(*input, **kwargs)
  File "/site-packages/pytorch_lightning/overrides/data_parallel.py", line 172, in forward
    output = self.module.test_step(*inputs[0], **kwargs[0])
  File "/", line 66, in test_step
    return self.validation_step(batch, batch_idx)
  File "/", line 50, in validation_step
    _, metrics = self._step(batch, batch_idx)
  File "/", line 75, in _step
    metrics = self._get_metrics(preds, y, loss)
  File "/", line 83, in _get_metrics
    metrics = {prefix + metric.name: metric(preds[0], y) for metric in metrics_list}
  File "/", line 83, in <dictcomp>
    metrics = {prefix + metric.name: metric(preds[0], y) for metric in metrics_list}
  File "/site-packages/torch/nn/modules/module.py", line 726, in _call_impl
    hook_result = hook(self, input, result)
  File "/site-packages/pytorch_lightning/metrics/metric.py", line 152, in ddp_reduce
    synced = self.ddp_sync(output)
  File "/site-packages/pytorch_lightning/metrics/metric.py", line 133, in ddp_sync
    gathered_tensors = apply_to_collection(tensor, torch.Tensor, gather_all_tensors_if_available, self.reduce_group)
  File "/site-packages/pytorch_lightning/utilities/apply_func.py", line 49, in apply_to_collection
    return function(data, *args, **kwargs)
  File "/site-packages/pytorch_lightning/metrics/converters.py", line 324, in gather_all_tensors_if_available
    torch.distributed.all_gather(gathered_result, result, group)
  File "/site-packages/torch/distributed/distributed_c10d.py", line 1185, in all_gather
    work = _default_pg.allgather([tensor_list], [tensor])
RuntimeError: Tensors must be CUDA and dense

For single gpu everything works like in 0.9.0

@Vozf
Copy link
Contributor Author

Vozf commented Sep 15, 2020

The error above is for NumpyMetric(edited comment)
For TensorMetric the incoming pred and target are not gpu tensors. They are cpu tensors and it results in the following error (my TensorMetric code is gpu only)

  File "/h.py", line 67, in _eval_pr
    y_temp = (y_pred >= thlist[i]).float()
  File "/home//envs/rnd/lib/python3.8/site-packages/torch/tensor.py", line 22, in wrapped
    return f(*args, **kwargs)
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:2 and cpu

This occurs even with single gpu, so that seems like a major bug in current master(0.9.1rc3)
This was not happening on 0.9.0. Pred and target were gpu tensors

@SkafteNicki
Copy link
Member

@Vozf could you move the initialization of the metric to __init__ of your model. Metrics should behave similar to nn.Module and move calculations to the correct device if initialized correctly.

@Vozf
Copy link
Contributor Author

Vozf commented Sep 16, 2020

It's inside init. Sorry for misunderstanding, I wanted to clarify what self.metric is

@SkafteNicki
Copy link
Member

Could you provide a colab notebook where I can reproduce the errors you are getting?

@Vozf
Copy link
Contributor Author

Vozf commented Sep 16, 2020

Does the colab have multigpu to reproduce ddp?

@Borda
Copy link
Member

Borda commented Sep 16, 2020

@Vozf unfortunately just one GPU or TPU

@Vozf
Copy link
Contributor Author

Vozf commented Sep 16, 2020

So it doesn't seem like an option unless you are planning to download notebook and execute it with multigpu

@SkafteNicki
Copy link
Member

I can execute on multi gpu just need some (minimal) code that can reproduce your error.

@Vozf
Copy link
Contributor Author

Vozf commented Sep 16, 2020

Ok, I'll try to build an example

@Vozf
Copy link
Contributor Author

Vozf commented Sep 16, 2020

import os

import torch
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms

import pytorch_lightning as pl
from pytorch_lightning.metrics import NumpyMetric, Metric, TensorMetric
from pytorch_lightning.metrics.functional.regression import mae, mse, psnr, rmse, rmsle, ssim

# case 1 Metric - works fine with multiple gpu
# class MAE(Metric):

# case 2 TensorMetric - doesnt work
class MAE(TensorMetric):
    # ForkingPickler(file, protocol).dump(obj)
    # AttributeError: Can't pickle local object '_apply_to_outputs.<locals>.decorator_fn.<locals>.new_func'

    def __init__(
        self,
        reduction: str = "elementwise_mean",
    ):
        """
        Args:
            reduction: a method to reduce metric score over labels (default: takes the mean)
                Available reduction methods:
                - elementwise_mean: takes the mean
                - none: pass array
                - sum: add elements
        """
        super().__init__(name="mae")
        self.reduction = reduction

    def forward(self, pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
        """
        Actual metric computation

        Args:
            pred: predicted labels
            target: ground truth labels

        Return:
            A Tensor with the mae loss.
        """
        #         print('shape', pred.shape, target.shape)
        #         print(pred[0], target[0])
        return mae(pred, target, self.reduction)


class CoolSystem(pl.LightningModule):
    def __init__(self, classes=10):
        super().__init__()
        self.save_hyperparameters()
        self.metric = MAE()

        # not the best model...
        self.l1 = torch.nn.Linear(28 * 28, self.hparams.classes)

    def forward(self, x):
        #         print('forward shape', x.view(x.size(0), -1).shape, torch.relu(self.l1(x.view(x.size(0), -1))).shape)
        return torch.relu(self.l1(x.view(x.size(0), -1)))

    def training_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)
        loss = F.cross_entropy(y_hat, y)

        y_hat_metr = torch.argmax(y_hat, dim=1)
        metric = self.metric(y_hat_metr.float(), y.float())

        tensorboard_logs = {"train_metric": metric}
        return {"loss": loss, "log": tensorboard_logs}

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self(x)

        y_hat_metr = torch.argmax(y_hat, dim=1)
        metric = self.metric(y_hat_metr.float(), y.float())
        #         print(metric)

        loss = F.cross_entropy(y_hat, y)
        return {"val_metric": metric}

    #     def validation_epoch_end(self, outputs):
    #         avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
    #         return {'val_loss': avg_loss}

    def configure_optimizers(self):
        return torch.optim.Adam(self.parameters(), lr=0.001)


from pytorch_lightning import Trainer, seed_everything

if __name__ == "__main__":
    seed_everything(0)

    # data
    mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
    mnist_train = DataLoader(mnist_train, batch_size=32, num_workers=4)
    mnist_val = MNIST(os.getcwd(), train=True, download=True, transform=transforms.ToTensor())
    mnist_val = DataLoader(mnist_val, batch_size=32, num_workers=4)

    # model
    model = CoolSystem()

    # most basic trainer, uses good defaults
    trainer = Trainer(gpus=2, progress_bar_refresh_rate=20, max_epochs=10)
    trainer.fit(model, mnist_train, mnist_val)

Tested with 0.9.0. The error reproduces. The main part is the super class of mae in 1 case it works in other it raises ddp exception.

The errors I mentioned with 0.9.1 should reproduce too, but I didn't test with it

@Vozf
Copy link
Contributor Author

Vozf commented Sep 16, 2020

If changed to single gpu(gpus=[0]) It works in both cases

@edenlightning edenlightning added this to the 0.9.x milestone Sep 16, 2020
@SkafteNicki
Copy link
Member

I took a look at this, and cannot reproduce this on master which means that your problems was most likely solved by PR #3245. Thus, if you install that latest version of master you should be fine.

@Vozf
Copy link
Contributor Author

Vozf commented Sep 17, 2020

Tried to reproduce on master, and wasn't able to.
Probably I'm wrong and something is wrong with my code, but I still believe there is a problem with ddp metric aggregation.
With my real code both errors still reproduce
My numpyMetric is still not working saying RuntimeError: Tensors must be CUDA and dense. When aggregating.
Unfortunately got no time to create more artificial examples. Therefore closing

@Vozf Vozf closed this as completed Sep 17, 2020
@SkafteNicki
Copy link
Member

@Vozf I promise to take a look at the NumpyMetric one of the following days myself.
I am aware that aggregation may be unstable. I have opened PR #3517 where I will add test for all metrics to make sure that aggregation is working for each and everyone (however, this may take some time).

@Vozf
Copy link
Contributor Author

Vozf commented Sep 17, 2020

If that helps
My intuition during debugging was that numpyMetric np.array is getting converted to torch.Tensor(device=cpu) instead of gpu and that leads to this error
work = _default_pg.allgather([tensor_list], [tensor])
RuntimeError: Tensors must be CUDA and dense

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

No branches or pull requests

4 participants