Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixes crash in deformable convolutions (2598) #2604

Merged
merged 3 commits into from
Aug 24, 2020
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 29 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -553,6 +553,35 @@ def script_func(x_, offset_, weight_, bias_, stride_, pad_, dilation_):
gradcheck(lambda z, off, wei, bi: script_func(z, off, wei, bi, stride, padding, dilation),
(x, offset, weight, bias), nondet_tol=1e-5)

# Test from https://github.com/pytorch/vision/issues/2598
# Run on CUDA only
if "cuda" in device.type:
# compare grads computed on CUDA with grads computed on CPU
true_cpu_grads = None

init_weight = torch.randn(9, 9, 3, 3, requires_grad=True)
img = torch.randn(8, 9, 1000, 110)
offset = torch.rand(8, 2 * 3 * 3, 1000, 110)

if not contiguous:
img = img.permute(0, 1, 3, 2).contiguous().permute(0, 1, 3, 2)
offset = offset.permute(1, 3, 0, 2).contiguous().permute(2, 0, 3, 1)
weight = init_weight.permute(3, 2, 0, 1).contiguous().permute(2, 3, 1, 0)
else:
weight = init_weight

for d in ["cpu", "cuda"]:

out = ops.deform_conv2d(img.to(d), offset.to(d), weight.to(d), padding=1)
out.mean().backward()
if true_cpu_grads is None:
true_cpu_grads = init_weight.grad
self.assertTrue(true_cpu_grads is not None)
else:
self.assertTrue(init_weight.grad is not None)
res_grads = init_weight.grad.to("cpu")
self.assertTrue(true_cpu_grads.allclose(res_grads))


class FrozenBNTester(unittest.TestCase):
def test_frozenbatchnorm2d_repr(self):
Expand Down
7 changes: 3 additions & 4 deletions torchvision/csrc/cuda/DeformConv_cuda.cu
Original file line number Diff line number Diff line change
Expand Up @@ -80,12 +80,11 @@

using namespace at;

const int CUDA_NUM_THREADS = 1024;
const int kMaxGridNum = 65535;

const unsigned int CUDA_NUM_THREADS = 1024;
const unsigned int kMaxGridNum = at::cuda::getCurrentDeviceProperties()->maxGridSize[0];
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
const int kMaxParallelImgs = 32;

inline int GET_BLOCKS(const int N) {
inline unsigned int GET_BLOCKS(const unsigned int N) {
return std::min(kMaxGridNum, (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS);
}

Expand Down