Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Out-of-memory on GPU due to the "weak_script" decorators #20588

Closed
zhangguanheng66 opened this issue May 16, 2019 · 5 comments
Closed

Out-of-memory on GPU due to the "weak_script" decorators #20588

zhangguanheng66 opened this issue May 16, 2019 · 5 comments
Assignees
Labels
high priority oncall: jit Add this issue/PR to JIT oncall triage queue

Comments

@zhangguanheng66
Copy link
Contributor

zhangguanheng66 commented May 16, 2019

🐛 Bug

The issue has been resolved with a recently merged PR (#20563). This issue report is here for the record and future benchmark. The issue is related to the local scope of a weak-scripted function, which cause a memory leak.

We have the out-of-memory issue when running nn.MultiheadAttention module on CUDA. This happened since we split the forward function of nn.MultiheadAttention module and move major calculation to torch.nn.functional.py.

To fix the issue in the merged PR, we had to remove the "weak_script" decorators in multi_head_attention_forward() function.

To Reproduce

Steps to reproduce the behavior:

  1. Make sure you are on commit "6e82b1c77d36386ba738af3287693105b4bbafe2"
  2. Use the following script on GPU to reproduce the OOM error message.

import torch
import torch.nn as nn

d_model = 512
nhead = 16
bptt = 10
batch_size = 15
device = torch.device("cuda")

norm = nn.LayerNorm(d_model).to(device)
self_attn = nn.MultiheadAttention(d_model, nhead).to(device)
src_seq = torch.rand((bptt, batch_size, d_model)).to(device)

for _ in range(200000):
src = norm(src_seq)
output = self_attn(src, src, src)

Expected behavior

RuntimeError: CUDA out of memory. Tried to allocate 2.00 MiB (GPU 0; 15.90 GiB total capacity; 13.49 GiB already allocated; 1.56 MiB free; 1.87 GiB cached)

Environment

Collecting environment information...
PyTorch version: 1.1.0a0+1d33ab8
Is debug build: No
CUDA used to build PyTorch: 9.2.88

OS: Ubuntu 18.04.1 LTS
GCC version: (Ubuntu 7.3.0-27ubuntu1~18.04) 7.3.0
CMake version: version 3.12.2

Python version: 3.7
Is CUDA available: Yes
CUDA runtime version: 9.2.88
GPU models and configuration:
GPU 0: Quadro GP100
GPU 1: Quadro GP100

Nvidia driver version: 410.79
cuDNN version: Could not collect

Versions of relevant libraries:
[pip] numpy==1.15.4
[pip] numpydoc==0.8.0
[pip] torch==1.1.0a0+1d33ab8
[conda] blas 1.0 mkl
[conda] magma-cuda90 2.5.0 1 pytorch
[conda] mkl 2019.1 144
[conda] mkl-include 2019.3 199
[conda] mkl-service 1.1.2 py37he904b0f_5
[conda] mkl_fft 1.0.6 py37hd81dba3_0
[conda] mkl_random 1.0.2 py37hd81dba3_0
[conda] torch 1.1.0a0+1d33ab8 dev_0

Additional context

@zhangguanheng66 zhangguanheng66 self-assigned this May 16, 2019
@cpuhrsch cpuhrsch added high priority oncall: jit Add this issue/PR to JIT oncall triage queue labels May 16, 2019
@zhangguanheng66
Copy link
Contributor Author

Another issue is related to jit when using torch.nn.MultiheadAttention module. #20722

@zhangguanheng66
Copy link
Contributor Author

@suo We are currently experience the JIT issue when implementing torch.nn.MultiheadAttention module. Although we have submitted a PR to work around the issue, it would be very helpful to resolve it since our transformer model/PR heavily depends on the torch.nn.MultiheadAttention module. There is another issue related to torchscript #20722.

@ailzhang
Copy link
Contributor

@zhangguanheng66 is it easy to get a repro from master? (do I have to revert a lot?

@zhangguanheng66
Copy link
Contributor Author

@ailzhang if you are on the master branch, simply add the "@weak_script" decorators back to the internal functions (see multi_head_attention_forward in torch/nn/functional.py) and you should be able to reproduce the bug. We also notice numerical discrepancy in the jit model and will ask for help in another PR.

@ailzhang ailzhang removed their assignment Jun 28, 2019
@driazati
Copy link
Contributor

driazati commented Jul 3, 2019

#22212 deletes all the weak script decorators, so that should fix this issue

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority oncall: jit Add this issue/PR to JIT oncall triage queue
Projects
None yet
Development

No branches or pull requests

6 participants