Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI] Allow conv definition to have custom kernel layout #11936

Merged
merged 4 commits into from
Jul 14, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,15 @@ def conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
data, kernel = inputs
if topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_NCHWc_int8, True, True),
wrap_compute_conv2d(
topi.arm_cpu.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.arm_cpu",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.x86",
)
Expand All @@ -294,7 +296,9 @@ def depthwise_conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
"""depthwise_conv2d_NCHWc adopted from x86"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.x86",
)
Expand Down
11 changes: 10 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,10 +316,19 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
):
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, need_data_layout=True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda",
)
elif is_auto_scheduler_enabled():
strategy.add_implementation(
wrap_compute_conv2d(
topi.nn.conv, need_data_layout=True, need_kernel_layout=True, has_groups=True
),
naive_schedule,
name="conv2d.cuda",
plevel=15,
)
elif target.kind.name == "cuda" and "cudnn" not in target.libs:
# No TVM native kernel applicable
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
Expand Down
15 changes: 12 additions & 3 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,9 @@ def schedule_bitpack(attrs, outs, target):
# conv2d
def wrap_compute_conv2d(
topi_compute,
*,
need_data_layout=False,
need_kernel_layout=False,
need_out_layout=False,
has_groups=False,
need_auto_scheduler_layout=False,
Expand All @@ -236,6 +238,7 @@ def _compute_conv2d(attrs, inputs, out_type):
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
data_layout = attrs.get_str("data_layout")
kernel_layout = attrs.get_str("kernel_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
Expand All @@ -244,6 +247,8 @@ def _compute_conv2d(attrs, inputs, out_type):
args.append(attrs.groups)
if need_data_layout:
args.append(data_layout)
if need_kernel_layout:
args.append(kernel_layout)
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
Expand Down Expand Up @@ -340,13 +345,15 @@ def conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
if inputs[0].dtype == "int8" or inputs[0].dtype == "uint8":
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True),
wrap_compute_conv2d(
topi.nn.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.generic",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.generic",
)
Expand All @@ -360,7 +367,9 @@ def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
logger.warning("depthwise_conv2d_NCHWc is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.nn.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.generic",
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target):
"""conv2d_NCHWc hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.hls.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.hls",
)
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/op/strategy/intel_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def conv2d_strategy_intel_graphics(attrs, inputs, out_type, target):
# conv2d_NCHWc won't work without alter op layout pass
# TODO(@Laurawly): fix this
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics",
plevel=5,
Expand All @@ -71,7 +73,9 @@ def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, out_type, target):
"""conv2d_NCHWc intel_graphics strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics",
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
and padding[1] == padding[3]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, need_data_layout=True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=50,
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,15 @@ def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
data, kernel = inputs
if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype):
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc_int8, True, True),
wrap_compute_conv2d(
topi.x86.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.x86",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.x86",
)
Expand All @@ -287,7 +289,9 @@ def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
"""depthwise_conv2d x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.x86",
)
Expand Down
35 changes: 24 additions & 11 deletions python/tvm/topi/nn/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,27 @@
from .conv2d import conv


def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", out_dtype=None):
def conv1d(
data,
kernel,
strides=1,
padding="VALID",
dilation=1,
data_layout="NCW",
kernel_layout="",
out_dtype=None,
):
"""1D convolution forward operator.
Parameters
----------
data : tvm.te.Tensor
3-D input shape [batch, in_channel, in_width] for layout == 'NCW'
and [batch, in_width, in_channel] for layout == 'NWC'
3-D input shape [batch, in_channel, in_width] for data_layout == 'NCW'
and [batch, in_width, in_channel] for data_layout == 'NWC'
kernel : tvm.te.Tensor
3-D kernel with shape [num_filter, in_channel, filter_size] for layout == 'NCW'
and [filter_size, in_channel, num_filter] for layout == 'NWC'
3-D kernel with shape [num_filter, in_channel, filter_size] for kernel_layout == 'OIW'
and [filter_size, in_channel, num_filter] for kernel_layout == 'WIO'
strides : int or tuple
The spatial stride along width
Expand All @@ -41,23 +50,27 @@ def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", o
dilation : int or tuple
Dilation rate if convolution should be dilated.
layout : str
data_layout : str
How input data is laid out, must be one of ['NCW', 'NWC']
kernel_layout: Optiona[str]
The layout of the kernel. If unspecified, use default layout. "OIW" if data_layout == "NCW",
"WIO" if data_layout == "NWC".
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, 1, layout, out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, data_layout, kernel_layout, out_dtype)


def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
"""1D convolution in NWC layout. See :py:func:`conv` for details on parameters"""
return conv(data, kernel, strides, padding, dilation, 1, "NWC", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, "NWC", "WIO", out_dtype=out_dtype)


def conv1d_ncw(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
"""1D convolution in NCW layout. See :py:func:`conv` for details on parameters"""
return conv(data, kernel, strides, padding, dilation, 1, "NCW", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, "NCW", "OIW", out_dtype=out_dtype)


def group_conv1d_nwc(
Expand Down Expand Up @@ -89,7 +102,7 @@ def group_conv1d_nwc(
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, groups, "NWC", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, groups, "NWC", "WIO", out_dtype=out_dtype)


def group_conv1d_ncw(
Expand Down Expand Up @@ -121,4 +134,4 @@ def group_conv1d_ncw(
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, groups, "NCW", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, groups, "NCW", "OIW", out_dtype=out_dtype)
Loading