From 9640f1f0be3aee110f82d06add497cde7c7febdb Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 3 Mar 2022 19:22:22 -0800 Subject: [PATCH 1/3] [MetaSchedule] Enable AutoTVM-style template-based search space --- .../testing/conv2d_winograd_cpu.py | 172 +++++++++++++ .../testing/conv2d_winograd_cuda.py | 172 +++++++++++++ python/tvm/topi/cuda/conv2d_nhwc_winograd.py | 27 +- python/tvm/topi/cuda/conv2d_winograd.py | 20 +- python/tvm/topi/nn/conv2d.py | 23 +- python/tvm/topi/nn/pad.py | 8 +- python/tvm/topi/utils.py | 19 +- .../schedule_rule/multi_level_tiling.cc | 13 +- src/meta_schedule/schedule_rule/winograd.cc | 96 +++++++ .../space_generator/post_order_apply.cc | 54 ++-- src/runtime/threading_backend.cc | 2 +- ..._meta_schedule_custom_rule_winograd_cpu.py | 209 +++++++++++++++ ...meta_schedule_custom_rule_winograd_cuda.py | 243 ++++++++++++++++++ .../test_meta_schedule_post_order_apply.py | 45 +++- 14 files changed, 1052 insertions(+), 51 deletions(-) create mode 100644 python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py create mode 100644 python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py create mode 100644 src/meta_schedule/schedule_rule/winograd.cc create mode 100644 tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py create mode 100644 tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py new file mode 100644 index 000000000000..01a32794f769 --- /dev/null +++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +from tvm.script import tir as T + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,no-self-use,unused-argument + + +@T.prim_func +def conv2d_winograd_cpu( + X: T.Buffer[(1, 14, 14, 128), "float32"], + W: T.Buffer[(6, 6, 128, 128), "float32"], + conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], +) -> None: + # body + data_pad = T.alloc_buffer([1, 16, 16, 128]) + input_tile = T.alloc_buffer([6, 6, 9, 128]) + B = T.alloc_buffer([6, 6]) + data_pack = T.alloc_buffer([6, 6, 9, 128]) + bgemm = T.alloc_buffer([6, 6, 9, 128]) + A = T.alloc_buffer([6, 4]) + inverse = T.alloc_buffer([4, 4, 9, 128]) + for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.block_attr({"schedule_rule": "None"}) + T.reads([X[i0_1, i1_1, i2_1, i3_1]]) + T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, + X[i0_1, i1_1, i2_1, i3_1], + T.float32(0), + dtype="float32", + ) + for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): + with T.block("input_tile"): + eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) + T.block_attr({"schedule_rule": "None"}) + T.reads( + data_pad[ + T.floordiv(p, 9), + ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), + ((T.floormod(p, 3) * 4) + nu), + ci, + ] + ) + T.writes([input_tile[eps, nu, p, ci]]) + input_tile[eps, nu, p, ci] = data_pad[ + T.floordiv(p, 9), + ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), + ((T.floormod(p, 3) * 4) + nu), + ci, + ] + for i0_3, i1_3 in T.grid(6, 6): + with T.block("B"): + i, j = T.axis.remap("SS", [i0_3, i1_3]) + T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) + T.writes([B[i, j]]) + # fmt: off + B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + # fmt: on + for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): + with T.block("data_pack"): + eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( + "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] + ) + T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cpu"}) + T.reads( + [ + data_pack[eps_1, nu_1, p_1, ci_1], + input_tile[r_a, r_b, p_1, ci_1], + B[ + T.min(r_a, r_b) : ( + T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) + ), + T.min(eps_1, nu_1) : ( + T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) + ), + ], + ] + ) + T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) + with T.init(): + data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) + data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + ( + (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1] + ) + for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): + with T.block("bgemm"): + eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) + T.block_attr({"meta_schedule.write_cache_level": [2]}) + T.reads( + [ + bgemm[eps_2, nu_2, p_2, co], + data_pack[eps_2, nu_2, p_2, ci_2], + W[eps_2, nu_2, co, ci_2], + ] + ) + T.writes([bgemm[eps_2, nu_2, p_2, co]]) + with T.init(): + bgemm[eps_2, nu_2, p_2, co] = T.float32(0) + bgemm[eps_2, nu_2, p_2, co] = ( + bgemm[eps_2, nu_2, p_2, co] + + data_pack[eps_2, nu_2, p_2, ci_2] * W[eps_2, nu_2, co, ci_2] + ) + for i0_6, i1_6 in T.grid(6, 4): + with T.block("A"): + i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) + T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) + T.writes([A[i_1, j_1]]) + # fmt: off + A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) + # fmt: on + for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): + with T.block("inverse"): + vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( + "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] + ) + T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"}) + T.reads( + [ + inverse[vh, vw, p_3, co_1], + bgemm[r_a_1, r_b_1, p_3, co_1], + A[ + T.min(r_a_1, r_b_1) : ( + T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) + ), + T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), + ], + ] + ) + T.writes([inverse[vh, vw, p_3, co_1]]) + with T.init(): + inverse[vh, vw, p_3, co_1] = T.float32(0) + inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + ( + (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw] + ) + for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): + with T.block("conv2d_winograd"): + n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) + T.reads( + [ + inverse[ + T.floormod(h, 4), + T.floormod(w, 4), + (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), + co_2, + ] + ] + ) + T.writes([conv2d_winograd[n, h, w, co_2]]) + conv2d_winograd[n, h, w, co_2] = inverse[ + T.floormod(h, 4), + T.floormod(w, 4), + (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), + co_2, + ] diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py new file mode 100644 index 000000000000..59ae737b1348 --- /dev/null +++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py @@ -0,0 +1,172 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +from tvm.script import tir as T + +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,no-self-use,unused-argument + + +@T.prim_func +def conv2d_winograd_cuda( + placeholder: T.Buffer[(1, 14, 14, 128), "float32"], + placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"], + conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], +) -> None: + data_pad = T.alloc_buffer([1, 16, 16, 128]) + input_tile = T.alloc_buffer([6, 6, 9, 128]) + B = T.alloc_buffer([6, 6]) + data_pack = T.alloc_buffer([6, 6, 9, 128]) + bgemm = T.alloc_buffer([6, 6, 9, 128]) + A = T.alloc_buffer([6, 4]) + inverse = T.alloc_buffer([4, 4, 9, 128]) + for i0, i1, i2, i3 in T.grid(1, 16, 16, 128): + with T.block("data_pad"): + i0_1, i1_1, i2_1, i3_1 = T.axis.remap("SSSS", [i0, i1, i2, i3]) + T.block_attr({"schedule_rule": "None"}) + T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) + T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) + data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( + 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, + placeholder[i0_1, i1_1, i2_1, i3_1], + T.float32(0), + dtype="float32", + ) + for i0_2, i1_2, i2_2, i3_2 in T.grid(6, 6, 9, 128): + with T.block("input_tile"): + eps, nu, p, ci = T.axis.remap("SSSS", [i0_2, i1_2, i2_2, i3_2]) + T.block_attr({"schedule_rule": "None"}) + T.reads( + [ + data_pad[ + T.floordiv(p, 9), + ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), + ((T.floormod(p, 3) * 4) + nu), + ci, + ] + ] + ) + T.writes([input_tile[eps, nu, p, ci]]) + input_tile[eps, nu, p, ci] = data_pad[ + T.floordiv(p, 9), + ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), + ((T.floormod(p, 3) * 4) + nu), + ci, + ] + for i0_3, i1_3 in T.grid(6, 6): + with T.block("B"): + i, j = T.axis.remap("SS", [i0_3, i1_3]) + T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) + T.writes([B[i, j]]) + # fmt: off + B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + # fmt: on + for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): + with T.block("data_pack"): + eps_1, nu_1, p_1, ci_1, r_a, r_b = T.axis.remap( + "SSSSRR", [i0_4, i1_4, i2_3, i3_3, i4, i5] + ) + T.block_attr({"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}) + T.reads( + [ + data_pack[eps_1, nu_1, p_1, ci_1], + input_tile[r_a, r_b, p_1, ci_1], + B[ + T.min(r_a, r_b) : ( + T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) + ), + T.min(eps_1, nu_1) : ( + T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) + ), + ], + ] + ) + T.writes([data_pack[eps_1, nu_1, p_1, ci_1]]) + with T.init(): + data_pack[eps_1, nu_1, p_1, ci_1] = T.float32(0) + data_pack[eps_1, nu_1, p_1, ci_1] = data_pack[eps_1, nu_1, p_1, ci_1] + ( + (input_tile[r_a, r_b, p_1, ci_1] * B[r_a, eps_1]) * B[r_b, nu_1] + ) + for i0_5, i1_5, i2_4, i3_4, i4_1 in T.grid(6, 6, 9, 128, 128): + with T.block("bgemm"): + eps_2, nu_2, p_2, co, ci_2 = T.axis.remap("SSSSR", [i0_5, i1_5, i2_4, i3_4, i4_1]) + T.block_attr({"meta_schedule.write_cache_level": [3]}) + T.reads( + [ + bgemm[eps_2, nu_2, p_2, co], + data_pack[eps_2, nu_2, p_2, ci_2], + placeholder_1[eps_2, nu_2, co, ci_2], + ] + ) + T.writes([bgemm[eps_2, nu_2, p_2, co]]) + with T.init(): + bgemm[eps_2, nu_2, p_2, co] = T.float32(0) + bgemm[eps_2, nu_2, p_2, co] = bgemm[eps_2, nu_2, p_2, co] + ( + data_pack[eps_2, nu_2, p_2, ci_2] * placeholder_1[eps_2, nu_2, co, ci_2] + ) + for i0_6, i1_6 in T.grid(6, 4): + with T.block("A"): + i_1, j_1 = T.axis.remap("SS", [i0_6, i1_6]) + T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) + T.writes([A[i_1, j_1]]) + # fmt: off + A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) + # fmt: on + for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): + with T.block("inverse"): + vh, vw, p_3, co_1, r_a_1, r_b_1 = T.axis.remap( + "SSSSRR", [i0_7, i1_7, i2_5, i3_5, i4_2, i5_1] + ) + T.block_attr({"schedule_rule": "meta_schedule.winograd_inverse"}) + T.reads( + [ + inverse[vh, vw, p_3, co_1], + bgemm[r_a_1, r_b_1, p_3, co_1], + A[ + T.min(r_a_1, r_b_1) : ( + T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) + ), + T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), + ], + ] + ) + T.writes([inverse[vh, vw, p_3, co_1]]) + with T.init(): + inverse[vh, vw, p_3, co_1] = T.float32(0) + inverse[vh, vw, p_3, co_1] = inverse[vh, vw, p_3, co_1] + ( + (bgemm[r_a_1, r_b_1, p_3, co_1] * A[r_a_1, vh]) * A[r_b_1, vw] + ) + for i0_8, i1_8, i2_6, i3_6 in T.grid(1, 12, 12, 128): + with T.block("conv2d_winograd"): + n, h, w, co_2 = T.axis.remap("SSSS", [i0_8, i1_8, i2_6, i3_6]) + T.reads( + [ + inverse[ + T.floormod(h, 4), + T.floormod(w, 4), + (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), + co_2, + ] + ] + ) + T.writes([conv2d_winograd[n, h, w, co_2]]) + conv2d_winograd[n, h, w, co_2] = inverse[ + T.floormod(h, 4), + T.floormod(w, 4), + (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), + co_2, + ] diff --git a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py index 698beeac6dc4..80745a90d9ff 100644 --- a/python/tvm/topi/cuda/conv2d_nhwc_winograd.py +++ b/python/tvm/topi/cuda/conv2d_nhwc_winograd.py @@ -20,15 +20,17 @@ """Winograd template for cuda backend""" import tvm -from tvm import te -from tvm import autotvm +from tvm import autotvm, te + from .. import nn -from ..utils import get_const_int, get_const_tuple, traverse_inline from ..nn.winograd_util import winograd_transform_matrices -from .tensor_intrin import intrin_wmma_load_matrix_A -from .tensor_intrin import intrin_wmma_load_matrix_W -from .tensor_intrin import intrin_wmma_store_matrix -from .tensor_intrin import intrin_wmma_gemm +from ..utils import get_const_int, get_const_tuple, traverse_inline +from .tensor_intrin import ( + intrin_wmma_gemm, + intrin_wmma_load_matrix_A, + intrin_wmma_load_matrix_W, + intrin_wmma_store_matrix, +) def _infer_tile_size(data, kernel): @@ -332,7 +334,13 @@ def nhwc_winograd_cuda( assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) - data_pad = nn.pad(data, (0, pt, pl, 0), (0, pb, pr, 0), name="data_pad") + data_pad = nn.pad( + data, + (0, pt, pl, 0), + (0, pb, pr, 0), + name="data_pad", + attrs={"schedule_rule": "None"}, + ) r = KW m = tile_size @@ -388,6 +396,7 @@ def nhwc_winograd_cuda( idxdiv(p, (nH * nW)), idxmod(idxdiv(p, nW), nH) * m + eps, idxmod(p, nW) * m + nu, c ], name="d", + attrs={"schedule_rule": "None"}, ) # Transform data @@ -399,6 +408,7 @@ def nhwc_winograd_cuda( input_tile[p][ci][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] ), name="data_pack", + attrs={"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}, ) # Convert data type of input feature maps and weights for tensorcore @@ -430,6 +440,7 @@ def nhwc_winograd_cuda( bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] ), name="inverse", + attrs={"schedule_rule": "meta_schedule.winograd_inverse"}, ) # Output diff --git a/python/tvm/topi/cuda/conv2d_winograd.py b/python/tvm/topi/cuda/conv2d_winograd.py index 8a3f009c7ca5..4ff3f52b998f 100644 --- a/python/tvm/topi/cuda/conv2d_winograd.py +++ b/python/tvm/topi/cuda/conv2d_winograd.py @@ -18,15 +18,14 @@ """Winograd template for cuda backend""" import logging + import tvm -from tvm import te -from tvm import autotvm +from tvm import autotvm, te from .. import nn -from ..utils import get_const_int, get_const_tuple, traverse_inline +from ..nn.conv2d import _conv2d_winograd_nhwc_impl, conv2d_winograd_nhwc from ..nn.winograd_util import winograd_transform_matrices -from ..nn.conv2d import conv2d_winograd_nhwc, _conv2d_winograd_nhwc_impl - +from ..utils import get_const_int, get_const_tuple, traverse_inline logger = logging.getLogger("conv2d_winograd") @@ -78,7 +77,13 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ assert HSTR == 1 and WSTR == 1 and dilation_h == 1 and dilation_w == 1 pt, pl, pb, pr = nn.get_pad_tuple(padding, (KH, KW)) - data_pad = nn.pad(data, (0, 0, pt, pl), (0, 0, pb, pr), name="data_pad") + data_pad = nn.pad( + data, + (0, 0, pt, pl), + (0, 0, pb, pr), + name="data_pad", + attrs={"schedule_rule": "None"}, + ) r = KW m = tile_size @@ -113,6 +118,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ idxmod(idxdiv(p, nW), nH) * m + eps ][idxmod(p, nW) * m + nu], name="d", + attrs={"schedule_rule": "None"}, ) # transform data @@ -124,6 +130,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] ), name="data_pack", + attrs={"schedule_rule": "meta_schedule.winograd_data_pack.cuda"}, ) # do batch gemm @@ -145,6 +152,7 @@ def winograd_cuda(cfg, data, kernel, strides, padding, dilation, out_dtype, pre_ bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] ), name="inverse", + attrs={"schedule_rule": "meta_schedule.winograd_inverse"}, ) # output diff --git a/python/tvm/topi/nn/conv2d.py b/python/tvm/topi/nn/conv2d.py index 34357508122f..b1230c0398c2 100644 --- a/python/tvm/topi/nn/conv2d.py +++ b/python/tvm/topi/nn/conv2d.py @@ -19,11 +19,11 @@ """Conv2D operators""" from __future__ import absolute_import as _abs -from collections import namedtuple import re -from typing import Union, Sequence, Optional -import numpy as np +from collections import namedtuple +from typing import Optional, Sequence, Union +import numpy as np import tvm from tvm import auto_scheduler, te @@ -1019,7 +1019,11 @@ def _conv2d_winograd_nhwc_impl( pad_extra = (nW - 1) * m + alpha - (H + pad_t + pad_b) data_pad = pad( - data, (0, pad_t, pad_l, 0), (0, pad_b + pad_extra, pad_r + pad_extra, 0), name="data_pad" + data, + (0, pad_t, pad_l, 0), + (0, pad_b + pad_extra, pad_r + pad_extra, 0), + name="data_pad", + attrs={"schedule_rule": "None"}, ) if not pre_computed: @@ -1044,6 +1048,7 @@ def _conv2d_winograd_nhwc_impl( (p % nW) * m + nu ][ci], name="input_tile", + attrs={"schedule_rule": "None"}, ) # transform data @@ -1055,7 +1060,10 @@ def _conv2d_winograd_nhwc_impl( input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b] ), name="data_pack", - attrs={"auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"]}, + attrs={ + "auto_scheduler_simplify_const_tensor_indices": ["eps", "nu", "r_a", "r_b"], + "schedule_rule": "meta_schedule.winograd_data_pack.cpu", + }, # the attrs are necessary hints for the auto-scheduler ) @@ -1082,7 +1090,10 @@ def _conv2d_winograd_nhwc_impl( bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b] ), name="inverse", - attrs={"auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"]}, + attrs={ + "auto_scheduler_simplify_const_tensor_indices": ["vh", "vw", "r_a", "r_b"], + "schedule_rule": "meta_schedule.winograd_inverse", + }, # the attrs are necessary hints for the auto-scheduler ) diff --git a/python/tvm/topi/nn/pad.py b/python/tvm/topi/nn/pad.py index 78e41b5af92a..4e76104fb08b 100644 --- a/python/tvm/topi/nn/pad.py +++ b/python/tvm/topi/nn/pad.py @@ -16,14 +16,16 @@ # under the License. """Pad the data by constant value """ from __future__ import absolute_import as _abs + import tvm from tvm import te -from ..utils import equal_const_int + from .. import tag +from ..utils import equal_const_int @tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") -def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput"): +def pad(data, pad_before, pad_after=None, pad_value=0.0, name="PadInput", attrs=None): """Pad Input with zeros. Parameters @@ -85,7 +87,7 @@ def _pad(*indices): return tvm.tir.if_then_else(not_zero, data(*index_tuple), pad_value) return data(*index_tuple) - return te.compute(out_shape, _pad, name=name) + return te.compute(out_shape, _pad, name=name, attrs=attrs) @tvm.te.tag_scope(tag=tag.INJECTIVE + ",pad") diff --git a/python/tvm/topi/utils.py b/python/tvm/topi/utils.py index be3df2be5f6a..0e39a6ce9a4b 100644 --- a/python/tvm/topi/utils.py +++ b/python/tvm/topi/utils.py @@ -17,14 +17,15 @@ # pylint: disable=invalid-name """Common topi utilities""" from __future__ import absolute_import as _abs -from numbers import Integral -import numpy as np +from numbers import Integral +import numpy as np import tvm from tvm import te -from tvm.tir import layout, bijective_layout -from . import tag, cpp +from tvm.tir import bijective_layout, layout + +from . import cpp, tag class InvalidShapeError(ValueError): @@ -347,7 +348,15 @@ def select_array(i, j): ) return now - return te.compute(matrix.shape, select_array, name=name, attrs={"const_matrix": True}) + return te.compute( + matrix.shape, + select_array, + name=name, + attrs={ + "const_matrix": True, + "schedule_rule": "meta_schedule.compute_inline", + }, + ) def get_max_power2_factor(n, max_value=None): diff --git a/src/meta_schedule/schedule_rule/multi_level_tiling.cc b/src/meta_schedule/schedule_rule/multi_level_tiling.cc index d0bfff40fcbe..84ba0dd034a4 100644 --- a/src/meta_schedule/schedule_rule/multi_level_tiling.cc +++ b/src/meta_schedule/schedule_rule/multi_level_tiling.cc @@ -230,12 +230,19 @@ inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const if (config.req == ReuseType::kNoReuse) { return {std::move(state)}; } + std::vector levels = config.levels; + ReuseType req = config.req; + if (Optional> ann = tir::GetAnn>( + state.sch->GetSRef(state.block_rv), "meta_schedule.write_cache_level")) { + req = ReuseType::kMustReuse; + levels = std::vector(ann.value().begin(), ann.value().end()); + } std::vector results; - if (config.req == ReuseType::kMayReuse) { + if (req == ReuseType::kMayReuse) { // Case 1. If the write cache is already there, we don't need to add another. Array consumer_rvs = state.sch->GetConsumers(state.block_rv); if (consumer_rvs.size() == 1 && IsWriteCache(state.sch->GetSRef(consumer_rvs[0]))) { - for (int level : config.levels) { + for (int level : levels) { State new_state = state; new_state.sch = state.sch->Copy(); new_state.sch->Seed(state.sch->ForkSeed()); @@ -256,7 +263,7 @@ inline std::vector MultiLevelTilingNode::AddWriteReuse(State state) const // Case 3. Add one write cache BlockRV write_cache = state.sch->CacheWrite(/*block_rv=*/state.block_rv, /*read_buffer_index=*/0, /*storage_scope=*/config.scope); - for (int level : config.levels) { + for (int level : levels) { State new_state = state; new_state.sch = state.sch->Copy(); new_state.sch->Seed(state.sch->ForkSeed()); diff --git a/src/meta_schedule/schedule_rule/winograd.cc b/src/meta_schedule/schedule_rule/winograd.cc new file mode 100644 index 000000000000..44db6f2f404c --- /dev/null +++ b/src/meta_schedule/schedule_rule/winograd.cc @@ -0,0 +1,96 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include "../utils.h" + +namespace tvm { +namespace tir { + +TVM_REGISTER_GLOBAL("meta_schedule.compute_inline") + .set_body_typed([](Schedule sch, BlockRV block) -> Array { + sch->ComputeInline(block); + return {sch}; + }); + +inline BlockRV GetOnlyProducer(Schedule sch, BlockRV block) { + Array producers = sch->GetProducers(block); + ICHECK_EQ(producers.size(), 1); + return producers[0]; +} + +inline LoopRV ScheduleDataPack(Schedule sch, BlockRV block) { + Array factors{nullptr}; + Array loops = sch->GetLoops(block); + ICHECK_EQ(loops.size(), 6); + + factors = sch->SamplePerfectTile(loops[2], /*n=*/2, /*max_innermost_factor=*/64); + Array t0 = sch->Split(loops[2], {factors.begin(), factors.end()}); + ICHECK_EQ(t0.size(), 2); + + factors = sch->SamplePerfectTile(loops[3], /*n=*/2, /*max_innermost_factor=*/64); + Array t1 = sch->Split(loops[3], {factors.begin(), factors.end()}); + ICHECK_EQ(t1.size(), 2); + + sch->Unroll(loops[0]); + sch->Unroll(loops[1]); + sch->Unroll(loops[4]); + sch->Unroll(loops[5]); + sch->Reorder({ + t0[0], + t1[0], + t0[1], + t1[1], + loops[0], + loops[1], + loops[4], + loops[5], + }); + return t1[1]; +} + +TVM_REGISTER_GLOBAL("meta_schedule.winograd_inverse") + .set_body_typed([](Schedule sch, BlockRV block) -> Array { + ScheduleDataPack(sch, block); + return {sch}; + }); + +TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cpu") + .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { + BlockRV input_tile = GetOnlyProducer(sch, data_pack); + BlockRV data_pad = GetOnlyProducer(sch, input_tile); + ScheduleDataPack(sch, data_pack); + sch->ComputeAt(input_tile, /*loop_rv=*/sch->SampleComputeLocation(input_tile), + /*preserve_unit_loops=*/true); + sch->ComputeAt(data_pad, /*loop_rv=*/sch->SampleComputeLocation(data_pad), + /*preserve_unit_loops=*/true); + return {sch}; + }); + +TVM_REGISTER_GLOBAL("meta_schedule.winograd_data_pack.cuda") + .set_body_typed([](Schedule sch, BlockRV data_pack) -> Array { + BlockRV input_tile = GetOnlyProducer(sch, data_pack); + BlockRV data_pad = GetOnlyProducer(sch, input_tile); + LoopRV loop = ScheduleDataPack(sch, data_pack); + sch->ComputeAt(input_tile, /*loop_rv=*/loop, /*preserve_unit_loops=*/true); + sch->SetScope(input_tile, /*buffer_index=*/0, /*storage_scope=*/"local"); + sch->ComputeInline(data_pad); + return {sch}; + }); + +} // namespace tir +} // namespace tvm diff --git a/src/meta_schedule/space_generator/post_order_apply.cc b/src/meta_schedule/space_generator/post_order_apply.cc index bc616327eb3b..cae42bee4fe4 100644 --- a/src/meta_schedule/space_generator/post_order_apply.cc +++ b/src/meta_schedule/space_generator/post_order_apply.cc @@ -31,6 +31,7 @@ class BlockCollector : public tir::StmtVisitor { private: /*! \brief Entry point */ Array Run() { + std::vector results; for (const auto& kv : sch_->mod()->functions) { const GlobalVar& gv = kv.first; // `gv->name_hint` is the name of the function const BaseFunc& base_func = kv.second; // this can be PrimFunc or relay::Function @@ -39,12 +40,12 @@ class BlockCollector : public tir::StmtVisitor { block_names_.clear(); blocks_to_collect_.clear(); VisitStmt(func->body); - for (const String& block_name : blocks_to_collect_) { - results_.push_back(sch_->GetBlock(block_name, func_name_)); + for (const String& name : blocks_to_collect_) { + results.push_back(sch_->GetBlock(name, func_name_)); } } } - return results_; + return results; } /*! \brief Constructor */ explicit BlockCollector(const tir::Schedule& sch) : sch_(sch) {} @@ -64,8 +65,6 @@ class BlockCollector : public tir::StmtVisitor { std::unordered_set block_names_; /* \brief The list of blocks to collect in order */ Array blocks_to_collect_; - /*! \brief Function name & blocks of collection */ - Array results_; /*! \brief Name of the current PrimFunc */ String func_name_; }; @@ -95,10 +94,10 @@ class PostOrderApplyNode : public SpaceGeneratorNode { Array GenerateDesignSpace(const IRModule& mod_) final { using ScheduleAndUnvisitedBlocks = std::pair>; - tir::Schedule sch = tir::Schedule::Traced( // - /*mod=*/mod_, // - /*rand_state=*/ForkSeed(&this->rand_state_), // - /*debug_mode=*/tir::kVerifySRefTree | tir::kVerifyCachedFlags, // + tir::Schedule sch = tir::Schedule::Traced( + /*mod=*/mod_, + /*rand_state=*/ForkSeed(&this->rand_state_), + /*debug_mode=*/0, /*error_render_level=*/tir::ScheduleErrorRenderLevel::kDetail); std::vector stack; @@ -106,12 +105,19 @@ class PostOrderApplyNode : public SpaceGeneratorNode { // Enumerate the schedule rules first because you can // always concat multiple schedule rules as one Array all_blocks = BlockCollector::Collect(sch); - for (ScheduleRule sch_rule : sch_rules_) { - for (const tir::Schedule& sch : result) { - stack.emplace_back(sch, all_blocks); + Array> rules{NullOpt}; + rules.insert(rules.end(), sch_rules_.begin(), sch_rules_.end()); + for (Optional sch_rule : rules) { + if (sch_rule.defined()) { + for (const tir::Schedule& sch : result) { + stack.emplace_back(sch, all_blocks); + } + } else { + for (const tir::Schedule& sch : result) { + stack.emplace_back(sch, Array{all_blocks.rbegin(), all_blocks.rend()}); + } } result.clear(); - while (!stack.empty()) { // get the stack.top() tir::Schedule sch; @@ -126,12 +132,24 @@ class PostOrderApplyNode : public SpaceGeneratorNode { // otherwise, get the last block that is not visited tir::BlockRV block_rv = blocks.back(); blocks.pop_back(); - if (sch->HasBlock(block_rv)) { - Array applied = sch_rule->Apply(sch, /*block=*/block_rv); - for (const tir::Schedule& sch : applied) { - stack.emplace_back(sch, blocks); - } + if (!sch->HasBlock(block_rv)) { + stack.emplace_back(sch, blocks); + continue; + } + Optional ann = tir::GetAnn(sch->GetSRef(block_rv), "schedule_rule"); + if (ann.defined() == sch_rule.defined() || (ann.defined() && ann.value() == "None")) { + stack.emplace_back(sch, blocks); + continue; + } + Array applied{nullptr}; + if (sch_rule.defined()) { + applied = sch_rule.value()->Apply(sch, /*block=*/block_rv); } else { + const runtime::PackedFunc* f = runtime::Registry::Get(ann.value()); + CHECK(f) << "ValueError: Custom schedule rule not found: " << ann.value(); + applied = (*f)(sch, block_rv); + } + for (const tir::Schedule& sch : applied) { stack.emplace_back(sch, blocks); } } diff --git a/src/runtime/threading_backend.cc b/src/runtime/threading_backend.cc index 3d9dd8ec9605..748b0b035094 100644 --- a/src/runtime/threading_backend.cc +++ b/src/runtime/threading_backend.cc @@ -278,7 +278,7 @@ int ThreadGroup::Configure(AffinityMode mode, int nthreads, bool exclude_worker0 void Yield() { std::this_thread::yield(); } /*! - * \bief Set the maximum number of available cores. + * \brief Set the maximum number of available cores. */ void SetMaxConcurrency(int value) { if (value < 0) { diff --git a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py new file mode 100644 index 000000000000..04dcf957780c --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cpu.py @@ -0,0 +1,209 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +import tvm +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.meta_schedule.testing.conv2d_winograd_cpu import conv2d_winograd_cpu +from tvm.meta_schedule.tune import DefaultLLVM +from tvm.target import Target +from tvm.tir.schedule import Schedule, Trace + + +def _get_mod(): + # pylint: disable=invalid-name + def inline(sch: Schedule): + b1 = sch.get_block(name="A") + b2 = sch.get_block(name="B") + sch.compute_inline(block=b1) + sch.compute_inline(block=b2) + + def input_tile_data_pad(sch: Schedule): + b78 = sch.get_block(name="input_tile") + l80 = sch.sample_compute_location(block=b78, decision=4) + sch.compute_at(block=b78, loop=l80, preserve_unit_loops=True) + + b81 = sch.get_block(name="data_pad") + l83 = sch.sample_compute_location(block=b81, decision=-2) + sch.compute_at(block=b81, loop=l83, preserve_unit_loops=True) + + def data_pack(sch: Schedule): + b18 = sch.get_block(name="data_pack") + l19, l20, l21, l22, l23, l24 = sch.get_loops(block=b18) + sch.unroll(loop=l19) + sch.unroll(loop=l20) + v25, v26 = sch.sample_perfect_tile( + n=2, + loop=l21, + max_innermost_factor=64, + decision=[9, 1], + ) + l27, l28 = sch.split(loop=l21, factors=[v25, v26]) + v29, v30 = sch.sample_perfect_tile( + n=2, + loop=l22, + max_innermost_factor=64, + decision=[32, 4], + ) + l31, l32 = sch.split(loop=l22, factors=[v29, v30]) + sch.unroll(loop=l23) + sch.unroll(loop=l24) + sch.reorder(l27, l31, l28, l32, l19, l20, l23, l24) + + def bgemm(sch: Schedule): + bgemm = sch.get_block(name="bgemm") + write_cache = sch.cache_write( + block=bgemm, + write_buffer_index=0, + storage_scope="global", + ) + sch.annotate( + block_or_loop=bgemm, + ann_key="meta_schedule.tiling_structure", + ann_val="SSRSRS", + ) + # b33, b34 = b34, b33 + l35, l36, l37, l38, l39 = sch.get_loops(block=bgemm) + v40, v41, v42, v43 = sch.sample_perfect_tile( + n=4, + loop=l35, + max_innermost_factor=64, + decision=[1, 2, 3, 1], + ) + l44, l45, l46, l47 = sch.split(loop=l35, factors=[v40, v41, v42, v43]) + v48, v49, v50, v51 = sch.sample_perfect_tile( + n=4, + loop=l36, + max_innermost_factor=64, + decision=[1, 1, 1, 6], + ) + l52, l53, l54, l55 = sch.split(loop=l36, factors=[v48, v49, v50, v51]) + v56, v57, v58, v59 = sch.sample_perfect_tile( + n=4, + loop=l37, + max_innermost_factor=64, + decision=[1, 1, 1, 9], + ) + l60, l61, l62, l63 = sch.split(loop=l37, factors=[v56, v57, v58, v59]) + v64, v65, v66, v67 = sch.sample_perfect_tile( + n=4, + loop=l38, + max_innermost_factor=64, + decision=[2, 1, 16, 4], + ) + l68, l69, l70, l71 = sch.split(loop=l38, factors=[v64, v65, v66, v67]) + v72, v73 = sch.sample_perfect_tile( + n=2, + loop=l39, + max_innermost_factor=64, + decision=[16, 8], + ) + l74, l75 = sch.split(loop=l39, factors=[v72, v73]) + sch.reorder( + # fmt: off + l44, l52, l60, l68, + l45, l53, l61, l69, + l74, + l46, l54, l62, l70, + l75, + l47, l55, l63, l71, + # fmt: on + ) + sch.reverse_compute_at(block=write_cache, loop=l69, preserve_unit_loops=True) + + def inverse(sch: Schedule): + b3 = sch.get_block(name="inverse") + l4, l5, l6, l7, l8, l9 = sch.get_loops(block=b3) + sch.unroll(loop=l4) + sch.unroll(loop=l5) + v10, v11 = sch.sample_perfect_tile( + n=2, + loop=l6, + max_innermost_factor=64, + decision=[1, 9], + ) + l12, l13 = sch.split(loop=l6, factors=[v10, v11]) + v14, v15 = sch.sample_perfect_tile( + n=2, + loop=l7, + max_innermost_factor=64, + decision=[2, 64], + ) + l16, l17 = sch.split(loop=l7, factors=[v14, v15]) + sch.unroll(loop=l8) + sch.unroll(loop=l9) + sch.reorder(l12, l16, l13, l17, l4, l5, l8, l9) + + # pylint: enable=invalid-name + + sch = Schedule(mod=conv2d_winograd_cpu) + inline(sch) + data_pack(sch) + input_tile_data_pad(sch) + bgemm(sch) + inverse(sch) + return sch.mod + + +def test_conv2d_winograd_cpu(): + mod = conv2d_winograd_cpu + mod = IRModule({"main": mod}) + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + sch_rules=DefaultLLVM._sch_rules(), # pylint: disable=protected-access + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + (sch,) = post_order_apply.generate_design_space(mod) + + decisions = dict( + zip( + [i for i in sch.trace.insts[:-4] if i.kind.name.startswith("Sample")], + [ + # data_pack + [9, 1], + [32, 4], + # input_tile + 4, + # data_pad + -2, + # inverse + [1, 9], + [2, 64], + # bgemm + [1, 2, 3, 1], + [1, 1, 1, 6], + [1, 1, 1, 9], + [2, 1, 16, 4], + [16, 8], + ], + ) + ) + trace = Trace(sch.trace.insts[:-4], decisions=decisions) + sch = Schedule(mod=mod) + trace.apply_to_schedule(sch, remove_postproc=False) + answer = sch.mod + expected = _get_mod() + tvm.ir.assert_structural_equal(answer, expected) + + +if __name__ == "__main__": + test_conv2d_winograd_cpu() diff --git a/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py new file mode 100644 index 000000000000..afe6548d6fe3 --- /dev/null +++ b/tests/python/unittest/test_meta_schedule_custom_rule_winograd_cuda.py @@ -0,0 +1,243 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-docstring + +import tvm +from tvm.ir import IRModule +from tvm.meta_schedule import TuneContext +from tvm.meta_schedule.space_generator import PostOrderApply +from tvm.meta_schedule.testing.conv2d_winograd_cuda import conv2d_winograd_cuda +from tvm.meta_schedule.tune import DefaultCUDA +from tvm.target import Target +from tvm.tir.schedule import Schedule, Trace + + +def _get_mod(): + # pylint: disable=invalid-name + def inline(sch: Schedule): + b125 = sch.get_block(name="A") + sch.compute_inline(block=b125) + b126 = sch.get_block(name="B") + sch.compute_inline(block=b126) + + def input_tile_data_pad(sch: Schedule): + b115 = sch.get_block(name="input_tile") + (b116,) = sch.get_consumers(block=b115) + _, _, _, l120, _, _, _, _ = sch.get_loops(block=b116) + sch.compute_at(block=b115, loop=l120, preserve_unit_loops=True) + sch.set_scope(block=b115, buffer_index=0, storage_scope="local") + + b127 = sch.get_block(name="data_pad") + sch.compute_inline(block=b127) + + def data_pack(sch: Schedule): + b16 = sch.get_block(name="data_pack") + l17, l18, l19, l20, l21, l22 = sch.get_loops(block=b16) + sch.unroll(loop=l17) + sch.unroll(loop=l18) + v23, v24 = sch.sample_perfect_tile( + n=2, + loop=l19, + max_innermost_factor=64, + decision=[3, 3], + ) + l25, l26 = sch.split(loop=l19, factors=[v23, v24]) + v27, v28 = sch.sample_perfect_tile( + n=2, + loop=l20, + max_innermost_factor=64, + decision=[64, 2], + ) + l29, l30 = sch.split(loop=l20, factors=[v27, v28]) + sch.unroll(loop=l21) + sch.unroll(loop=l22) + sch.reorder(l25, l29, l26, l30, l17, l18, l21, l22) + + def bgemm(sch: Schedule): + b31 = sch.get_block(name="bgemm") + sch.annotate( + block_or_loop=b31, + ann_key="meta_schedule.tiling_structure", + ann_val="SSSRRSRS", + ) + b32 = sch.cache_write(block=b31, write_buffer_index=0, storage_scope="local") + b31, b32 = b32, b31 + l33, l34, l35, l36, l37 = sch.get_loops(block=b32) + v38, v39, v40, v41, v42 = sch.sample_perfect_tile( + n=5, + loop=l33, + max_innermost_factor=64, + decision=[1, 1, 1, 1, 6], + ) + l43, l44, l45, l46, l47 = sch.split(loop=l33, factors=[v38, v39, v40, v41, v42]) + v48, v49, v50, v51, v52 = sch.sample_perfect_tile( + n=5, + loop=l34, + max_innermost_factor=64, + decision=[1, 1, 1, 3, 2], + ) + l53, l54, l55, l56, l57 = sch.split(loop=l34, factors=[v48, v49, v50, v51, v52]) + v58, v59, v60, v61, v62 = sch.sample_perfect_tile( + n=5, + loop=l35, + max_innermost_factor=64, + decision=[3, 1, 1, 1, 3], + ) + l63, l64, l65, l66, l67 = sch.split(loop=l35, factors=[v58, v59, v60, v61, v62]) + v68, v69, v70, v71, v72 = sch.sample_perfect_tile( + n=5, + loop=l36, + max_innermost_factor=64, + decision=[4, 2, 1, 4, 4], + ) + l73, l74, l75, l76, l77 = sch.split(loop=l36, factors=[v68, v69, v70, v71, v72]) + v78, v79, v80 = sch.sample_perfect_tile( + n=3, + loop=l37, + max_innermost_factor=64, + decision=[32, 1, 4], + ) + l81, l82, l83 = sch.split(loop=l37, factors=[v78, v79, v80]) + sch.reorder( + # fmt: off + l43, l53, l63, l73, + l44, l54, l64, l74, + l45, l55, l65, l75, + l81, + l82, + l46, l56, l66, l76, + l83, + l47, l57, l67, l77, + # fmt: on + ) + l84 = sch.fuse(l43, l53, l63, l73) + sch.bind(loop=l84, thread_axis="blockIdx.x") + l85 = sch.fuse(l44, l54, l64, l74) + sch.bind(loop=l85, thread_axis="vthread.x") + l86 = sch.fuse(l45, l55, l65, l75) + sch.bind(loop=l86, thread_axis="threadIdx.x") + + b87 = sch.cache_read(block=b32, read_buffer_index=1, storage_scope="shared") + sch.compute_at(block=b87, loop=l81, preserve_unit_loops=True) + _, _, _, _, l92, l93, l94, l95 = sch.get_loops(block=b87) + sch.fuse(l92, l93, l94, l95) + v97 = sch.sample_categorical( + candidates=[1, 2, 3, 4], + probs=[0.25, 0.25, 0.25, 0.25], + decision=1, + ) + sch.annotate( + block_or_loop=b87, + ann_key="meta_schedule.cooperative_fetch", + ann_val=v97, + ) + + b101 = sch.cache_read(block=b32, read_buffer_index=2, storage_scope="shared") + sch.compute_at(block=b101, loop=l81, preserve_unit_loops=True) + _, _, _, _, l106, l107, l108, l109 = sch.get_loops(block=b101) + sch.fuse(l106, l107, l108, l109) + v110 = sch.sample_categorical( + candidates=[1, 2, 3, 4], + probs=[0.25, 0.25, 0.25, 0.25], + decision=1, + ) + sch.annotate( + block_or_loop=b101, + ann_key="meta_schedule.cooperative_fetch", + ann_val=v110, + ) + + sch.reverse_compute_at(block=b31, loop=l86, preserve_unit_loops=True) + + def inverse(sch: Schedule): + b1 = sch.get_block(name="inverse") + l2, l3, l4, l5, l6, l7 = sch.get_loops(block=b1) + sch.unroll(loop=l2) + sch.unroll(loop=l3) + v8, v9 = sch.sample_perfect_tile( + n=2, + loop=l4, + max_innermost_factor=64, + decision=[3, 3], + ) + l10, l11 = sch.split(loop=l4, factors=[v8, v9]) + v12, v13 = sch.sample_perfect_tile( + n=2, + loop=l5, + max_innermost_factor=64, + decision=[2, 64], + ) + l14, l15 = sch.split(loop=l5, factors=[v12, v13]) + sch.unroll(loop=l6) + sch.unroll(loop=l7) + sch.reorder(l10, l14, l11, l15, l2, l3, l6, l7) + + # pylint: enable=invalid-name + + sch = Schedule(mod=conv2d_winograd_cuda) + inline(sch) + data_pack(sch) + input_tile_data_pad(sch) + bgemm(sch) + inverse(sch) + + return sch.mod + + +def test_conv2d_winograd_cuda(): + mod = conv2d_winograd_cuda + mod = IRModule({"main": mod}) + context = TuneContext( + mod=mod, + target=Target("cuda"), + task_name="Custom Search Space Task", + sch_rules=DefaultCUDA._sch_rules(), # pylint: disable=protected-access + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + (sch,) = post_order_apply.generate_design_space(mod) + decisions = dict( + zip( + [i for i in sch.trace.insts[:-2] if i.kind.name.startswith("Sample")], + [ + # data_pack + [3, 3], + [64, 2], + # inverse + [3, 3], + [2, 64], + # bgemm + [1, 1, 1, 1, 6], + [1, 1, 1, 3, 2], + [3, 1, 1, 1, 3], + [4, 2, 1, 4, 4], + [32, 1, 4], + 1, + 1, + ], + ) + ) + trace = Trace(sch.trace.insts[:-2], decisions=decisions) + sch = Schedule(mod=mod) + trace.apply_to_schedule(sch, remove_postproc=False) + answer = sch.mod + expected = _get_mod() + tvm.ir.assert_structural_equal(answer, expected) + + +if __name__ == "__main__": + test_conv2d_winograd_cuda() diff --git a/tests/python/unittest/test_meta_schedule_post_order_apply.py b/tests/python/unittest/test_meta_schedule_post_order_apply.py index 556468adb982..40bb82f95929 100644 --- a/tests/python/unittest/test_meta_schedule_post_order_apply.py +++ b/tests/python/unittest/test_meta_schedule_post_order_apply.py @@ -22,6 +22,7 @@ import pytest import tvm +from tvm._ffi import register_func from tvm.error import TVMError from tvm.meta_schedule import TuneContext from tvm.meta_schedule.schedule_rule import PyScheduleRule @@ -31,7 +32,6 @@ from tvm.target import Target from tvm.tir.schedule import BlockRV, Schedule - # pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument, # fmt: off @@ -121,6 +121,23 @@ def main(a: T.handle, d: T.handle) -> None: D[vi, vj] = (B[vi, vj] + T.float32(3)) * T.float32(5) +@tvm.script.ir_module +class MatmulCustomized: + @T.prim_func + def main(a: T.handle, b: T.handle, c: T.handle) -> None: + T.func_attr({"global_symbol": "main"}) + A = T.match_buffer(a, (1024, 1024), "float32") + B = T.match_buffer(b, (1024, 1024), "float32") + C = T.match_buffer(c, (1024, 1024), "float32") + with T.block("root"): + for i, j, k in T.grid(1024, 1024, 1024): + with T.block("matmul"): + T.block_attr({"schedule_rule": "tvm.meta_schedule.test.custom_search_space"}) + vi, vj, vk = T.axis.remap("SSR", [i, j, k]) + with T.init(): + C[vi, vj] = 0.0 + C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj] + # fmt: on # pylint: enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument @@ -344,5 +361,31 @@ def correct_trace(a, b, c, d): ) +def test_meta_schedule_custom_search_space(): + mod = MatmulCustomized + context = TuneContext( + mod=mod, + target=Target("llvm"), + task_name="Custom Search Space Task", + sch_rules=[], + ) + post_order_apply = PostOrderApply() + post_order_apply.initialize_with_tune_context(context) + with pytest.raises(ValueError, match="Custom schedule rule not found"): + post_order_apply.generate_design_space(mod) + + called = False + + def custom_search_space_func(sch: Schedule, _: BlockRV) -> List[Schedule]: + nonlocal called + called = True + return [sch] + + register_func("tvm.meta_schedule.test.custom_search_space", custom_search_space_func) + + post_order_apply.generate_design_space(mod) + assert called + + if __name__ == "__main__": sys.exit(pytest.main([__file__] + sys.argv[1:])) From 727c49275d8cacc830cb0f0bc8cffaf591ed9519 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 3 Mar 2022 19:27:40 -0800 Subject: [PATCH 2/3] Fix lint --- python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py | 2 +- python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py index 01a32794f769..0b38f704a10f 100644 --- a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py +++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py @@ -18,7 +18,7 @@ from tvm.script import tir as T -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,no-self-use,unused-argument +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,no-self-use,unused-argument,chained-comparison,misplaced-comparison-constant @T.prim_func diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py index 59ae737b1348..10744653c16b 100644 --- a/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py +++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py @@ -18,7 +18,7 @@ from tvm.script import tir as T -# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,no-self-use,unused-argument +# pylint: disable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,no-self-use,unused-argument,chained-comparison,misplaced-comparison-constant @T.prim_func From 1b2036f6e2ce498b748a271efcedc1e88f04db57 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 3 Mar 2022 19:59:47 -0800 Subject: [PATCH 3/3] suppress mypy --- .../meta_schedule/cost_model/random_model.py | 2 +- .../testing/conv2d_winograd_cpu.py | 50 ++++++++--------- .../testing/conv2d_winograd_cuda.py | 53 ++++++++++--------- 3 files changed, 53 insertions(+), 52 deletions(-) diff --git a/python/tvm/meta_schedule/cost_model/random_model.py b/python/tvm/meta_schedule/cost_model/random_model.py index 77926c45b972..bc178f76ac90 100644 --- a/python/tvm/meta_schedule/cost_model/random_model.py +++ b/python/tvm/meta_schedule/cost_model/random_model.py @@ -131,6 +131,6 @@ def predict( np.random.set_state(self.random_state) # TODO(@zxybazh): Use numpy's RandState object: # https://numpy.org/doc/1.16/reference/generated/numpy.random.RandomState.html#numpy.random.RandomState - result = np.random.rand(len(candidates)) * self.max_range + result = np.random.rand(len(candidates)) * self.max_range # type: ignore self.random_state = np.random.get_state() return result diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py index 0b38f704a10f..bfd5f4557ce8 100644 --- a/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py +++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cpu.py @@ -23,9 +23,9 @@ @T.prim_func def conv2d_winograd_cpu( - X: T.Buffer[(1, 14, 14, 128), "float32"], - W: T.Buffer[(6, 6, 128, 128), "float32"], - conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], + X: T.Buffer[(1, 14, 14, 128), "float32"], # type: ignore + W: T.Buffer[(6, 6, 128, 128), "float32"], # type: ignore + conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], # type: ignore ) -> None: # body data_pad = T.alloc_buffer([1, 16, 16, 128]) @@ -42,7 +42,7 @@ def conv2d_winograd_cpu( T.reads([X[i0_1, i1_1, i2_1, i3_1]]) T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( - 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, + 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, # type: ignore X[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32", @@ -53,17 +53,17 @@ def conv2d_winograd_cpu( T.block_attr({"schedule_rule": "None"}) T.reads( data_pad[ - T.floordiv(p, 9), - ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), - ((T.floormod(p, 3) * 4) + nu), + T.floordiv(p, 9), # type: ignore + ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore + ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] ) T.writes([input_tile[eps, nu, p, ci]]) input_tile[eps, nu, p, ci] = data_pad[ - T.floordiv(p, 9), - ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), - ((T.floormod(p, 3) * 4) + nu), + T.floordiv(p, 9), # type: ignore + ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore + ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] for i0_3, i1_3 in T.grid(6, 6): @@ -72,7 +72,7 @@ def conv2d_winograd_cpu( T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([B[i, j]]) # fmt: off - B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) # type: ignore # fmt: on for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): with T.block("data_pack"): @@ -85,11 +85,11 @@ def conv2d_winograd_cpu( data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[ - T.min(r_a, r_b) : ( - T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) + T.min(r_a, r_b) : ( # type: ignore + T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) # type: ignore ), - T.min(eps_1, nu_1) : ( - T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) + T.min(eps_1, nu_1) : ( # type: ignore + T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) # type: ignore ), ], ] @@ -124,7 +124,7 @@ def conv2d_winograd_cpu( T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([A[i_1, j_1]]) # fmt: off - A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) + A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) # type: ignore # fmt: on for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): with T.block("inverse"): @@ -137,10 +137,10 @@ def conv2d_winograd_cpu( inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[ - T.min(r_a_1, r_b_1) : ( - T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) + T.min(r_a_1, r_b_1) : ( # type: ignore + T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) # type: ignore ), - T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), + T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), # type: ignore ], ] ) @@ -156,17 +156,17 @@ def conv2d_winograd_cpu( T.reads( [ inverse[ - T.floormod(h, 4), - T.floormod(w, 4), - (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), + T.floormod(h, 4), # type: ignore + T.floormod(w, 4), # type: ignore + (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ] ] ) T.writes([conv2d_winograd[n, h, w, co_2]]) conv2d_winograd[n, h, w, co_2] = inverse[ - T.floormod(h, 4), - T.floormod(w, 4), - (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), + T.floormod(h, 4), # type: ignore + T.floormod(w, 4), # type: ignore + (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ] diff --git a/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py b/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py index 10744653c16b..530eadafc0f3 100644 --- a/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py +++ b/python/tvm/meta_schedule/testing/conv2d_winograd_cuda.py @@ -22,11 +22,12 @@ @T.prim_func -def conv2d_winograd_cuda( - placeholder: T.Buffer[(1, 14, 14, 128), "float32"], - placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"], - conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], +def conv2d_winograd_cuda( # type: ignore + placeholder: T.Buffer[(1, 14, 14, 128), "float32"], # type: ignore + placeholder_1: T.Buffer[(6, 6, 128, 128), "float32"], # type: ignore + conv2d_winograd: T.Buffer[(1, 12, 12, 128), "float32"], # type: ignore ) -> None: + # type: ignore data_pad = T.alloc_buffer([1, 16, 16, 128]) input_tile = T.alloc_buffer([6, 6, 9, 128]) B = T.alloc_buffer([6, 6]) @@ -41,7 +42,7 @@ def conv2d_winograd_cuda( T.reads([placeholder[i0_1, i1_1, i2_1, i3_1]]) T.writes([data_pad[i0_1, i1_1, i2_1, i3_1]]) data_pad[i0_1, i1_1, i2_1, i3_1] = T.if_then_else( - 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, + 0 <= i1_1 and i1_1 < 14 and 0 <= i2_1 and i2_1 < 14, # type: ignore placeholder[i0_1, i1_1, i2_1, i3_1], T.float32(0), dtype="float32", @@ -53,18 +54,18 @@ def conv2d_winograd_cuda( T.reads( [ data_pad[ - T.floordiv(p, 9), - ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), - ((T.floormod(p, 3) * 4) + nu), + T.floordiv(p, 9), # type: ignore + ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore + ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] ] ) T.writes([input_tile[eps, nu, p, ci]]) input_tile[eps, nu, p, ci] = data_pad[ - T.floordiv(p, 9), - ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), - ((T.floormod(p, 3) * 4) + nu), + T.floordiv(p, 9), # type: ignore + ((T.floordiv(T.floormod(p, 9), 3) * 4) + eps), # type: ignore + ((T.floormod(p, 3) * 4) + nu), # type: ignore ci, ] for i0_3, i1_3 in T.grid(6, 6): @@ -73,7 +74,7 @@ def conv2d_winograd_cuda( T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([B[i, j]]) # fmt: off - B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) + B[i, j] = T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 5) and (T.floormod(j, 6) == 0)), T.float32(0), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 5)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 4)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 3)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 2)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 4) and (T.floormod(j, 6) == 0)), T.float32(1), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 5)), T.float32(-2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 4)), T.float32(-0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 3)), T.float32(2), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 2)), T.float32(2.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 1)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 3) and (T.floormod(j, 6) == 0)), T.float32(1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 5)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 4)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 3)), T.float32(-1), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 2)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 1)), T.float32(-2.5), T.Select(((T.floormod(i, 6) == 2) and (T.floormod(j, 6) == 0)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 5)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 4)), T.float32(0.5), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 3)), T.float32(-2), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 2)), T.float32(-1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 1)), T.float32(1), T.Select(((T.floormod(i, 6) == 1) and (T.floormod(j, 6) == 0)), T.float32(-1.5), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 5)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 4)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 3)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 2)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 1)), T.float32(0), T.Select(((T.floormod(i, 6) == 0) and (T.floormod(j, 6) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))))))))))))))) # type: ignore # fmt: on for i0_4, i1_4, i2_3, i3_3, i4, i5 in T.grid(6, 6, 9, 128, 6, 6): with T.block("data_pack"): @@ -86,11 +87,11 @@ def conv2d_winograd_cuda( data_pack[eps_1, nu_1, p_1, ci_1], input_tile[r_a, r_b, p_1, ci_1], B[ - T.min(r_a, r_b) : ( - T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) + T.min(r_a, r_b) : ( # type: ignore + T.min(r_a, r_b) + ((T.max(r_a, r_b) + 1) - T.min(r_a, r_b)) # type: ignore ), - T.min(eps_1, nu_1) : ( - T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) + T.min(eps_1, nu_1) : ( # type: ignore + T.min(eps_1, nu_1) + ((T.max(eps_1, nu_1) + 1) - T.min(eps_1, nu_1)) # type: ignore ), ], ] @@ -124,7 +125,7 @@ def conv2d_winograd_cuda( T.block_attr({"schedule_rule": "meta_schedule.compute_inline"}) T.writes([A[i_1, j_1]]) # fmt: off - A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) + A[i_1, j_1] = T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 5) and (T.floormod(j_1, 4) == 0)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 3)), T.float32(-8), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 2)), T.float32(4), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 1)), T.float32(-2), T.Select(((T.floormod(i_1, 6) == 4) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 3)), T.float32(0.125), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 2)), T.float32(0.25), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 1)), T.float32(0.5), T.Select(((T.floormod(i_1, 6) == 3) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 3)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 1)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 2) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 3)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 2)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 1)), T.float32(-1), T.Select(((T.floormod(i_1, 6) == 1) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 3)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 2)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 1)), T.float32(0), T.Select(((T.floormod(i_1, 6) == 0) and (T.floormod(j_1, 4) == 0)), T.float32(1), T.float32(0))))))))))))))))))))))))) # type: ignore # fmt: on for i0_7, i1_7, i2_5, i3_5, i4_2, i5_1 in T.grid(4, 4, 9, 128, 6, 6): with T.block("inverse"): @@ -137,10 +138,10 @@ def conv2d_winograd_cuda( inverse[vh, vw, p_3, co_1], bgemm[r_a_1, r_b_1, p_3, co_1], A[ - T.min(r_a_1, r_b_1) : ( - T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) + T.min(r_a_1, r_b_1) : ( # type: ignore + T.min(r_a_1, r_b_1) + ((T.max(r_a_1, r_b_1) + 1) - T.min(r_a_1, r_b_1)) # type: ignore ), - T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), + T.min(vh, vw) : (T.min(vh, vw) + ((T.max(vh, vw) + 1) - T.min(vh, vw))), # type: ignore ], ] ) @@ -156,17 +157,17 @@ def conv2d_winograd_cuda( T.reads( [ inverse[ - T.floormod(h, 4), - T.floormod(w, 4), - (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), + T.floormod(h, 4), # type: ignore + T.floormod(w, 4), # type: ignore + (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ] ] ) T.writes([conv2d_winograd[n, h, w, co_2]]) conv2d_winograd[n, h, w, co_2] = inverse[ - T.floormod(h, 4), - T.floormod(w, 4), - (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), + T.floormod(h, 4), # type: ignore + T.floormod(w, 4), # type: ignore + (((n * 9) + (T.floordiv(h, 4) * 3)) + T.floordiv(w, 4)), # type: ignore co_2, ]