Skip to content

Commit

Permalink
New Feature: add DropBlock layer (#5416)
Browse files Browse the repository at this point in the history
* Create dropblock.py

* add dropblock2d

* fix pylint

* refactor dropblock

* add dropblock

* Rename dropblock.py to drop_block.py

* fix pylint

* add dropblock

* add dropblock3d

* add drop_block3d

* add dropblock

* Update drop_block.py

* Update torchvision/ops/drop_block.py

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/ops/drop_block.py

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/ops/drop_block.py

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

* Update torchvision/ops/drop_block.py

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

* Update drop_block.py

* Update drop_block.py

* import torch.fx

* fix lint

* fix lint

* Update drop_block.py

* improve dropblock

* add dropblock

* refactor dropblock

* fix doc

* remove the limitation of block_size

* Update torchvision/ops/drop_block.py

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>

* fix lint

* fix lint

* add dropblock

* Fix linter

* add dropblock random check

* reduce test time

* Update test_ops.py

* speed the dropblock test

* fix lint

Co-authored-by: Vasilis Vryniotis <datumbox@users.noreply.github.com>
  • Loading branch information
xiaohu2015 and datumbox authored Feb 27, 2022
1 parent 1fc53b2 commit 5568744
Show file tree
Hide file tree
Showing 4 changed files with 257 additions and 0 deletions.
4 changes: 4 additions & 0 deletions docs/source/ops.rst
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ Operators
box_iou
clip_boxes_to_image
deform_conv2d
drop_block2d
drop_block3d
generalized_box_iou
generalized_box_iou_loss
masks_to_boxes
Expand Down Expand Up @@ -48,3 +50,5 @@ Operators
Conv2dNormActivation
Conv3dNormActivation
SqueezeExcitation
DropBlock2d
DropBlock3d
93 changes: 93 additions & 0 deletions test/test_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from abc import ABC, abstractmethod
from functools import lru_cache
from itertools import product
from typing import Callable, List, Tuple

import numpy as np
Expand Down Expand Up @@ -57,6 +58,16 @@ def forward(self, a):
self.layer(a)


class DropBlockWrapper(nn.Module):
def __init__(self, obj):
super().__init__()
self.layer = obj
self.n_inputs = 1

def forward(self, a):
self.layer(a)


class RoIOpTester(ABC):
dtype = torch.float64

Expand Down Expand Up @@ -1357,5 +1368,87 @@ def test_split_normalization_params(self, norm_layer):
assert len(params[1]) == 82


class TestDropBlock:
@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("dim", [2, 3])
@pytest.mark.parametrize("p", [0, 0.5])
@pytest.mark.parametrize("block_size", [5, 11])
@pytest.mark.parametrize("inplace", [True, False])
def test_drop_block(self, seed, dim, p, block_size, inplace):
torch.manual_seed(seed)
batch_size = 5
channels = 3
height = 11
width = height
depth = height
if dim == 2:
x = torch.ones(size=(batch_size, channels, height, width))
layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
feature_size = height * width
elif dim == 3:
x = torch.ones(size=(batch_size, channels, depth, height, width))
layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)
feature_size = depth * height * width
layer.__repr__()

out = layer(x)
if p == 0:
assert out.equal(x)
if block_size == height:
for b, c in product(range(batch_size), range(channels)):
assert out[b, c].count_nonzero() in (0, feature_size)

@pytest.mark.parametrize("seed", range(10))
@pytest.mark.parametrize("dim", [2, 3])
@pytest.mark.parametrize("p", [0.1, 0.2])
@pytest.mark.parametrize("block_size", [3])
@pytest.mark.parametrize("inplace", [False])
def test_drop_block_random(self, seed, dim, p, block_size, inplace):
torch.manual_seed(seed)
batch_size = 5
channels = 3
height = 11
width = height
depth = height
if dim == 2:
x = torch.ones(size=(batch_size, channels, height, width))
layer = ops.DropBlock2d(p=p, block_size=block_size, inplace=inplace)
elif dim == 3:
x = torch.ones(size=(batch_size, channels, depth, height, width))
layer = ops.DropBlock3d(p=p, block_size=block_size, inplace=inplace)

trials = 250
num_samples = 0
counts = 0
cell_numel = torch.tensor(x.shape).prod()
for _ in range(trials):
with torch.no_grad():
out = layer(x)
non_zero_count = out.nonzero().size(0)
counts += cell_numel - non_zero_count
num_samples += cell_numel

assert abs(p - counts / num_samples) / p < 0.15

def make_obj(self, dim, p, block_size, inplace, wrap=False):
if dim == 2:
obj = ops.DropBlock2d(p, block_size, inplace)
elif dim == 3:
obj = ops.DropBlock3d(p, block_size, inplace)
return DropBlockWrapper(obj) if wrap else obj

@pytest.mark.parametrize("dim", (2, 3))
@pytest.mark.parametrize("p", [0, 1])
@pytest.mark.parametrize("block_size", [5, 7])
@pytest.mark.parametrize("inplace", [True, False])
def test_is_leaf_node(self, dim, p, block_size, inplace):
op_obj = self.make_obj(dim, p, block_size, inplace, wrap=True)
graph_node_names = get_graph_node_names(op_obj)

assert len(graph_node_names) == 2
assert len(graph_node_names[0]) == len(graph_node_names[1])
assert len(graph_node_names[0]) == 1 + op_obj.n_inputs


