-
Notifications
You must be signed in to change notification settings - Fork 412
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Retrieval Metrics GPU Memory Leak #2481
Comments
Hi! thanks for your contribution!, great first issue! |
Hello @astirn. I think the behaviour is correct, as Retrieval Metrics are designed only to be globally valid, and to be computed on the whole dataset, thus they will accumulate results till the end of the epoch. Please try the following in your def training_step(self, batch, batch_idx):
...
self.train_retrieval_metrics.update(logits, targets, indexes)
def on_train_epoch_end(self):
metrics = self.train_retrieval_metrics.compute() With this modification, you will get correct metrics and compute them only once at the end of the training. You may consider moving results to Another case is if you don't have indexes that shard across several batches (basically you have no overlapping queries between different batches). Then, you could use def training_step(self, batch, batch_idx):
...
mrr = retrieval_reciprocal_rank(logits, targets) Be aware that the latter example assumes indexes for each batch would be like |
Thanks for your reply. I believe Lightning AI already makes this call. IIRC, the problem I had is that GPU memory usage increases over epochs, not just within an epoch as your reply suggests it should. I wrote my own MRR code that solves this problem and have been using since I filed this issue. Feel free to close this issue :) |
Could you please share your implementation of MRR to understand whether there is a bug on our end? |
@astirn was it resolved or could you share a sample code to reproduce it as @lucadiliello suggested? |
Unfortunately, I do not have time to ablate proprietary code to share with you. I tried to give enough information in my original post. Please feel free to close the issue if you don't think what I found is a problem. |
🐛 Bug
Updating a
RetrievalMetric
in training and/or validation results in runaway GPU memory usage.To Reproduce
In a
LightningModule
'sdef __init__(...)
, declareThen in
training_step(self, batch, batch_idx)
andvalidation_step(self, batch, batch_idx)
, respectively calland
where
logits
is a batch size x batch size matrix of logits. You can assumetargets
andindexes
are correct since I am getting good MRR measurements.I believe I tracked down the problem. In
retrieval/base.py
, theRetrievalMetric
class initializesusing python lists. Because these are lists, the
Metric
class'sreset(self)
inmetric.py
the
setattr(self, attr, default.detach().clone().to(current_val.device))
is never reached.Expected behavior
GPU memory does not increase overtime.
Environment
Additional context
The text was updated successfully, but these errors were encountered: