Skip to content

Commit

Permalink
Merge branch 'main' of github.com:pytorch/vision into pyav_iobase
Browse files Browse the repository at this point in the history
  • Loading branch information
NicolasHug committed May 29, 2024
2 parents 477c49f + 45e053b commit 7359199
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 14 deletions.
4 changes: 2 additions & 2 deletions references/classification/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,9 @@ def get_mixup_cutmix(*, mixup_alpha, cutmix_alpha, num_classes, use_v2):
)
if cutmix_alpha > 0:
mixup_cutmix.append(
transforms_module.CutMix(alpha=mixup_alpha, num_classes=num_classes)
transforms_module.CutMix(alpha=cutmix_alpha, num_classes=num_classes)
if use_v2
else RandomCutMix(num_classes=num_classes, p=1.0, alpha=mixup_alpha)
else RandomCutMix(num_classes=num_classes, p=1.0, alpha=cutmix_alpha)
)
if not mixup_cutmix:
return None
Expand Down
5 changes: 5 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from common_utils import assert_equal, cpu_and_cuda, cpu_and_cuda_and_mps, needs_cuda, needs_mps
from PIL import Image
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.autograd import gradcheck
from torch.nn.modules.utils import _pair
from torchvision import models, ops
Expand Down Expand Up @@ -529,6 +530,10 @@ def test_autocast_cpu(self, aligned, deterministic, x_dtype, rois_dtype):
def test_backward(self, seed, device, contiguous, deterministic):
if deterministic and device == "cpu":
pytest.skip("cpu is always deterministic, don't retest")
if deterministic and device == "mps":
pytest.skip("no deterministic implementation for mps")
if deterministic and not is_compile_supported(device):
pytest.skip("deterministic implementation only if torch.compile supported")
super().test_backward(seed, device, contiguous, deterministic)

def _make_rois(self, img_size, num_imgs, dtype, num_rois=1000):
Expand Down
36 changes: 28 additions & 8 deletions torchvision/ops/roi_align.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import functools
from typing import List, Union

import torch
import torch._dynamo
import torch.fx
from torch import nn, Tensor
from torch._dynamo.utils import is_compile_supported
from torch.jit.annotations import BroadcastingList2
from torch.nn.modules.utils import _pair
from torchvision.extension import _assert_has_ops, _has_ops
Expand All @@ -12,6 +14,24 @@
from ._utils import check_roi_boxes_shape, convert_boxes_to_roi_format


def lazy_compile(**compile_kwargs):
"""Lazily wrap a function with torch.compile on the first call
This avoids eagerly importing dynamo.
"""

def decorate_fn(fn):
@functools.wraps(fn)
def compile_hook(*args, **kwargs):
compiled_fn = torch.compile(fn, **compile_kwargs)
globals()[fn.__name__] = functools.wraps(fn)(compiled_fn)
return compiled_fn(*args, **kwargs)

return compile_hook

return decorate_fn


# NB: all inputs are tensors
def _bilinear_interpolate(
input, # [N, C, H, W]
Expand Down Expand Up @@ -86,15 +106,13 @@ def maybe_cast(tensor):
return tensor


# This is a slow but pure Python and differentiable implementation of
# roi_align. It potentially is a good basis for Inductor compilation
# (but I have not benchmarked it) but today it is solely used for the
# fact that its backwards can be implemented deterministically,
# which is needed for the PT2 benchmark suite.
#
# This is a pure Python and differentiable implementation of roi_align. When
# run in eager mode, it uses a lot of memory, but when compiled it has
# acceptable memory usage. The main point of this implementation is that
# its backwards is deterministic.
# It is transcribed directly off of the roi_align CUDA kernel, see
# https://dev-discuss.pytorch.org/t/a-pure-python-implementation-of-roi-align-that-looks-just-like-its-cuda-kernel/1266
@torch._dynamo.allow_in_graph
@lazy_compile(dynamic=True)
def _roi_align(input, rois, spatial_scale, pooled_height, pooled_width, sampling_ratio, aligned):
orig_dtype = input.dtype

Expand Down Expand Up @@ -232,7 +250,9 @@ def roi_align(
if not isinstance(rois, torch.Tensor):
rois = convert_boxes_to_roi_format(rois)
if not torch.jit.is_scripting():
if not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps)):
if (
not _has_ops() or (torch.are_deterministic_algorithms_enabled() and (input.is_cuda or input.is_mps))
) and is_compile_supported(input.device.type):
return _roi_align(input, rois, spatial_scale, output_size[0], output_size[1], sampling_ratio, aligned)
_assert_has_ops()
return torch.ops.torchvision.roi_align(
Expand Down
8 changes: 4 additions & 4 deletions torchvision/transforms/_functional_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -722,10 +722,10 @@ def perspective(
return _apply_grid_transform(img, grid, interpolation, fill=fill)


def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
def _get_gaussian_kernel1d(kernel_size: int, sigma: float, dtype: torch.dtype, device: torch.device) -> Tensor:
ksize_half = (kernel_size - 1) * 0.5

x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size, dtype=dtype, device=device)
pdf = torch.exp(-0.5 * (x / sigma).pow(2))
kernel1d = pdf / pdf.sum()

Expand All @@ -735,8 +735,8 @@ def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:
def _get_gaussian_kernel2d(
kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
) -> Tensor:
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0], dtype, device)
kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1], dtype, device)
kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])
return kernel2d

Expand Down

0 comments on commit 7359199

Please sign in to comment.