Skip to content

Commit

Permalink
[ARM] Support NCHWc alter layout in the fallback mode (#10724)
Browse files Browse the repository at this point in the history
* [ARM] Support NCHWc alter layout in the fallback mode

* remove fallback path

* add test

* fixed int32_lanes and add channel check

* fixed schedule dispatch bug

* add workaround fallback path for NHWC im2col based GEMM schedule

* int32_lanes=4 by default

* typo

* update test
  • Loading branch information
masahi authored Mar 25, 2022
1 parent 63461c0 commit e9091d6
Show file tree
Hide file tree
Showing 6 changed files with 107 additions and 12 deletions.
15 changes: 12 additions & 3 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,18 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_spatial_pack),
name="conv2d_nchw_spatial_pack.arm_cpu",
plevel=10,
)

if topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype):
if (
topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype)
and kernel.shape[1] >= 64
):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nchw_int8),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nchw_int8),
name="conv2d_nchw_int8.arm_cpu",
plevel=15,
)
else:
strategy.add_implementation(
Expand Down Expand Up @@ -383,12 +388,16 @@ def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_typ
if layout == "NHWC" and data.dtype in ["int8", "uint8"]:
strategy.add_implementation(
wrap_compute_conv2d_gemm(native_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_native),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_native_without_transform
),
name="conv2d_NHWC_quantized_native_without_transform.arm_cpu",
)
strategy.add_implementation(
wrap_compute_conv2d_gemm(interleaved_compute),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved),
wrap_topi_schedule(
topi.arm_cpu.schedule_conv2d_NHWC_quantized_interleaved_without_transform
),
name="conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu",
)
else:
Expand Down
29 changes: 23 additions & 6 deletions python/tvm/topi/arm_cpu/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from ..nn import conv2d_alter_layout, conv2d_legalize
from ..utils import get_const_tuple
from ..x86.conv2d import _get_default_config as _get_x86_default_config
from ..x86.conv2d_int8 import _get_default_config_int8
from .conv2d_int8 import is_int8_hw_support
from .arm_utils import get_tiling_B_interleaved_t
from ..generic.conv2d import conv2d_alter_int8_common
Expand Down Expand Up @@ -101,9 +102,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
# we then assume it's not necessary to alter this op.
return None
cfg = dispatch_ctx.query(target, workload)
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

topi_tmpl = workload[0]
new_attrs = {k: attrs[k] for k in attrs.keys()}
Expand Down Expand Up @@ -346,6 +344,11 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):

if topi_tmpl == "conv2d_NCHWc_int8.arm_cpu":
assert data_layout == "NCHW" and kernel_layout == "OIHW"
batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, _, kh, kw = get_const_tuple(kernel_tensor.shape)

n_elems = 8

if cfg.is_fallback:
_get_default_config_int8(
cfg,
Expand All @@ -357,12 +360,14 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
out_dtype,
False,
data_layout,
int32_lanes=4,
)

batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
out_channel, channel_multiplier, kh, kw = get_const_tuple(kernel_tensor.shape)
ic_bn, oc_bn = cfg["tile_ic"].size[-1], cfg["tile_oc"].size[-1]
n_elems = 8

if cfg.is_fallback:
# ic_bn needs to be divided by n_elems below
ic_bn = max(ic_bn, n_elems)

# update new attrs
new_attrs["channels"] = out_channel
Expand Down Expand Up @@ -395,6 +400,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.contrib_conv2d_nchwc(*inputs, **new_attrs)

if topi_tmpl == "conv2d_NHWC_quantized_interleaved.arm_cpu":
# TODO(masahi): This schedule can easily result in a tensorization error
# if used in the fallback mode
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu"
Expand All @@ -411,6 +422,12 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
inputs[0], new_kernel_expr, **new_attrs
)
if topi_tmpl == "conv2d_NHWC_quantized_native.arm_cpu":
# TODO(masahi): This schedule can easily result in a tensorization error
# if used in the fallback mode
if cfg.is_fallback: # if is fallback, clear query cache and return None
autotvm.task.clear_fallback_cache(target, workload)
return None

assert data_layout == "NHWC" and kernel_layout == "HWIO"
KH, KW, _, OC = get_const_tuple(kernel.shape)
new_workload_name = "conv2d_NHWC_quantized_native_without_transform.arm_cpu"
Expand Down
12 changes: 12 additions & 0 deletions python/tvm/topi/arm_cpu/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,12 @@ def schedule_conv2d_NHWC_quantized_interleaved(cfg, outs):
return _schedule_conv2d_NHWC_quantized(cfg, outs, True)


@autotvm.register_topi_schedule("conv2d_NHWC_quantized_interleaved_without_transform.arm_cpu")
def schedule_conv2d_NHWC_quantized_interleaved_without_transform(cfg, outs):
"""Interface for interleaved schedule_conv2d_NHWC_quantized_interleaved"""
return _schedule_conv2d_NHWC_quantized(cfg, outs, True)


# Native schedules: those schedule won't interleave A (which is left in its native form).
# The weights are interleaved and transposed
@autotvm.register_topi_compute("conv2d_NHWC_quantized_native.arm_cpu")
Expand Down Expand Up @@ -330,3 +336,9 @@ def compute_conv2d_NHWC_quantized_native_without_transform(
def schedule_conv2d_NHWC_quantized_native(cfg, outs):
"""Interface for native schedule_conv2d_NHWC_quantized"""
return _schedule_conv2d_NHWC_quantized(cfg, outs, False)


@autotvm.register_topi_schedule("conv2d_NHWC_quantized_native_without_transform.arm_cpu")
def schedule_conv2d_NHWC_quantized_native_without_transform(cfg, outs):
"""Interface for native schedule_conv2d_NHWC_quantized"""
return _schedule_conv2d_NHWC_quantized(cfg, outs, False)
1 change: 1 addition & 0 deletions python/tvm/topi/x86/conv2d_alter_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
out_dtype,
False,
data_layout,
int32_lanes=16,
)

batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
Expand Down
16 changes: 13 additions & 3 deletions python/tvm/topi/x86/conv2d_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,16 @@


def _get_default_config_int8(
cfg, data, kernel, strides, padding, dilation, out_dtype, is_depthwise=False, layout="NCHW"
cfg,
data,
kernel,
strides,
padding,
dilation,
out_dtype,
is_depthwise=False,
layout="NCHW",
int32_lanes=4,
):
"""
Get default schedule config for the workload
Expand All @@ -50,11 +59,11 @@ def _get_default_config_int8(
is_kernel_1x1 = wkl.kernel_h == 1 and wkl.kernel_w == 1
if is_kernel_1x1:
conv2d_generic.fallback_schedule_cpu_1x1_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4
cfg, wkl, int32_lanes=int32_lanes, num_int8_elements=4
)
else:
conv2d_generic.fallback_schedule_cpu_common_int8(
cfg, wkl, int32_lanes=16, num_int8_elements=4
cfg, wkl, int32_lanes=int32_lanes, num_int8_elements=4
)


Expand Down Expand Up @@ -163,6 +172,7 @@ def conv2d_NCHWc_int8(cfg, data, kernel, strides, padding, dilation, layout, out
padding,
dilation,
out_dtype,
int32_lanes=16,
)

# Pack data if raw 4-D data is provided.
Expand Down
46 changes: 46 additions & 0 deletions tests/python/unittest/test_meta_schedule_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,11 @@
# under the License.
import sys
from typing import List
import numpy as np

import pytest
import tvm
from tvm import relay
from tvm import meta_schedule as ms
from tvm.ir.module import IRModule
from tvm.meta_schedule.database import PyDatabase, TuningRecord, Workload
Expand Down Expand Up @@ -149,5 +151,49 @@ def extract_task_qbert():
assert "vnni" in annotations["schedule_rule"]


def extract_task_arm_conv2d_nchwc():
data_shape = (1, 64, 128, 128)
weight_shape = (32, 64, 1, 1)
bias_shape = (weight_shape[0],)
padding = (1, 1)

data = relay.var("data", shape=data_shape, dtype="int8")
weight = relay.var("weight", shape=weight_shape, dtype="int8")
bias = relay.var("bias", shape=bias_shape, dtype="int32")
conv2d = relay.nn.conv2d(
data=data,
weight=weight,
kernel_size=weight_shape[2:],
channels=weight_shape[0],
padding=padding,
strides=(1, 1),
out_dtype="int32",
)
bias_add = relay.nn.bias_add(conv2d, bias)
relay_mod = tvm.IRModule.from_expr(bias_add)

weight_np = np.random.uniform(1, 10, size=weight_shape).astype("int8")
bias_np = np.random.uniform(1, 10, size=bias_shape).astype("int32")

params = {"weight": weight_np, "bias": bias_np}

target = "llvm -device arm_cpu -mtriple aarch64-linux-gnu -mattr=+neon"
extracted_tasks = extract_task_from_relay(relay_mod, target, params)
tune_tasks = list(
filter(
lambda task: "conv2d" in task.task_name,
extracted_tasks,
)
)

assert len(tune_tasks) == 1

relay_func = list(tune_tasks[0].mod.functions.values())[0]
out_type = relay_func.body.checked_type

# Check that the output is in NCHWc layout
assert list(out_type.shape) == [1, 8, 130, 130, 4]


if __name__ == "__main__":
sys.exit(pytest.main([__file__] + sys.argv[1:]))

0 comments on commit e9091d6

Please sign in to comment.