diff --git a/tests/python/unittest/test_tir_schedule_reduction.py b/tests/python/unittest/test_tir_schedule_reduction.py new file mode 100644 index 0000000000..9ea6e327f2 --- /dev/null +++ b/tests/python/unittest/test_tir_schedule_reduction.py @@ -0,0 +1,222 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring,missing-module-docstring +import numpy as np +import tvm +import tvm.testing +from tvm import tir +from tvm.script import ty + +# pylint: disable=no-member,invalid-name,unused-variable + + +@tvm.script.tir +def transformed_matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): + with tir.block([128, 128, tir.reduce_axis(0, 128)], "update") as [vi, vj, vk]: + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.bind(vk, (((i2_outer*32) + (i2_inner_outer*4)) + i2_inner_inner)) + tir.reads([C[vi, vj], A[vi, vk], B[vj, vk]]) + tir.writes([C[vi, vj]]) + with tir.init(): + C[vi, vj] = 0.0 + C[vi, vj] = (C[vi, vj] + (A[vi, vk]*B[vj, vk])) + + +@tvm.script.tir +def matmul_rfactor(a: ty.handle, b: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [128, 128]) + B = tir.match_buffer(b, [128, 128]) + C = tir.match_buffer(c, [128, 128]) + C_rf = tir.alloc_buffer([4, 128, 128]) + + for i0, i1, i2_outer, i2_inner_outer, i2_inner_inner in tir.grid(128, 128, 4, 8, 4): + with tir.block([4, 128, 128, tir.reduce_axis(0, 4), tir.reduce_axis(0, 8)], "update_rf") as [vi2_inner_inner, vi, vj, vi2_outer, vi2_inner_outer]: + tir.bind(vi2_inner_inner, i2_inner_inner) + tir.bind(vi, i0) + tir.bind(vj, i1) + tir.bind(vi2_outer, i2_outer) + tir.bind(vi2_inner_outer, i2_inner_outer) + with tir.init(): + C_rf[vi2_inner_inner, vi, vj] = 0.0 + C_rf[vi2_inner_inner, vi, vj] = (C_rf[vi2_inner_inner, vi, vj] + (A[vi, (((vi2_outer*32) + (vi2_inner_outer*4)) + vi2_inner_inner)]*B[vj, (((vi2_outer*32) + (vi2_inner_outer*4)) + vi2_inner_inner)])) + + for i0_1, i1_1, i2_inner_inner_1 in tir.grid(128, 128, 4): + with tir.block([128, 128, tir.reduce_axis(0, 4)], "update") as [vi_1, vj_1, vi2_inner_inner_1]: + tir.bind(vi_1, i0_1) + tir.bind(vj_1, i1_1) + tir.bind(vi2_inner_inner_1, i2_inner_inner_1) + with tir.init(): + C[vi_1, vj_1] = 0.0 + C[vi_1, vj_1] = (C[vi_1, vj_1] + C_rf[vi2_inner_inner_1, vi_1, vj_1]) + + +@tvm.script.tir +def square_sum(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + C = tir.match_buffer(c, [16]) + + with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: + with tir.init(): + C[b] = 0.0 + C[b] = C[b] + A[b, i, j] * A[b, i, j] + + +@tvm.script.tir +def square_sum_rfactor(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + C = tir.match_buffer(c, [16]) + C_rf = tir.alloc_buffer([16, 256]) + + for i0, i1, i2 in tir.grid(16, 256, 256): + with tir.block([256, 16, tir.reduce_axis(0, 256)], "C_rf") as [vi2, b, i]: + tir.bind(vi2, i2) + tir.bind(b, i0) + tir.bind(i, i1) + with tir.init(): + C_rf[b, vi2] = 0.0 + C_rf[b, vi2] = (C_rf[b, vi2] + (A[b, i, vi2]*A[b, i, vi2])) + + for i0_1, i2_1 in tir.grid(16, 256): + with tir.block([16, tir.reduce_axis(0, 256)], "C") as [b_1, vi2_1]: + tir.bind(b_1, i0_1) + tir.bind(vi2_1, i2_1) + with tir.init(): + C[b_1] = 0.0 + C[b_1] = (C[b_1] + C_rf[b_1, vi2_1]) + + +@tvm.script.tir +def transformed_square_sum_square_root(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + D = tir.match_buffer(d, [16]) + C = tir.alloc_buffer([16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): + with tir.block([16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C") as [b, i, j]: + tir.bind(b, i0) + tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) + tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) + tir.reads([C[b], A[b, i, j]]) + tir.writes([C[b]]) + with tir.init(): + C[b] = 0.0 + C[b] = (C[b] + (A[b, i, j]*A[b, i, j])) + for i0_1 in tir.serial(0, 16): + with tir.block([16], "D") as [b_1]: + tir.bind(b_1, i0_1) + tir.reads([C[b_1]]) + tir.writes([D[b_1]]) + D[b_1] = tir.sqrt(C[b_1], dtype="float32") + + +@tvm.script.tir +def square_sum_square_root_rfactor(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, [16, 256, 256]) + D = tir.match_buffer(d, [16]) + C = tir.alloc_buffer([16]) + C_rf = tir.alloc_buffer([1, 16]) + + for i0, i1_i2_fused_outer, i1_i2_fused_inner in tir.grid(16, 65536, 1): + with tir.block([1, 16, tir.reduce_axis(0, 256), tir.reduce_axis(0, 256)], "C_rf") as [vi1_i2_fused_inner, b, i, j]: + tir.bind(vi1_i2_fused_inner, i1_i2_fused_inner) + tir.bind(b, i0) + tir.bind(i, tir.floordiv(i1_i2_fused_outer, 256)) + tir.bind(j, tir.floormod(i1_i2_fused_outer, 256)) + with tir.init(): + C_rf[vi1_i2_fused_inner, b] = 0.0 + C_rf[vi1_i2_fused_inner, b] = (C_rf[vi1_i2_fused_inner, b] + (A[b, i, j]*A[b, i, j])) + + for i0_1, i1_i2_fused_inner_1 in tir.grid(16, 1): + with tir.block([16, tir.reduce_axis(0, 1)], "C") as [b_1, vi1_i2_fused_inner_1]: + tir.bind(b_1, i0_1) + tir.bind(vi1_i2_fused_inner_1, i1_i2_fused_inner_1) + with tir.init(): + C[b_1] = 0.0 + C[b_1] = (C[b_1] + C_rf[vi1_i2_fused_inner_1, b_1]) + + for i0_2 in tir.serial(0, 16): + with tir.block([16], "D") as [b_2]: + tir.bind(b_2, i0_2) + D[b_2] = tir.sqrt(C[b_2], dtype="float32") + + +# pylint: enable=no-member,invalid-name,unused-variable + + +def test_reduction_rfactor_matmul(): + s = tir.Schedule(transformed_matmul, debug_mode=True) + C = s.get_block("update") + _, _, _, _, kii = s.get_loops(C) + rf_block = s.rfactor(kii, 0) + tvm.ir.assert_structural_equal(s.mod["main"], matmul_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("update_rf"))) + + f = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(128, 128)).astype("float32") + b_np = np.random.uniform(size=(128, 128)).astype("float32") + a = tvm.nd.array(a_np) + b = tvm.nd.array(b_np) + c = tvm.nd.array(np.zeros((128, 128), dtype="float32")) + f(a, b, c) + c_np = np.matmul(a_np, b_np.T) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_square_sum(): + s = tir.Schedule(square_sum, debug_mode=True) + C = s.get_block("C") + _, _, j = s.get_loops(C) + rf_block = s.rfactor(j, 1) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + + f = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") + a = tvm.nd.array(a_np) + c = tvm.nd.array(np.zeros((16,), dtype="float32")) + f(a, c) + c_np = np.sum(a_np * a_np, axis=(1, 2)) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4) + + +def test_reduction_rfactor_square_sum_square_root(): + s = tir.Schedule(transformed_square_sum_square_root, debug_mode=True) + C = s.get_block("C") + _, _, fi = s.get_loops(C) + rf_block = s.rfactor(fi, 0) + tvm.ir.assert_structural_equal(s.mod["main"], square_sum_square_root_rfactor) + assert s.get(rf_block).same_as(s.get(s.get_block("C_rf"))) + + f = tvm.build(s.mod["main"], target="llvm") + a_np = np.random.uniform(size=(16, 256, 256)).astype("float32") + a = tvm.nd.array(a_np) + c = tvm.nd.array(np.zeros((16,), dtype="float32")) + f(a, c) + c_np = np.sqrt(np.sum(a_np * a_np, axis=(1, 2))) + tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4) + + +if __name__ == "__main__": + test_reduction_rfactor_matmul() + test_reduction_rfactor_square_sum() + test_reduction_rfactor_square_sum_square_root() \ No newline at end of file