-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Hexagon] Enable depthwise conv2d NHWC with an HWIO kernel layout #13414
Changes from 2 commits
c621db9
fc2b365
6cdb2d0
e582efb
ce70659
dfff52d
e34374b
0bf433f
de68e67
4239414
5155ebd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,6 +19,7 @@ | |
from __future__ import absolute_import as _abs | ||
from collections import namedtuple | ||
import tvm | ||
import numpy as np | ||
from tvm import te | ||
|
||
from .dilate import dilate | ||
|
@@ -211,7 +212,7 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No | |
return Output | ||
|
||
|
||
def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=None): | ||
def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, kernel_layout, out_dtype=None): | ||
"""Depthwise convolution nhwc forward operator. | ||
|
||
Parameters | ||
|
@@ -252,8 +253,18 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No | |
dilation_h, dilation_w = dilation | ||
|
||
batch, in_height, in_width, in_channel = Input.shape | ||
|
||
dim = len(Input.shape) - 2 | ||
|
||
# shape of dilated kernel | ||
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape | ||
if kernel_layout == "HWOI": | ||
filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape | ||
kernel_permutation_to = [0, 1] + list(range(2, dim + 2)) | ||
elif kernel_layout == "HWIO": | ||
filter_height, filter_width, channel_multiplier, filter_channel = Filter.shape | ||
kernel_permutation_to = [dim + 1, dim] + list(range(dim)) | ||
|
||
kernel_permutation_from = np.argsort(kernel_permutation_to) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is there a benefit to defining in terms of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I tried to follow this as much as possible so in the future we can add more features if needed. If you think this is not the best way I could change it. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, I see. It looks like that implementation is trying to be more clever, and to identify the permutation by inspecting the string. In this case, since we're only supporting two explicitly enabled shapes, I'd lean for defining a if kernel_layout == 'HWOI':
kernel_permutaiton = [0,1,2,3]
elif kernel_layout == 'HWIO':
kernel_permutaiton = [0,1,3,2]
else:
raise ValueError(f'Unsupported kernel layout: {kernel_layout}') There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. That said, if there are many locations where the same kernel permutation definitions are used, it may be useful to pull it out into a common utility. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Changed the way of permutation from your recommendation. @Lunderberg |
||
|
||
dilated_kernel_h = (filter_height - 1) * dilation_h + 1 | ||
dilated_kernel_w = (filter_width - 1) * dilation_w + 1 | ||
|
@@ -284,9 +295,13 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No | |
j * stride_w + dj * dilation_w, | ||
idxdiv(c, channel_multiplier), | ||
].astype(out_dtype) | ||
* Filter[ | ||
di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier) | ||
].astype(out_dtype) | ||
* Filter.__getitem__( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The explicit There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh I see, I have changed this as your recommendation. |
||
tuple( | ||
np.array( | ||
[di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)] | ||
)[kernel_permutation_from] | ||
) | ||
).astype(out_dtype) | ||
), | ||
axis=[di, dj], | ||
), | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not seeing how this permutation generates
HWIO
. This defineskernel_permutation_to
as[3, 2, 0, 1]
, sokernel_permutation_from
is[2, 3, 1, 0]
. With the usage below, that would permute from[di, dj, c//channel_multiplier, c%channel_multiplier]
to[c//channel_multiplier, c%channel_multiplier, dj, di]
, which would beOIWH
.Should this be
list(range(dim)) + [dim + 1, dim]
instead?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Lunderberg You are right, my bad. Not sure why I thought this would result in
HWIO
. Thanks so much for catching this bug.