Skip to content

Commit

Permalink
[TOPI] Allow conv definition to have custom kernel layout (#11936)
Browse files Browse the repository at this point in the history
* [TOPI] Allow conv definition to have custom kernel layout

* add tests

* fix

* fix
  • Loading branch information
vinx13 authored Jul 14, 2022
1 parent a9c610f commit a571bfb
Show file tree
Hide file tree
Showing 13 changed files with 175 additions and 73 deletions.
10 changes: 7 additions & 3 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,15 @@ def conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
data, kernel = inputs
if topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_NCHWc_int8, True, True),
wrap_compute_conv2d(
topi.arm_cpu.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.arm_cpu",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.x86",
)
Expand All @@ -294,7 +296,9 @@ def depthwise_conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
"""depthwise_conv2d_NCHWc adopted from x86"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.x86",
)
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,19 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
):
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, need_data_layout=True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda",
)
elif is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_conv2d(
topi.nn.conv, need_data_layout=True, need_kernel_layout=True, has_groups=True
),
naive_schedule,
name="conv2d.cuda",
plevel=15,
)
elif target.kind.name == "cuda" and "cudnn" not in target.libs:
# No TVM native kernel applicable
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def schedule_bitpack(attrs, outs, target):
# conv2d
def wrap_compute_conv2d(
topi_compute,
*,
need_data_layout=False,
need_kernel_layout=False,
need_out_layout=False,
has_groups=False,
need_auto_scheduler_layout=False,
Expand All @@ -236,6 +238,7 @@ def _compute_conv2d(attrs, inputs, out_type):
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
data_layout = attrs.get_str("data_layout")
kernel_layout = attrs.get_str("kernel_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
Expand All @@ -244,6 +247,8 @@ def _compute_conv2d(attrs, inputs, out_type):
args.append(attrs.groups)
if need_data_layout:
args.append(data_layout)
if need_kernel_layout:
args.append(kernel_layout)
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
Expand Down Expand Up @@ -340,13 +345,15 @@ def conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
if inputs[0].dtype == "int8" or inputs[0].dtype == "uint8":
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True),
wrap_compute_conv2d(
topi.nn.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.generic",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.generic",
)
Expand All @@ -360,7 +367,9 @@ def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
logger.warning("depthwise_conv2d_NCHWc is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.nn.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.generic",
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target):
"""conv2d_NCHWc hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.hls.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.hls",
)
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/op/strategy/intel_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def conv2d_strategy_intel_graphics(attrs, inputs, out_type, target):
# conv2d_NCHWc won't work without alter op layout pass
# TODO(@Laurawly): fix this
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics",
plevel=5,
Expand All @@ -71,7 +73,9 @@ def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, out_type, target):
"""conv2d_NCHWc intel_graphics strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics",
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
and padding[1] == padding[3]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, need_data_layout=True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=50,
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,15 @@ def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
data, kernel = inputs
if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype):
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc_int8, True, True),
wrap_compute_conv2d(
topi.x86.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.x86",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.x86",
)
Expand All @@ -287,7 +289,9 @@ def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
"""depthwise_conv2d x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.x86",
)
Expand Down
35 changes: 24 additions & 11 deletions python/tvm/topi/nn/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,27 @@
from .conv2d import conv


def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", out_dtype=None):
def conv1d(
data,
kernel,
strides=1,
padding="VALID",
dilation=1,
data_layout="NCW",
kernel_layout="",
out_dtype=None,
):
"""1D convolution forward operator.
Parameters
----------
data : tvm.te.Tensor
3-D input shape [batch, in_channel, in_width] for layout == 'NCW'
and [batch, in_width, in_channel] for layout == 'NWC'
3-D input shape [batch, in_channel, in_width] for data_layout == 'NCW'
and [batch, in_width, in_channel] for data_layout == 'NWC'
kernel : tvm.te.Tensor
3-D kernel with shape [num_filter, in_channel, filter_size] for layout == 'NCW'
and [filter_size, in_channel, num_filter] for layout == 'NWC'
3-D kernel with shape [num_filter, in_channel, filter_size] for kernel_layout == 'OIW'
and [filter_size, in_channel, num_filter] for kernel_layout == 'WIO'
strides : int or tuple
The spatial stride along width
Expand All @@ -41,23 +50,27 @@ def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", o
dilation : int or tuple
Dilation rate if convolution should be dilated.
layout : str
data_layout : str
How input data is laid out, must be one of ['NCW', 'NWC']
kernel_layout: Optiona[str]
The layout of the kernel. If unspecified, use default layout. "OIW" if data_layout == "NCW",
"WIO" if data_layout == "NWC".
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, 1, layout, out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, data_layout, kernel_layout, out_dtype)


def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
"""1D convolution in NWC layout. See :py:func:`conv` for details on parameters"""
return conv(data, kernel, strides, padding, dilation, 1, "NWC", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, "NWC", "WIO", out_dtype=out_dtype)


def conv1d_ncw(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
"""1D convolution in NCW layout. See :py:func:`conv` for details on parameters"""
return conv(data, kernel, strides, padding, dilation, 1, "NCW", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, "NCW", "OIW", out_dtype=out_dtype)


def group_conv1d_nwc(
Expand Down Expand Up @@ -89,7 +102,7 @@ def group_conv1d_nwc(
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, groups, "NWC", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, groups, "NWC", "WIO", out_dtype=out_dtype)


def group_conv1d_ncw(
Expand Down Expand Up @@ -121,4 +134,4 @@ def group_conv1d_ncw(
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, groups, "NCW", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, groups, "NCW", "OIW", out_dtype=out_dtype)
Loading

0 comments on commit a571bfb

Please sign in to comment.