Skip to content

Commit

Permalink
[1/2] Added backward pass on CPU for interpolation with anti-alias op…
Browse files Browse the repository at this point in the history
…tion (#4208)

* WIP on backward op interpolation with AA

* Removed cuda tests and reformat cpp code

* Fixed clang wrong formatting

* Added channels last test case

Co-authored-by: vfdev-5 <vfdev-5@gmail.com>
Co-authored-by: Francisco Massa <fvsmassa@gmail.com>
  • Loading branch information
3 people authored Jul 28, 2021
1 parent 30fd10b commit e2dbadb
Show file tree
Hide file tree
Showing 6 changed files with 407 additions and 89 deletions.
47 changes: 47 additions & 0 deletions test/test_functional_tensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
import itertools
import os
import colorsys
Expand Down Expand Up @@ -578,6 +579,52 @@ def test_assert_resize_antialias(interpolation):
F.resize(tensor, size=(5, 5), interpolation=interpolation, antialias=True)


@pytest.mark.parametrize('dt', [torch.float32, torch.float64, torch.float16])
@pytest.mark.parametrize('size', [[10, 7], [10, 42], [42, 7]])
@pytest.mark.parametrize('interpolation', [BILINEAR, BICUBIC])
def test_interpolate_antialias_backward(dt, size, interpolation):

# temporarily hard-code device as CPU, CUDA support will be done later
device = "cpu"

if dt == torch.float16 and device == "cpu":
# skip float16 on CPU case
return

torch.manual_seed(12)
if interpolation == BILINEAR:
forward_op = torch.ops.torchvision._interpolate_bilinear2d_aa
backward_op = torch.ops.torchvision._interpolate_bilinear2d_aa_backward
elif interpolation == BICUBIC:
forward_op = torch.ops.torchvision._interpolate_bicubic2d_aa
backward_op = torch.ops.torchvision._interpolate_bicubic2d_aa_backward

class F(torch.autograd.Function):

@staticmethod
def forward(ctx, i):
result = forward_op(i, size, False)
ctx.save_for_backward(i, result)
return result

@staticmethod
def backward(ctx, grad_output):
i, result = ctx.saved_tensors
ishape = i.shape
oshape = result.shape[2:]
return backward_op(grad_output, oshape, ishape, False)

x = (
torch.rand(1, 32, 29, 3, dtype=torch.double, device=device).permute(0, 3, 1, 2).requires_grad_(True),
)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)

x = (
torch.rand(1, 3, 32, 29, dtype=torch.double, device=device, requires_grad=True),
)
assert torch.autograd.gradcheck(F.apply, x, eps=1e-8, atol=1e-6, rtol=1e-6, fast_mode=False)


def check_functional_vs_PIL_vs_scripted(fn, fn_pil, fn_t, config, device, dtype, tol=2.0 + 1e-10, agg_method="max"):

script_fn = torch.jit.script(fn)
Expand Down
Loading

0 comments on commit e2dbadb

Please sign in to comment.