-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
GPU ram memory increase until overflow when using PSNR and SSIM #2597
Comments
Hi! thanks for your contribution!, great first issue! |
@ouioui199 looking at your example (could you pls share the full sample code?) and wondering if you in the epoch end hook also call compute? |
Hi, sorry for the slow reply from my side. Computational graphI am not sure where you have got this from but the computational graph is not auto detached from the metric in general. It depends on the import torch
from torchmetrics.image import PeakSignalNoiseRatio
psnr = PeakSignalNoiseRatio()
print(psnr.is_differentiable) # returns True
preds = torch.tensor([[0.0, 1.0], [2.0, 3.0]], requires_grad=True)
target = torch.tensor([[3.0, 2.0], [1.0, 0.0]])
print(psnr(preds, target)) # returns tensor(2.5527, grad_fn=<CloneBackward0>) whereas a metric such as from torchmetrics.classification import Accuracy
accuracy = Accuracy(task="multiclass", num_classes=4)
print(accuracy.is_differentiable) # returns False
preds = torch.tensor([0, 1, 2, 3], requires_grad=True, dtype=torch.float)
target = torch.tensor([0, 2, 2, 3])
print(accuracy(preds, target)) # returns tensor(0.7500) so call detach before evaluating the metric would be the correct approach here. Memory keeps increasingIt is fairly well laid out in the documentation that there are two different kind of metrics in torchmetrics, metrics with constant memory and metrics with increasing memory: from torchmetrics.image import PeakSignalNoiseRatio, StructuralSimilarityIndexMeasure
psnr = PeakSignalNoiseRatio(reduction="elementwise_mean")
ssim = StructuralSimilarityIndexMeasure(reduction="elementwise_mean") will solve this issue. Closing issue. |
🐛 Bug
Hello all,
I'm implementing CycleGAN with Lightning. I use PSNR and SSIM from torchmetrics for evaluation.
During training, I see that my GPU ram memory increases non stop until overflow and the whole training shuts down.
This might similar to #2481
To Reproduce
Add this to init method of model class:
In training_step method:
train_metrics = self.train_metrics(fake, real)
In validation_step method:
valid_metrics = self.valid_metrics(fake, real)
Environment
Easy fix proposition
I try to debug the code.
When verifying train_metrics, I get this:
which is weird because metrics aren't supposed to be attached to computational graph.
When verifying valid_metrics, I don't see grad_fn.
Guessing that's the issue, I tried to call fake.detach() when computing train_metrics.
Now the training is stable, the GPU memory stops increasing non stop.
The text was updated successfully, but these errors were encountered: