Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 13, 2022
1 parent b444f70 commit c8daa85
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 25 deletions.
35 changes: 24 additions & 11 deletions python/tvm/topi/nn/conv1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,18 +19,27 @@
from .conv2d import conv


def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", out_dtype=None):
def conv1d(
data,
kernel,
strides=1,
padding="VALID",
dilation=1,
data_layout="NCW",
kernel_layout="",
out_dtype=None,
):
"""1D convolution forward operator.
Parameters
----------
data : tvm.te.Tensor
3-D input shape [batch, in_channel, in_width] for layout == 'NCW'
and [batch, in_width, in_channel] for layout == 'NWC'
3-D input shape [batch, in_channel, in_width] for data_layout == 'NCW'
and [batch, in_width, in_channel] for data_layout == 'NWC'
kernel : tvm.te.Tensor
3-D kernel with shape [num_filter, in_channel, filter_size] for layout == 'NCW'
and [filter_size, in_channel, num_filter] for layout == 'NWC'
3-D kernel with shape [num_filter, in_channel, filter_size] for kernel_layout == 'OIW'
and [filter_size, in_channel, num_filter] for kernel_layout == 'WIO'
strides : int or tuple
The spatial stride along width
Expand All @@ -41,23 +50,27 @@ def conv1d(data, kernel, strides=1, padding="VALID", dilation=1, layout="NCW", o
dilation : int or tuple
Dilation rate if convolution should be dilated.
layout : str
data_layout : str
How input data is laid out, must be one of ['NCW', 'NWC']
kernel_layout: Optiona[str]
The layout of the kernel. If unspecified, use default layout. "OIW" if data_layout == "NCW",
"WIO" if data_layout == "NWC".
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, 1, layout, "", out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, data_layout, kernel_layout, out_dtype)


def conv1d_nwc(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
"""1D convolution in NWC layout. See :py:func:`conv` for details on parameters"""
return conv(data, kernel, strides, padding, dilation, 1, "NWC", "", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, "NWC", "WIO", out_dtype=out_dtype)


def conv1d_ncw(data, kernel, strides=1, padding="VALID", dilation=1, out_dtype=None):
"""1D convolution in NCW layout. See :py:func:`conv` for details on parameters"""
return conv(data, kernel, strides, padding, dilation, 1, "NCW", "", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, 1, "NCW", "OIW", out_dtype=out_dtype)


def group_conv1d_nwc(
Expand Down Expand Up @@ -89,7 +102,7 @@ def group_conv1d_nwc(
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, groups, "NWC", "", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, groups, "NWC", "WIO", out_dtype=out_dtype)


def group_conv1d_ncw(
Expand Down Expand Up @@ -121,4 +134,4 @@ def group_conv1d_ncw(
out_dtype : str
The output data type. If None then output is same type as input.
"""
return conv(data, kernel, strides, padding, dilation, groups, "NCW", "", out_dtype=out_dtype)
return conv(data, kernel, strides, padding, dilation, groups, "NCW", "OIW", out_dtype=out_dtype)
30 changes: 20 additions & 10 deletions python/tvm/topi/nn/conv2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,16 +57,18 @@
)


def conv2d(input, filter, strides, padding, dilation, layout="NCHW", out_dtype=None):
def conv2d(
input, filter, strides, padding, dilation, data_layout="NCHC", kernel_layout="", out_dtype=None
):
"""Conv2D operator.
Parameters
----------
input : tvm.te.Tensor
4-D with shape [batch, in_channel, in_height, in_width]
4-D with shape [batch, in_channel, in_height, in_width] in data_layout
filter : tvm.te.Tensor
4-D with shape [num_filter, in_channel, filter_height, filter_width]
4-D with shape [num_filter, in_channel, filter_height, filter_width] in kernel_layout
strides : int or a list/tuple of two ints
stride size, or [stride_height, stride_width]
Expand All @@ -79,17 +81,21 @@ def conv2d(input, filter, strides, padding, dilation, layout="NCHW", out_dtype=N
dilation: int or a list/tuple of two ints
dilation size, or [dilation_height, dilation_width]
layout : str
data_layout : str
layout of data
kernel_layout : Optional[str]
layout of kernel. If unspecified, use default layout inferred from data_layout. "OHWI" if
data_layout == "NCHW", "HWIO" if data_layout == "NHWC".
Returns
-------
output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
# search platform specific declaration first
# default declaration
return conv(input, filter, strides, padding, dilation, 1, layout, "", out_dtype)
return conv(input, filter, strides, padding, dilation, 1, data_layout, kernel_layout, out_dtype)


@tvm.target.generic_func
Expand Down Expand Up @@ -239,7 +245,7 @@ def conv2d_nchw(Input, Filter, stride, padding, dilation, out_dtype=None):
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return conv(Input, Filter, stride, padding, dilation, 1, "NCHW", "", out_dtype=out_dtype)
return conv(Input, Filter, stride, padding, dilation, 1, "NCHW", "OIHW", out_dtype=out_dtype)


def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
Expand Down Expand Up @@ -269,7 +275,7 @@ def conv2d_hwcn(Input, Filter, stride, padding, dilation, out_dtype=None):
output : tvm.te.Tensor
4-D with shape [out_height, out_width, out_channel, batch]
"""
return conv(Input, Filter, stride, padding, dilation, 1, "HWCN", "", out_dtype=out_dtype)
return conv(Input, Filter, stride, padding, dilation, 1, "HWCN", "HWIO", out_dtype=out_dtype)


def conv2d_nhwc(
Expand Down Expand Up @@ -325,7 +331,7 @@ def conv2d_nhwc(
dilation,
1,
"NHWC",
"",
"HWIO",
out_dtype,
auto_scheduler_rewritten_layout,
meta_schedule_original_shape,
Expand Down Expand Up @@ -709,7 +715,9 @@ def group_conv2d_nchw(Input, Filter, stride, padding, dilation, groups, out_dtyp
Output : tvm.te.Tensor
4-D with shape [batch, out_channel, out_height, out_width]
"""
return conv(Input, Filter, stride, padding, dilation, groups, "NCHW", "", out_dtype=out_dtype)
return conv(
Input, Filter, stride, padding, dilation, groups, "NCHW", "OIHW", out_dtype=out_dtype
)


def conv(
Expand Down Expand Up @@ -943,7 +951,9 @@ def group_conv2d_nhwc(Input, Filter, stride, padding, dilation, groups, out_dtyp
Output : tvm.te.Tensor
4-D with shape [batch, out_height, out_width, out_channel]
"""
return conv(Input, Filter, stride, padding, dilation, groups, "NHWC", "", out_dtype=out_dtype)
return conv(
Input, Filter, stride, padding, dilation, groups, "NHWC", "HWIO", out_dtype=out_dtype
)


def unpack_NCHWc_to_nchw(packed_out, out_dtype):
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/topi/nn/conv3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def conv3d_ncdhw(Input, Filter, stride, padding, dilation, groups, out_dtype=Non
Output : tvm.te.Tensor
5-D with shape [batch, out_channel, out_depth, out_height, out_width]
"""
return conv(Input, Filter, stride, padding, dilation, groups, "NCDHW", "", out_dtype)
return conv(Input, Filter, stride, padding, dilation, groups, "NCDHW", "OIDHW", out_dtype)


def conv3d_ndhwc(
Expand Down Expand Up @@ -111,7 +111,7 @@ def conv3d_ndhwc(
dilation,
groups,
"NDHWC",
"",
"DHWIO",
out_dtype,
auto_scheduler_rewritten_layout,
meta_schedule_origin_shape,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/topi/testing/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import scipy.signal

import tvm
from tvm import topi
from tvm import te, topi
from tvm.testing import assert_allclose

_injective_schedule = {
Expand Down
31 changes: 30 additions & 1 deletion tests/python/topi/python/test_topi_conv2d_nhwc.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padd
return a_np, w_np, b_np


def test_conv2d_nhwc(target, dev, ref_data, dtype, stride, padding, dilation):
def test_conv2d_nhwc_hwio(target, dev, ref_data, dtype, stride, padding, dilation):
a_np, w_np, b_np = ref_data

A = te.placeholder(a_np.shape, name="A", dtype=dtype)
Expand All @@ -95,5 +95,34 @@ def test_conv2d_nhwc(target, dev, ref_data, dtype, stride, padding, dilation):
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)


def test_conv2d_nhwc_ohwi(ref_data, dtype, stride, padding, dilation):
# only test on CPU target because topi doesn't have schedules for this layout
target = "llvm"
dev = tvm.device(target, 0)
a_np, w_np_hwio, b_np = ref_data
w_np_ohwi = w_np_hwio.transpose(3, 0, 1, 2) # HWIO -> OHWI

A = te.placeholder(a_np.shape, name="A", dtype=dtype)
W = te.placeholder(w_np_ohwi.shape, name="W", dtype=dtype)

B = topi.nn.conv2d(
A,
W,
stride,
padding,
dilation,
data_layout="NHWC",
kernel_layout="OHWI",
out_dtype="float32",
)
s = tvm.te.create_schedule(B.op)
a = tvm.nd.array(a_np, dev)
w = tvm.nd.array(w_np_ohwi, dev)
b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
func = tvm.build(s, [A, W, B], target)
func(a, w, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c8daa85

Please sign in to comment.