diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py index 4c5af610d7091..7c48b09ff00dd 100644 --- a/python/tvm/relay/op/strategy/arm_cpu.py +++ b/python/tvm/relay/op/strategy/arm_cpu.py @@ -276,13 +276,15 @@ def conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target): data, kernel = inputs if topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype): strategy.add_implementation( - wrap_compute_conv2d(topi.arm_cpu.conv2d_NCHWc_int8, True, True), + wrap_compute_conv2d( + topi.arm_cpu.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True + ), wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NCHWc_int8), name="conv2d_NCHWc_int8.arm_cpu", ) else: strategy.add_implementation( - wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True), + wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True), wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc), name="conv2d_NCHWc.x86", ) @@ -294,7 +296,9 @@ def depthwise_conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target): """depthwise_conv2d_NCHWc adopted from x86""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True), + wrap_compute_conv2d( + topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True + ), wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc), name="depthwise_conv2d_NCHWc.x86", ) diff --git a/python/tvm/relay/op/strategy/cuda.py b/python/tvm/relay/op/strategy/cuda.py index 47b33722b115f..325ca260a66f5 100644 --- a/python/tvm/relay/op/strategy/cuda.py +++ b/python/tvm/relay/op/strategy/cuda.py @@ -318,7 +318,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target): ): assert kernel_layout == "OIHW4o4i" strategy.add_implementation( - wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True), + wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, need_data_layout=True), wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8), name="conv2d_NCHWc_int8.cuda", ) diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py index 15bd35c809f8f..6074b0a69cc30 100644 --- a/python/tvm/relay/op/strategy/generic.py +++ b/python/tvm/relay/op/strategy/generic.py @@ -223,6 +223,7 @@ def schedule_bitpack(attrs, outs, target): # conv2d def wrap_compute_conv2d( topi_compute, + *, need_data_layout=False, need_kernel_layout=False, need_out_layout=False, @@ -344,13 +345,15 @@ def conv2d_NCHWc_strategy(attrs, inputs, out_type, target): strategy = _op.OpStrategy() if inputs[0].dtype == "int8" or inputs[0].dtype == "uint8": strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True), + wrap_compute_conv2d( + topi.nn.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True + ), wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8), name="conv2d_NCHWc_int8.generic", ) else: strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True), + wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, need_out_layout=True), wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc), name="conv2d_NCHWc.generic", ) @@ -364,7 +367,9 @@ def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target): logger.warning("depthwise_conv2d_NCHWc is not optimized for this platform.") strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True), + wrap_compute_conv2d( + topi.nn.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True + ), wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc), name="depthwise_conv2d_NCHWc.generic", ) diff --git a/python/tvm/relay/op/strategy/hls.py b/python/tvm/relay/op/strategy/hls.py index 1eebbd36b8472..4a682066ca2e4 100644 --- a/python/tvm/relay/op/strategy/hls.py +++ b/python/tvm/relay/op/strategy/hls.py @@ -137,7 +137,7 @@ def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target): """conv2d_NCHWc hls strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True), + wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, need_out_layout=True), wrap_topi_schedule(topi.hls.schedule_conv2d_NCHWc), name="conv2d_NCHWc.hls", ) diff --git a/python/tvm/relay/op/strategy/intel_graphics.py b/python/tvm/relay/op/strategy/intel_graphics.py index a2de49c5579e3..115a71114468c 100644 --- a/python/tvm/relay/op/strategy/intel_graphics.py +++ b/python/tvm/relay/op/strategy/intel_graphics.py @@ -44,7 +44,9 @@ def conv2d_strategy_intel_graphics(attrs, inputs, out_type, target): # conv2d_NCHWc won't work without alter op layout pass # TODO(@Laurawly): fix this strategy.add_implementation( - wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True), + wrap_compute_conv2d( + topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, need_out_layout=True + ), wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc), name="conv2d_NCHWc.intel_graphics", plevel=5, @@ -71,7 +73,9 @@ def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, out_type, target): """conv2d_NCHWc intel_graphics strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True), + wrap_compute_conv2d( + topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, need_out_layout=True + ), wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc), name="conv2d_NCHWc.intel_graphics", ) diff --git a/python/tvm/relay/op/strategy/rocm.py b/python/tvm/relay/op/strategy/rocm.py index 6e91101826c97..89cac0db4ab98 100644 --- a/python/tvm/relay/op/strategy/rocm.py +++ b/python/tvm/relay/op/strategy/rocm.py @@ -44,7 +44,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target): and padding[1] == padding[3] ): strategy.add_implementation( - wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True), + wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, need_data_layout=True), wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen), name="conv2d_nchw_miopen.rocm", plevel=50, diff --git a/python/tvm/relay/op/strategy/x86.py b/python/tvm/relay/op/strategy/x86.py index abbc9d9a4c572..17474020eefea 100644 --- a/python/tvm/relay/op/strategy/x86.py +++ b/python/tvm/relay/op/strategy/x86.py @@ -269,13 +269,15 @@ def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target): data, kernel = inputs if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype): strategy.add_implementation( - wrap_compute_conv2d(topi.x86.conv2d_NCHWc_int8, True, True), + wrap_compute_conv2d( + topi.x86.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True + ), wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc_int8), name="conv2d_NCHWc_int8.x86", ) else: strategy.add_implementation( - wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True), + wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True), wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc), name="conv2d_NCHWc.x86", ) @@ -287,7 +289,9 @@ def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target): """depthwise_conv2d x86 strategy""" strategy = _op.OpStrategy() strategy.add_implementation( - wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True), + wrap_compute_conv2d( + topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True + ), wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc), name="depthwise_conv2d_NCHWc.x86", ) diff --git a/vta/python/vta/top/op.py b/vta/python/vta/top/op.py index 6b06d88096bf3..4fa5b6ff84387 100644 --- a/vta/python/vta/top/op.py +++ b/vta/python/vta/top/op.py @@ -214,7 +214,7 @@ def conv2d_strategy_vta(attrs, inputs, out_type, target): assert kernel.dtype == "int8" strategy.add_implementation( - _strategy.wrap_compute_conv2d(conv2d_packed, True), + _strategy.wrap_compute_conv2d(conv2d_packed, need_data_layout=True), _strategy.wrap_topi_schedule(schedule_conv2d_packed), name="conv2d_packed.vta", )