-
Notifications
You must be signed in to change notification settings - Fork 3.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
track_epoch_end_reduce_metrics and memory consumption #7498
Comments
Sorry it is my bad. I cant seem to reproduce this error. Although the behaviour of memory consumption still persists, I am not quite sure whether the aforemented method causes this issue. |
Nope, I do not return any object neither from step methods, nor from on_*_epoch_end as show below :( import pytorch_lightning as pl
import torch
from torch import nn
from torch.nn import functional as F
from pytorch_lightning.metrics.functional import accuracy
from typing import List, Any, Tuple
from torch.nn.init import xavier_normal_
class BaseKGE(pl.LightningModule):
def __init__(self, learning_rate=.1):
super().__init__()
self.name = 'Not init'
self.learning_rate = learning_rate
def configure_optimizers(self):
return torch.optim.Adam(self.parameters(), lr=self.learning_rate)
def loss_function(self, y_hat, y):
return self.loss(y_hat, y)
def forward_triples(self, *args, **kwargs):
raise ValueError(f'MODEL:{self.name} does not have forward_triples function')
def forward_k_vs_all(self, *args, **kwargs):
raise ValueError(f'MODEL:{self.name} does not have forward_k_vs_all function')
def forward(self, x):
if len(x) == 3:
h, r, t = x[0], x[1], x[2]
return self.forward_triples(h, r, t)
elif len(x) == 2:
h, y = x[0], x[1]
# Note that y can be relation or tail entity.
return self.forward_k_vs_all(h, y)
else:
raise ValueError('Not valid input')
def training_step(self, batch, batch_idx):
x_batch, y_batch = batch
train_loss = self.loss_function(self.forward(x_batch), y_batch)
return {'loss': train_loss}
#def training_epoch_end(self, outputs) -> None:
# """ DBpedia debugging removed."""
# #avg_loss = torch.stack([x['loss'] for x in outputs]).mean()
# #self.log('avg_loss', avg_loss, on_epoch=False, prog_bar=True) Thanks for the tip! I reckon that the two PRs are closely related with my issue |
I think it quite expected that the input companion during each epoch takes some memory and after all is computed at the end, them memory drops to normal... 🐰 |
🐛 Bug
track_epoch_end_reduce_metrics leads an increase in the memory consumption between epochs.
The text was updated successfully, but these errors were encountered: