Skip to content

Commit

Permalink
add cuda conv2d strategy
Browse files Browse the repository at this point in the history
  • Loading branch information
icemelon committed Jan 27, 2020
1 parent 43fd6f1 commit 02cd271
Show file tree
Hide file tree
Showing 19 changed files with 762 additions and 688 deletions.
4 changes: 3 additions & 1 deletion python/tvm/relay/op/_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ def compute_argwhere(attrs, inputs, output_type):

_reg.register_schedule("argwhere", strategy.schedule_argwhere)

############################### shape func #################################
#####################
# Shape functions #
#####################

@script
def _arange_shape_func(start, stop, step):
Expand Down
171 changes: 11 additions & 160 deletions python/tvm/relay/op/nn/_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,100 +87,8 @@ def compute_sparse_transpose(attrs, inputs, out_type):


# conv2d
def _find_conv2d_op(op):
"""Find the op with conv2d in its tag by traversing."""
if 'conv2d' in op.tag:
return op
for tensor in op.input_tensors:
op_ = _find_conv2d_op(tensor.op)
if op_ is not None:
return op_
return None

# @reg.register_compute("nn.conv2d")
# def compute_conv2d(attrs, inputs, out_type, target):
# """Compute definition of conv2d"""
# padding = get_const_tuple(attrs.padding)
# strides = get_const_tuple(attrs.strides)
# dilation = get_const_tuple(attrs.dilation)
# groups = attrs.groups
# layout = attrs.data_layout
# kernel_layout = attrs.kernel_layout
# out_dtype = attrs.out_dtype
# out_dtype = (inputs[0].dtype if out_dtype in ("same", "")
# else out_dtype)
#
# assert layout in ["NCHW", "NHWC", "NCHW4c", "HWCN"]
# (dilation_h, dilation_w) = dilation
# if dilation_h < 1 or dilation_w < 1:
# raise ValueError("dilation should be positive value")
#
# def _get_out_depth():
# weight_shape = get_const_tuple(inputs[1].shape)
# # NHWC layout
# if kernel_layout.startswith("HW"):
# return weight_shape[2] * weight_shape[3]
# # NCHW layout.
# # in ARM CPU contrib_spatial_pack schedule, we will prepack weight layout
# if len(weight_shape) == 4:
# return weight_shape[0] * weight_shape[1]
# else:
# assert len(weight_shape) == 5
# C, M, _, _, VC = weight_shape
# return C * VC * M
#
# if groups == 1:
# out = topi.nn.conv2d(
# inputs[0], inputs[1], strides, padding,
# dilation, layout, out_dtype)
# elif layout == "NCHW" and _get_out_depth() == groups:
# out = topi.nn.depthwise_conv2d_nchw(
# inputs[0], inputs[1], strides, padding, dilation, out_dtype)
# elif layout == "NHWC" and kernel_layout == "HWOI" and _get_out_depth() == groups:
# out = topi.nn.depthwise_conv2d_nhwc(
# inputs[0], inputs[1], strides, padding, dilation, out_dtype)
# elif layout in ['NCHW', 'NCHW4c']:
# out = topi.nn.group_conv2d_nchw(inputs[0], inputs[1], strides, padding, dilation, groups,
# out_dtype)
# else:
# raise ValueError("not support arbitrary group number for now")
# return [out]


# @reg.register_schedule("nn.conv2d")
# def schedule_conv2d(attrs, outs, target):
# """Schedule definition of conv2d"""
# groups = attrs.groups
# layout = attrs.data_layout
# kernel_layout = attrs.kernel_layout
#
# with target:
# if groups == 1 and layout == "NCHW":
# return topi.generic.schedule_conv2d_nchw(outs)
# elif groups == 1 and layout == "NCHW4c":
# return topi.generic.schedule_conv2d_nchw(outs)
# elif groups == 1 and layout == "NHWC":
# return topi.generic.schedule_conv2d_nhwc(outs)
# elif groups == 1 and layout == "HWCN":
# return topi.generic.schedule_conv2d_hwcn(outs)
# elif groups != 1:
# # collect in_channels to distinguish depthwise and group conv2d
# op = _find_conv2d_op(outs[0].op)
# assert op is not None
#
# is_depthwise = 'depthwise' in op.tag
# if is_depthwise:
# if layout == "NCHW":
# # TODO(leyuan, merrymercy, Huyuwei): fold depthwise topi into conv2d.
# return topi.generic.schedule_depthwise_conv2d_nchw(outs)
# if layout == "NHWC" and kernel_layout == "HWOI":
# return topi.generic.schedule_depthwise_conv2d_nhwc(outs)
# else:
# if layout in ["NCHW", "NCHW4c"]:
# return topi.generic.schedule_group_conv2d_nchw(outs)
# raise ValueError("No compatible schedule")

