Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Retrieval Metrics GPU Memory Leak #2481

Closed
astirn opened this issue Mar 29, 2024 · 6 comments
Closed

Retrieval Metrics GPU Memory Leak #2481

astirn opened this issue Mar 29, 2024 · 6 comments
Assignees
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x waiting on author

Comments

@astirn
Copy link

astirn commented Mar 29, 2024

🐛 Bug

Updating a RetrievalMetric in training and/or validation results in runaway GPU memory usage.

To Reproduce

In a LightningModule's def __init__(...), declare

metrics = tm.MetricCollection({
   'mrr': tm.RetrievalMRR()
})
self.train_retrieval_metrics = metrics.clone(prefix='train_')
self.val_retrieval_metrics = metrics.clone(prefix='val_')

Then in training_step(self, batch, batch_idx) and validation_step(self, batch, batch_idx), respectively call

self.train_retrieval_metrics(logits, targets, indexes)

and

self.val_retrieval_metrics(logits, targets, indexes)

where logits is a batch size x batch size matrix of logits. You can assume targets and indexes are correct since I am getting good MRR measurements.

I believe I tracked down the problem. In retrieval/base.py, the RetrievalMetric class initializes

self.add_state("indexes", default=[], dist_reduce_fx=None)
self.add_state("preds", default=[], dist_reduce_fx=None)
self.add_state("target", default=[], dist_reduce_fx=None)

using python lists. Because these are lists, the Metric class's reset(self) in metric.py

for attr, default in self._defaults.items():
    current_val = getattr(self, attr)
    if isinstance(default, Tensor):
        setattr(self, attr, default.detach().clone().to(current_val.device))
    else:
        setattr(self, attr, [])

the setattr(self, attr, default.detach().clone().to(current_val.device)) is never reached.

Expected behavior

GPU memory does not increase overtime.

Environment

  • TorchMetrics version: 1.3.2
  • Python & PyTorch Version: 3.11.8 & 2.2.1

Additional context

@astirn astirn added bug / fix Something isn't working help wanted Extra attention is needed labels Mar 29, 2024
Copy link

Hi! thanks for your contribution!, great first issue!

@lucadiliello
Copy link
Contributor

Hello @astirn. I think the behaviour is correct, as Retrieval Metrics are designed only to be globally valid, and to be computed on the whole dataset, thus they will accumulate results till the end of the epoch.

Please try the following in your LightningModule:

def training_step(self, batch, batch_idx):
    ...
    self.train_retrieval_metrics.update(logits, targets, indexes)

def on_train_epoch_end(self):
    metrics = self.train_retrieval_metrics.compute()

With this modification, you will get correct metrics and compute them only once at the end of the training. You may consider moving results to cpu to save GPU memory if training is long (i.e. pre-training). Just add compute_on_cpu=True to metrics instantiation.

Another case is if you don't have indexes that shard across several batches (basically you have no overlapping queries between different batches). Then, you could use functional metrics directly:

def training_step(self, batch, batch_idx):
   ...
   mrr = retrieval_reciprocal_rank(logits, targets)

Be aware that the latter example assumes indexes for each batch would be like [0, ..., 0] for batch_0, [1, ..., 1] for batch_1 and so on.

@astirn
Copy link
Author

astirn commented May 20, 2024

Thanks for your reply. I believe Lightning AI already makes this call. IIRC, the problem I had is that GPU memory usage increases over epochs, not just within an epoch as your reply suggests it should.

I wrote my own MRR code that solves this problem and have been using since I filed this issue. Feel free to close this issue :)

@lucadiliello
Copy link
Contributor

lucadiliello commented May 21, 2024

Could you please share your implementation of MRR to understand whether there is a bug on our end?

@Borda
Copy link
Member

Borda commented Aug 2, 2024

@astirn was it resolved or could you share a sample code to reproduce it as @lucadiliello suggested?

@astirn
Copy link
Author

astirn commented Aug 6, 2024

Unfortunately, I do not have time to ablate proprietary code to share with you. I tried to give enough information in my original post. Please feel free to close the issue if you don't think what I found is a problem.

@astirn astirn closed this as completed Aug 6, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v1.3.x waiting on author
Projects
None yet
Development

No branches or pull requests

4 participants
@Borda @lucadiliello @astirn and others