if __name__ == "__main__":
pytest.main([__file__])
5 changes: 5 additions & 0 deletions torchvision/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
)
from .boxes import box_convert
from .deform_conv import deform_conv2d, DeformConv2d
from .drop_block import drop_block2d, DropBlock2d, drop_block3d, DropBlock3d
from .feature_pyramid_network import FeaturePyramidNetwork
from .focal_loss import sigmoid_focal_loss
from .giou_loss import generalized_box_iou_loss
Expand Down Expand Up @@ -55,4 +56,8 @@
"Conv3dNormActivation",
"SqueezeExcitation",
"generalized_box_iou_loss",
"drop_block2d",
"DropBlock2d",
"drop_block3d",
"DropBlock3d",
]
155 changes: 155 additions & 0 deletions torchvision/ops/drop_block.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
import torch
import torch.fx
import torch.nn.functional as F
from torch import nn, Tensor

from ..utils import _log_api_usage_once


def drop_block2d(
input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
) -> Tensor:
"""
Implements DropBlock2d from `"DropBlock: A regularization method for convolutional networks"
<https://arxiv.org/abs/1810.12890>`.
Args:
input (Tensor[N, C, H, W]): The input tensor or 4-dimensions with the first one
being its batch i.e. a batch with ``N`` rows.
p (float): Probability of an element to be dropped.
block_size (int): Size of the block to drop.
inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
training (bool): apply dropblock if is ``True``. Default: ``True``.
Returns:
Tensor[N, C, H, W]: The randomly zeroed tensor after dropblock.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(drop_block2d)
if p < 0.0 or p > 1.0:
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
if input.ndim != 4:
raise ValueError(f"input should be 4 dimensional. Got {input.ndim} dimensions.")
if not training or p == 0.0:
return input

N, C, H, W = input.size()
block_size = min(block_size, W, H)
# compute the gamma of Bernoulli distribution
gamma = (p * H * W) / ((block_size ** 2) * ((H - block_size + 1) * (W - block_size + 1)))
noise = torch.empty((N, C, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device)
noise.bernoulli_(gamma)

noise = F.pad(noise, [block_size // 2] * 4, value=0)
noise = F.max_pool2d(noise, stride=(1, 1), kernel_size=(block_size, block_size), padding=block_size // 2)
noise = 1 - noise
normalize_scale = noise.numel() / (eps + noise.sum())
if inplace:
input.mul_(noise).mul_(normalize_scale)
else:
input = input * noise * normalize_scale
return input


def drop_block3d(
input: Tensor, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06, training: bool = True
) -> Tensor:
"""
Implements DropBlock3d from `"DropBlock: A regularization method for convolutional networks"
<https://arxiv.org/abs/1810.12890>`.
Args:
input (Tensor[N, C, D, H, W]): The input tensor or 5-dimensions with the first one
being its batch i.e. a batch with ``N`` rows.
p (float): Probability of an element to be dropped.
block_size (int): Size of the block to drop.
inplace (bool): If set to ``True``, will do this operation in-place. Default: ``False``.
eps (float): A value added to the denominator for numerical stability. Default: 1e-6.
training (bool): apply dropblock if is ``True``. Default: ``True``.
Returns:
Tensor[N, C, D, H, W]: The randomly zeroed tensor after dropblock.
"""
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
_log_api_usage_once(drop_block3d)
if p < 0.0 or p > 1.0:
raise ValueError(f"drop probability has to be between 0 and 1, but got {p}.")
if input.ndim != 5:
raise ValueError(f"input should be 5 dimensional. Got {input.ndim} dimensions.")
if not training or p == 0.0:
return input

N, C, D, H, W = input.size()
block_size = min(block_size, D, H, W)
# compute the gamma of Bernoulli distribution
gamma = (p * D * H * W) / ((block_size ** 3) * ((D - block_size + 1) * (H - block_size + 1) * (W - block_size + 1)))
noise = torch.empty(
(N, C, D - block_size + 1, H - block_size + 1, W - block_size + 1), dtype=input.dtype, device=input.device
)
noise.bernoulli_(gamma)

noise = F.pad(noise, [block_size // 2] * 6, value=0)
noise = F.max_pool3d(
noise, stride=(1, 1, 1), kernel_size=(block_size, block_size, block_size), padding=block_size // 2
)
noise = 1 - noise
normalize_scale = noise.numel() / (eps + noise.sum())
if inplace:
input.mul_(noise).mul_(normalize_scale)
else:
input = input * noise * normalize_scale
return input


torch.fx.wrap("drop_block2d")


class DropBlock2d(nn.Module):
"""
See :func:`drop_block2d`.
"""

def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
super().__init__()

self.p = p
self.block_size = block_size
self.inplace = inplace
self.eps = eps

def forward(self, input: Tensor) -> Tensor:
"""
Args:
input (Tensor): Input feature map on which some areas will be randomly
dropped.
Returns:
Tensor: The tensor after DropBlock layer.
"""
return drop_block2d(input, self.p, self.block_size, self.inplace, self.eps, self.training)

def __repr__(self) -> str:
s = f"{self.__class__.__name__}(p={self.p}, block_size={self.block_size}, inplace={self.inplace})"
return s


torch.fx.wrap("drop_block3d")


class DropBlock3d(DropBlock2d):
"""
See :func:`drop_block3d`.
"""

def __init__(self, p: float, block_size: int, inplace: bool = False, eps: float = 1e-06) -> None:
super().__init__(p, block_size, inplace, eps)

def forward(self, input: Tensor) -> Tensor:
"""
Args:
input (Tensor): Input feature map on which some areas will be randomly
dropped.
Returns:
Tensor: The tensor after DropBlock layer.
"""
return drop_block3d(input, self.p, self.block_size, self.inplace, self.eps, self.training)

0 comments on commit 5568744

Please sign in to comment.