reg.register_strategy("nn.conv2d", strategy.conv2d_strategy)
reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)

@reg.register_alter_op_layout("nn.conv2d")
def alter_op_layout_conv2d(attrs, inputs, tinfos, out_type):
Expand All @@ -207,7 +115,6 @@ def legalize_conv2d(attrs, inputs, types):
"""
return topi.nn.conv2d_legalize(attrs, inputs, types)


@reg.register_convert_op_layout("nn.conv2d")
def convert_conv2d(attrs, inputs, tinfos, desired_layout):
"""Convert Layout pass registration for conv2d op.
Expand Down Expand Up @@ -248,8 +155,6 @@ def convert_conv2d(attrs, inputs, tinfos, desired_layout):
return relay.nn.conv2d(data, weight, **new_attrs)
return None

reg.register_pattern("nn.conv2d", OpPattern.OUT_ELEMWISE_FUSABLE)


# conv2d_transpose
reg.register_strategy("nn.conv2d_transpose", strategy.conv2d_transpose_strategy)
Expand Down Expand Up @@ -421,36 +326,9 @@ def compute_mirror_pad(attrs, inputs, out_dtype, target):
reg.register_strategy_broadcast("nn.mirror_pad")


# winograd related operators
@reg.register_compute("nn.contrib_conv2d_winograd_without_weight_transform")
def compute_contrib_conv2d_winograd_without_weight_transform(attrs, inputs, out_dtype):
"""Compute definition of conv2d_winograd_without_weight_transform"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
data_layout = attrs.get_str("data_layout")
out_dtype = attrs.get_str("out_dtype")
tile_size = attrs.get_int("tile_size")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"

out = topi.nn.conv2d_winograd_without_weight_transform(
inputs[0], inputs[1], strides, padding, dilation, data_layout,
out_dtype, tile_size)

return [out]


# @reg.register_schedule("nn.contrib_conv2d_winograd_without_weight_transform")
# def schedule_contrib_conv2d_winograd_without_weight_transform(attrs, outs, target):
# """Schedule definition of conv2d_winograd_without_weight_transform"""
# with target:
# return topi.generic.schedule_conv2d_winograd_without_weight_transform(outs)


# conv2d_winograd related operators
reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
strategy.conv2d_winograd_without_weight_transfrom_strategy)
reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

Expand All @@ -462,14 +340,8 @@ def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype):
inputs[0], attrs.get_int('tile_size'))
return [out]


# @reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform")
# def schedule_contrib_conv2d_winograd_weight_transform(attrs, outs, target):
# """Schedule definition of contrib_conv2d_winograd_weight_transform"""
# with target:
# return topi.generic.schedule_conv2d_winograd_weight_transform(outs)


reg.register_schedule("nn.contrib_conv2d_winograd_weight_transform",
strategy.schedule_conv2d_winograd_weight_transform)
reg.register_pattern("nn.contrib_conv2d_winograd_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down Expand Up @@ -535,31 +407,8 @@ def compute_contrib_conv2d_winograd_nnpack_weight_transform(attrs, inputs, out_d
OpPattern.OUT_ELEMWISE_FUSABLE)

# depthwise_conv2d_NCHWc
@reg.register_compute("nn.contrib_depthwise_conv2d_NCHWc")
def compute_contrib_depthwise_conv2d_NCHWc(attrs, inputs, out_dtype, target):
"""Compute definition of depthwise conv2d NCHWc"""
# pylint: disable=assignment-from-no-return
padding = attrs.get_int_tuple("padding")
strides = attrs.get_int_tuple("strides")
dilation = attrs.get_int_tuple("dilation")
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.get_str("out_dtype")
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype

out = topi.nn.depthwise_conv2d_NCHWc(inputs[0], inputs[1], strides, padding, dilation,
data_layout, out_layout, out_dtype)
return [out]


# @reg.register_schedule("nn.contrib_depthwise_conv2d_NCHWc")
# def schedule_contrib_depthwise_conv2d_NCHWc(attrs, outs, target):
# """Schedule definition of contrib_conv2d_NCHWc"""
# with target:
# return topi.generic.schedule_depthwise_conv2d_NCHWc(outs)


reg.register_strategy("nn.contrib_depthwise_conv2d_NCHWc", strategy.depthwise_conv2d_NCHWc_strategy)
reg.register_strategy("nn.contrib_depthwise_conv2d_NCHWc",
strategy.depthwise_conv2d_NCHWc_strategy)
reg.register_pattern("nn.contrib_depthwise_conv2d_NCHWc",
OpPattern.OUT_ELEMWISE_FUSABLE)

Expand Down Expand Up @@ -658,7 +507,9 @@ def compute_space_to_depth(attrs, inputs, out_dtype):
reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE)


############################### shape func #################################
#####################
# Shape functions #
#####################

@script
def _conv2d_NCHWc_shape_func(dshape, kshape, strides, padding, dilation, oc_bn):
Expand Down
98 changes: 97 additions & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,102 @@ def schedule_l2_normalize_cuda(attrs, outs, target):
with target:
return topi.cuda.schedule_l2_normalize(outs)

@conv2d_strategy.register(["cuda", "gpu"])
def conv2d_strategy_cuda(attrs, inputs, out_type, target):
"""conv2d cuda strategy"""
strategy = _op.OpStrategy()
data, kernel = inputs
dilation_h, dilation_w = attrs.get_int_tuple("dilation")
groups = attrs.groups
layout = attrs.data_layout
stride_h, stride_w = attrs.get_int_tuple("strides")
kernel_layout = attrs.kernel_layout
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")

if groups == 1:
if layout == "NCHW":
# TODO(@vinx13, @icemelon9): Use conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw))
_, _, kh, kw = get_const_tuple(kernel.shape)
if kh <= 7 and kw <= 7 and kh == kw and stride_h == 1 and stride_w == 1:
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nchw_winograd),
15)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.conv2d_hwcn),
wrap_topi_schedule(topi.cuda.schedule_conv2d_hwcn))
elif layout == "NHWC":
assert kernel_layout == ""
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_conv2d_nhwc))
elif layout == "NCHW4c" and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8))
else:
raise RuntimeError("Unsupported conv2d layout {} for CUDA".format(layout))
# add cudnn implementation
if target.target_name == "cuda" and "cudnn" in target.libs:
if layout in ["NCHW", "NHWC"]:
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.conv2d_cudnn, True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_cudnn), 5)
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nchw))
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implement(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc))
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
if layout == 'NCHW':
# TODO(@vinx13, @icemelon9): Use group_conv2d_NCHWc_int8 when dtype is int8/uint8.
assert kernel_layout == "OIHW"
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_nchw))
elif layout == 'NCHW4c' and data.dtype in ["int8", "uint8"]:
assert kernel_layout == "OIHW4o4i"
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.group_conv2d_NCHWc_int8, True),
wrap_topi_schedule(topi.cuda.schedule_group_conv2d_NCHWc_int8))
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy

@conv2d_winograd_without_weight_transfrom_strategy.register(["cuda", "gpu"])
def conv2d_winograd_without_weight_transfrom_strategy_cuda(attrs, inputs, out_type, target):
dilation = attrs.get_int_tuple("dilation")
groups = attrs.get_int("groups")
layout = attrs.data_layout
assert dilation == (1, 1), "Do not support dilate now"
assert groups == 1, "Do not supoort arbitrary group number"
strategy = _op.OpStrategy()
if layout == "NCHW":
strategy.add_implement(
wrap_compute_conv2d(topi.cuda.conv2d_nchw_winograd_without_weight_transform),
wrap_topi_schedule(
topi.cuda.schedule_conv2d_nchw_winograd_without_weight_transform_cuda))
else:
raise RuntimeError("Unsupported conv2d_winograd_without_weight_transfrom layout {}".
format(layout))
return strategy

@deformable_conv2d_strategy.register(["cuda", "gpu"])
def deformable_conv2d_strategy_cuda(attrs, inputs, out_type, target):
"""deformable_conv2d cuda strategy"""
Expand Down Expand Up @@ -108,7 +204,7 @@ def conv3d_strategy_cuda(attrs, inputs, out_type, target):
assert layout in ["NCDHW", "NDHWC"], "Not support this layout {} yet".format(layout)
if layout == "NCDHW":
strategy.add_implement(wrap_compute_conv3d(topi.cuda.conv3d_ncdhw),
_reg._wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
wrap_topi_schedule(topi.cuda.schedule_conv3d_ncdhw),
10)
else: # layout == "NDHWC":
strategy.add_implement(wrap_compute_conv3d(topi.cuda.conv3d_ndhwc),
Expand Down
Loading

0 comments on commit 02cd271

Please sign in to comment.