diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 261b979dedaf..c8d51bc23c82 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -318,7 +318,7 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target): else: logger.warning("depthwise_conv2d with layout NHWC is not optimized for arm cpu.") strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True), wrap_topi_schedule(conv2d_generic.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.generic", ) diff --git a/python/tvm/relay/op/strategy/hexagon.py b/python/tvm/relay/op/strategy/hexagon.py index c1d64f2fe143..f42503a1477c 100644 --- a/python/tvm/relay/op/strategy/hexagon.py +++ b/python/tvm/relay/op/strategy/hexagon.py @@ -86,9 +86,8 @@ def conv2d_strategy_hexagon(attrs, inputs, out_type, target): name="depthwise_conv2d_nchw.hexagon", ) elif layout == "NHWC": - assert kernel_layout == "HWOI" strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True), wrap_topi_schedule(topi.hexagon.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.hexagon", ) diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index 3e59209f5822..7ff4dbc0ad1b 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -228,13 +228,12 @@ def conv2d_strategy_cpu(attrs, inputs, out_type, target): assert _OIHWio_matcher.match(kernel_layout) # check if kernel is OIHWio return depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target) elif layout == "NHWC": - assert kernel_layout == "HWOI" if (not need_auto_scheduler_layout) and (not need_meta_schedule_layout): logger.warning( "depthwise_conv2d NHWC layout is not optimized for x86 with autotvm." ) strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc), + wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc, need_kernel_layout=True), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc), name="depthwise_conv2d_nhwc.generic", ) diff --git a/python/tvm/topi/nn/depthwise_conv2d.py b/python/tvm/topi/nn/depthwise_conv2d.py index 48ffb8c6d9ff..7c446a23a813 100644 --- a/python/tvm/topi/nn/depthwise_conv2d.py +++ b/python/tvm/topi/nn/depthwise_conv2d.py @@ -19,6 +19,7 @@ from __future__ import absolute_import as _abs from collections import namedtuple import tvm +import numpy as np from tvm import te from .dilate import dilate @@ -211,7 +212,9 @@ def depthwise_conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=No return Output -def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=None): +def depthwise_conv2d_nhwc( + Input, Filter, stride, padding, dilation, kernel_layout="HWOI", out_dtype=None +): """Depthwise convolution nhwc forward operator. Parameters @@ -252,8 +255,14 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No dilation_h, dilation_w = dilation batch, in_height, in_width, in_channel = Input.shape + # shape of dilated kernel - filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape + if kernel_layout == "HWIO": + filter_height, filter_width, channel_multiplier, filter_channel = Filter.shape + kernel_permutation = [0, 1, 3, 2] + else: + filter_height, filter_width, filter_channel, channel_multiplier = Filter.shape + kernel_permutation = [0, 1, 2, 3] dilated_kernel_h = (filter_height - 1) * dilation_h + 1 dilated_kernel_w = (filter_width - 1) * dilation_w + 1 @@ -285,7 +294,11 @@ def depthwise_conv2d_nhwc(Input, Filter, stride, padding, dilation, out_dtype=No idxdiv(c, channel_multiplier), ].astype(out_dtype) * Filter[ - di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier) + tuple( + np.array( + [di, dj, idxdiv(c, channel_multiplier), idxmod(c, channel_multiplier)] + )[kernel_permutation] + ) ].astype(out_dtype) ), axis=[di, dj],