Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bounded memory leak caused by trainer.evalutaion_loop.outputs #5735

Closed
roytseng-tw opened this issue Feb 1, 2021 · 0 comments · Fixed by #6326
Closed

Bounded memory leak caused by trainer.evalutaion_loop.outputs #5735

roytseng-tw opened this issue Feb 1, 2021 · 0 comments · Fixed by #6326
Assignees
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task

Comments

@roytseng-tw
Copy link

🐛 Bug

trainer.evalutaion_loop.outputs caches the outputs of every validation steps in def run_evaluation(self, max_batches=None): of trainer:
https://github.com/PyTorchLightning/pytorch-lightning/blob/d71659b42a13946b854d49f5bb1bf6e2bcd5b9b2/pytorch_lightning/trainer/trainer.py#L659

It's not reset until the start of the next validation epoch:

https://github.com/PyTorchLightning/pytorch-lightning/blob/d71659b42a13946b854d49f5bb1bf6e2bcd5b9b2/pytorch_lightning/trainer/trainer.py#L621

https://github.com/PyTorchLightning/pytorch-lightning/blob/d71659b42a13946b854d49f5bb1bf6e2bcd5b9b2/pytorch_lightning/trainer/evaluation_loop.py#L124-L128

Please reproduce using the BoringModel

To Reproduce

Sorry, my working environment forbids me to use google drive.

import torch
from pytorch_lightning import LightningModule
from torch.utils.data import Dataset

class BoringModel(LightningModule):

    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def loss(self, batch, prediction):
        # An arbitrary loss to have a loss that updates the model weights during `Trainer.fit` calls
        return torch.nn.functional.mse_loss(prediction, torch.ones_like(prediction))
    
    def on_train_epoch_start(self):
        print('Before delete:', torch.cuda.memory_allocated())
        for out in self.trainer.evaluation_loop.outputs[0]:
          if 'x' in out:
            del out['x']
        print('After delete:', torch.cuda.memory_allocated())

    def training_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"loss": loss}

    def training_step_end(self, training_step_outputs):
        return training_step_outputs

    def training_epoch_end(self, outputs) -> None:
        torch.stack([x["loss"] for x in outputs]).mean()

    def validation_step(self, batch, batch_idx):
        output = self.layer(batch)
        loss = self.loss(batch, output)
        return {"x": loss}

    def validation_epoch_end(self, outputs) -> None:
        torch.stack([x['x'] for x in outputs]).mean()

    def configure_optimizers(self):
        optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
        lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
        return [optimizer], [lr_scheduler]
def test_x(tmpdir):
    # init model
    model = BoringModel()

    # Initialize a trainer
    trainer = pl.Trainer(
        max_epochs=1, 
        progress_bar_refresh_rate=20,
        gpus=[0],
    )

    # Train the model ⚡
    trainer.fit(model, train, val)

Execution

test_x(tmpdir)

Output

GPU available: True, used: True
TPU available: None, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name  | Type   | Params
---------------------------------
0 | layer | Linear | 66    
---------------------------------
66        Trainable params
0         Non-trainable params
66        Total params
Validation sanity check: 0%
0/2 [00:00<?, ?it/s]
Epoch 0: 100%
626/626 [00:00<00:00, 653.87it/s, loss=2.5e-14, v_num=6]
Before delete: 2048
After delete: 1024

Expected behavior

There shouldn't be such cached tensors.
This may cause OOM in some cases that OOM can be avoided.
For example,

  • On the first training epoch, a model fitted just right in the GPU memory runs fine w/o OOM.
  • After the first validation epoch, some GPU tensors are retained and occupy some portion of the memory.
  • On the second training epoch, the same model encounters OOM error.

Clear all references to those validation output tensors at the end of the validation epoch.
Maybe, more specifically, at here?
https://github.com/PyTorchLightning/pytorch-lightning/blob/d71659b42a13946b854d49f5bb1bf6e2bcd5b9b2/pytorch_lightning/trainer/evaluation_loop.py#L224-L229

Environment

  • CUDA:
    • GPU:
      • Tesla T4
    • available: True
    • version: 10.1
  • Packages:
    • numpy: 1.19.5
    • pyTorch_debug: True
    • pyTorch_version: 1.7.0+cu101
    • pytorch-lightning: 1.1.6
    • tqdm: 4.41.1
  • System:
    • OS: Linux
    • architecture:
      • 64bit
    • processor: x86_64
    • python: 3.6.9
    • version: Proposal for help #1 SMP Thu Jul 23 08:00:38 PDT 2020

Additional context

@roytseng-tw roytseng-tw added bug Something isn't working help wanted Open to be worked on labels Feb 1, 2021
@edenlightning edenlightning added the priority: 0 High priority task label Feb 9, 2021
@tchaton tchaton assigned tchaton and unassigned kaushikb11 Mar 3, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working help wanted Open to be worked on priority: 0 High priority task
Projects
None yet
Development

Successfully merging a pull request may close this issue.

4 participants