diff --git a/python/tvm/script/parser.py b/python/tvm/script/parser.py index fe71b064320f..1d76815dc3b8 100644 --- a/python/tvm/script/parser.py +++ b/python/tvm/script/parser.py @@ -554,12 +554,14 @@ def transform_Assign(self, node): 4.1 var = T.allocate() """ + print("parsing ", node.rhs.func_name) if isinstance(node.rhs, ast.Call): # Pattern 1 & Pattern 4 if isinstance(node.rhs.func_name, ast.Op): func = None else: func = self.transform(node.rhs.func_name) + print(func) if isinstance(func, WithScopeHandler): if not func.concise_scope or not func.def_symbol: @@ -577,6 +579,31 @@ def transform_Assign(self, node): arg_list = self.parse_arg_list(func, node.rhs) func.handle(node, self.context, arg_list, node.rhs.func_name.span) return self.parse_body(node) + elif callable(func): + args = [self.transform(arg) for arg in node.rhs.params] + out = func(*args) + print(out) + print(node.lhs) + assert len(out) == len(node.lhs) + + lhs_vars = [] + for ast_var, value in zip(node.lhs, out): + var = tvm.te.var( + ast_var.id.name, + "int32", + span=tvm_span_from_synr(ast_var.span), + ) + self.context.update_symbol(var.name, var, node) + lhs_vars.append(var) + + body = self.parse_body(node) + + for var, value in reversed(list(zip(lhs_vars, out))): + self.context.remove_symbol(var.name) + body = tvm.tir.LetStmt(var, value, body, span=tvm_span_from_synr(node.span)) + + return body + if isinstance(node.rhs, (ast.Call, ast.Constant)): # Pattern 4 of let binding value = self.transform(node.rhs) @@ -593,6 +620,7 @@ def transform_Assign(self, node): if node.ty is None and hasattr(value, "dtype"): var_ty = value.dtype else: + print(node.ty, ast_var) var_ty = self.parse_type(node.ty, ast_var) var = tvm.te.var( diff --git a/tests/python/unittest/test_mma_16x8x16_4k_tune.py b/tests/python/unittest/test_mma_16x8x16_4k_tune.py index 0023251f3ee6..ddeb931ff9ed 100644 --- a/tests/python/unittest/test_mma_16x8x16_4k_tune.py +++ b/tests/python/unittest/test_mma_16x8x16_4k_tune.py @@ -8,6 +8,11 @@ import numpy as np +def shared_16x16_to_ldmatrix_32x8_layout(i, j): + thread_id = 4 * (i % 8) + (j % 8) // 2 + return thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 2) + + @T.prim_func def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: A_shared = T.match_buffer(a, (16, 16), "float16", align=128, offset_factor=16, scope="shared") @@ -21,11 +26,15 @@ def ldmatrix_a_desc(a: T.handle, c: T.handle) -> None: with T.block("A_shared_warp"): v0, v1 = T.axis.remap("SS", [ax0, ax1]) T.reads(A_shared[v0, v1]) - T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) - A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[ - v0, v1 - ] + thread_id, y = shared_16x16_to_ldmatrix_32x8_layout(v0, v1) + T.writes(A_warp[thread_id, y]) + A_warp[thread_id, y] = A_shared[v0, v1] + + # T.writes(A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2]) + # A_warp[v0 % 8 * 4 + v1 % 8 // 2, v1 // 8 * 4 + v0 // 8 * 2 + v1 % 2] = A_shared[ + # v0, v1 + # ] @T.prim_func def ldmatrix_a_impl(a: T.handle, c: T.handle) -> None: @@ -390,22 +399,39 @@ def tile_wmma_fragment(block_read, height): sch.reorder(i0, j0, i1, j1) return i1 - def shared_16x16_to_ldmatrix_32x8_layout(i, j): - i_0 = i // 16 - j_0 = j // 16 - - i = i % 16 - j = j % 16 - - thread_id = 4 * (i % 8) + (j % 8) // 2 - return i_0, j_0, thread_id, 4 * (j // 8) + (i // 8) * 2 + (j % 8) % 2 - loop_a = tile_wmma_fragment(A_warp, 16) loop_b = tile_wmma_fragment(B_warp, 16) - sch.transform_layout(A_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout) - sch.transform_layout(B_warp, 0, "write", index_map=shared_16x16_to_ldmatrix_32x8_layout) - sch.transform_layout(C_warp, 0, "read", index_map=shared_16x16_to_ldmatrix_32x8_layout) + sch.transform_layout( + A_warp, + 0, + "write", + index_map=lambda i, j: ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ), + ) + sch.transform_layout( + B_warp, + 0, + "write", + index_map=lambda i, j: ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ), + ) + sch.transform_layout( + C_warp, + 0, + "read", + index_map=lambda i, j: ( + i // 16, + j // 16, + *shared_16x16_to_ldmatrix_32x8_layout(i % 16, j % 16), + ), + ) sch.tensorize(loop_a, "mma.ldmatrix_a") sch.tensorize(loop_b, "mma.ldmatrix_b") @@ -438,44 +464,44 @@ def shared_16x16_to_ldmatrix_32x8_layout(i, j): schedule(sch) print(sch.mod.script()) -if tune: - with tempfile.TemporaryDirectory() as work_dir: - sch = ms.tune_tir( - mod=workload, - target=tvm.target.Target("nvidia/geforce-rtx-3070"), - config=ms.TuneConfig( - strategy="evolutionary", - num_trials_per_iter=32, - max_trials_per_task=128, - max_trials_global=128, - ), - work_dir=work_dir, - space=ms.space_generator.ScheduleFn(schedule), - ) - if sch is None: - print("No valid schedule found!") - else: - print(sch.mod.script()) - print(sch.trace) -else: - target = "cuda" - f = tvm.build(sch.mod["main"], target=target, name="dense") - -dev = tvm.device("cuda", 0) -a_np = np.random.uniform(size=(N, K)).astype("float16") -b_np = np.random.uniform(size=(K, M)).astype("float16") -c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) -a = tvm.nd.array(a_np, dev) -b = tvm.nd.array(b_np, dev) -c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) -f = tvm.build(sch.mod["main"], target="cuda", name="dense") - -print(f.imported_modules[0].get_source()) -f(a, b, c) -tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) -print("ok") - -evaluator = f.time_evaluator(f.entry_name, dev, number=1000) -gflops = (N * M * K) * 2 / 1e9 -time_ms = evaluator(a, b, c).mean * 1e3 -print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3))) +# if tune: +# with tempfile.TemporaryDirectory() as work_dir: +# sch = ms.tune_tir( +# mod=workload, +# target=tvm.target.Target("nvidia/geforce-rtx-3070"), +# config=ms.TuneConfig( +# strategy="evolutionary", +# num_trials_per_iter=32, +# max_trials_per_task=128, +# max_trials_global=128, +# ), +# work_dir=work_dir, +# space=ms.space_generator.ScheduleFn(schedule), +# ) +# if sch is None: +# print("No valid schedule found!") +# else: +# print(sch.mod.script()) +# print(sch.trace) +# else: +# target = "cuda" +# f = tvm.build(sch.mod["main"], target=target, name="dense") + +# dev = tvm.device("cuda", 0) +# a_np = np.random.uniform(size=(N, K)).astype("float16") +# b_np = np.random.uniform(size=(K, M)).astype("float16") +# c_np = np.dot(a_np.astype("float32"), b_np.astype("float32")) +# a = tvm.nd.array(a_np, dev) +# b = tvm.nd.array(b_np, dev) +# c = tvm.nd.array(np.zeros((M, N), dtype="float32"), dev) +# f = tvm.build(sch.mod["main"], target="cuda", name="dense") + +# print(f.imported_modules[0].get_source()) +# f(a, b, c) +# tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-3) +# print("ok") + +# evaluator = f.time_evaluator(f.entry_name, dev, number=1000) +# gflops = (N * M * K) * 2 / 1e9 +# time_ms = evaluator(a, b, c).mean * 1e3 +# print("matmul with tensor core: %f ms, %f GFLOPS" % (time_ms, gflops / (time_ms / 1e3)))