Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
vinx13 committed Jul 13, 2022
1 parent d7411ae commit 249be49
Show file tree
Hide file tree
Showing 8 changed files with 32 additions and 15 deletions.
10 changes: 7 additions & 3 deletions python/tvm/relay/op/strategy/arm_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,15 @@ def conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
data, kernel = inputs
if topi.arm_cpu.is_int8_hw_support(data.dtype, kernel.dtype):
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_NCHWc_int8, True, True),
wrap_compute_conv2d(
topi.arm_cpu.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.arm_cpu",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.x86",
)
Expand All @@ -294,7 +296,9 @@ def depthwise_conv2d_NCHWc_strategy_arm_cpu(attrs, inputs, out_type, target):
"""depthwise_conv2d_NCHWc adopted from x86"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.x86",
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ def conv2d_strategy_cuda(attrs, inputs, out_type, target):
):
assert kernel_layout == "OIHW4o4i"
strategy.add_implementation(
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, True),
wrap_compute_conv2d(topi.cuda.conv2d_NCHWc_int8, need_data_layout=True),
wrap_topi_schedule(topi.cuda.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.cuda",
)
Expand Down
11 changes: 8 additions & 3 deletions python/tvm/relay/op/strategy/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,7 @@ def schedule_bitpack(attrs, outs, target):
# conv2d
def wrap_compute_conv2d(
topi_compute,
*,
need_data_layout=False,
need_kernel_layout=False,
need_out_layout=False,
Expand Down Expand Up @@ -344,13 +345,15 @@ def conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
strategy = _op.OpStrategy()
if inputs[0].dtype == "int8" or inputs[0].dtype == "uint8":
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True),
wrap_compute_conv2d(
topi.nn.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.generic",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.generic",
)
Expand All @@ -364,7 +367,9 @@ def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
logger.warning("depthwise_conv2d_NCHWc is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.nn.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.generic",
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/hls.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def conv2d_NCHWc_strategy_hls(attrs, inputs, out_type, target):
"""conv2d_NCHWc hls strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.hls.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.hls",
)
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/relay/op/strategy/intel_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,9 @@ def conv2d_strategy_intel_graphics(attrs, inputs, out_type, target):
# conv2d_NCHWc won't work without alter op layout pass
# TODO(@Laurawly): fix this
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics",
plevel=5,
Expand All @@ -71,7 +73,9 @@ def conv2d_NCHWc_strategy_intel_graphics(attrs, inputs, out_type, target):
"""conv2d_NCHWc intel_graphics strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.intel_graphics.conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.intel_graphics.conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.intel_graphics.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.intel_graphics",
)
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/relay/op/strategy/rocm.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def conv2d_strategy_rocm(attrs, inputs, out_type, target):
and padding[1] == padding[3]
):
strategy.add_implementation(
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, True),
wrap_compute_conv2d(topi.rocm.conv2d_nchw_miopen, need_data_layout=True),
wrap_topi_schedule(topi.rocm.schedule_conv2d_nchw_miopen),
name="conv2d_nchw_miopen.rocm",
plevel=50,
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/relay/op/strategy/x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,13 +269,15 @@ def conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
data, kernel = inputs
if topi.x86.is_int8_hw_support(data.dtype, kernel.dtype):
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc_int8, True, True),
wrap_compute_conv2d(
topi.x86.conv2d_NCHWc_int8, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.x86",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, True, True),
wrap_compute_conv2d(topi.x86.conv2d_NCHWc, need_data_layout=True, need_out_layout=True),
wrap_topi_schedule(topi.x86.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.x86",
)
Expand All @@ -287,7 +289,9 @@ def depthwise_conv2d_NCHWc_strategy_cpu(attrs, inputs, out_type, target):
"""depthwise_conv2d x86 strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.x86.depthwise_conv2d_NCHWc, True, True),
wrap_compute_conv2d(
topi.x86.depthwise_conv2d_NCHWc, need_data_layout=True, need_out_layout=True
),
wrap_topi_schedule(topi.x86.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.x86",
)
Expand Down
2 changes: 1 addition & 1 deletion vta/python/vta/top/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@ def conv2d_strategy_vta(attrs, inputs, out_type, target):
assert kernel.dtype == "int8"

strategy.add_implementation(
_strategy.wrap_compute_conv2d(conv2d_packed, True),
_strategy.wrap_compute_conv2d(conv2d_packed, need_data_layout=True),
_strategy.wrap_topi_schedule(schedule_conv2d_packed),
name="conv2d_packed.vta",
)
Expand Down

0 comments on commit 249be49

Please sign in to comment.