Skip to content

Commit

Permalink
Merge pull request #3 from angel-ayala/master
Browse files Browse the repository at this point in the history
Includes of dilation and groups params for nn.Conv2d
  • Loading branch information
sjaek authored Jan 8, 2020
2 parents 733b73c + 3b80091 commit 102f94c
Showing 1 changed file with 46 additions and 2 deletions.
48 changes: 46 additions & 2 deletions octconv/octconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,28 @@ def __init__(self,
stride=1,
padding=0,
alpha=0.5,
dilation=1,
groups=False,
bias=False):

"""
Octave convolution from the 2019 article
Drop an Octave: Reducing Spatial Redundancy in Convolutional Neural Networks with Octave Convolution
Extend the 2D convolution with the octave reduction.
Args:
in_channels (int): Number of channels in the input image
out_channels (int): Number of channels produced by the convolution
kernel_size (int or tuple): Size of the convolving kernel
stride (int or tuple, optional): Stride of the convolution. Default: 1
padding (int or tuple, optional): Zero-padding added to both sides of the input. Default: 0
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
alpha (float or tuple, optional): Reduction for the (input, output) octave part of the convolution.
Default: 0.5
groups (bool, optional): Decides if the convolution must be group-wise, with groups=in_channels.
Default: False
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``False``
"""

super(OctConv2d, self).__init__()

assert isinstance(in_channels, int) and in_channels > 0
Expand Down Expand Up @@ -43,9 +63,23 @@ def __init__(self,
'low': out_channels - out_ch_hf
}

# groups
self.groups = {
'high': 1,
'low': 1
}

if type(groups) == bool and groups:
if self.alpha_out > 0 and self.in_channels['high'] <= self.out_channels['high']:
self.groups['high'] = in_ch_hf

if self.alpha_in > 0 and self.in_channels['low'] <= self.out_channels['low']:
self.groups['low'] = in_channels - in_ch_hf

self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.bias = bias

self.pool = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
Expand All @@ -54,27 +88,35 @@ def __init__(self,
out_channels=self.out_channels['high'],
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
groups=self.groups['high'],
bias=bias) \
if not (self.alpha_in == 1 or self.alpha_out == 1) else None

self.conv_h2l = nn.Conv2d(in_channels=self.in_channels['high'],
out_channels=self.out_channels['low'],
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
groups=self.groups['high'],
bias=bias) \
if not (self.alpha_in == 1 or self.alpha_out == 0) else None

self.conv_l2h = nn.Conv2d(in_channels=self.in_channels['low'],
out_channels=self.out_channels['high'],
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
groups=self.groups['low'],
bias=bias) \
if not (self.alpha_in == 0 or self.alpha_out == 1) else None

self.conv_l2l = nn.Conv2d(in_channels=self.in_channels['low'],
out_channels=self.out_channels['low'],
kernel_size=kernel_size,
padding=padding,
dilation=dilation,
groups=self.groups['low'],
bias=bias) \
if not (self.alpha_in == 0 or self.alpha_out == 0) else None

Expand Down Expand Up @@ -127,10 +169,12 @@ def _check_inputs(self, x_h, x_l):
def __repr__(self):
s = """{}(in_channels=(low: {}, high: {}), out_channels=(low: {}, high: {}),
kernel_size=({kernel}, {kernel}), stride=({stride}, {stride}),
padding={}, alphas=({}, {}), bias={})""".format(
padding={}, alphas=({}, {}), dilation={dilation}, groups=(low: {groupsl}, high: {groupsh}),
bias={})""".format(
self._get_name(), self.in_channels['low'], self.in_channels['high'],
self.out_channels['low'], self.out_channels['high'],
self.padding, self.alpha_in, self.alpha_out, self.bias,
kernel=self.kernel_size, stride=self.stride)
kernel=self.kernel_size, stride=self.stride, dilation=self.dilation,
groupsl=self.groups['low'], groupsh=self.groups['high'])

return s

0 comments on commit 102f94c

Please sign in to comment.