From 261595e3a6390cba6b59bd474b8644a9b9e18af2 Mon Sep 17 00:00:00 2001 From: Siyuan Feng Date: Sat, 1 May 2021 23:02:30 +0800 Subject: [PATCH] [TensorIR][Pass][M1c] FlattenBuffer Co-authored-by: Tianqi Chen Co-authored-by: Ruihang Lai --- include/tvm/tir/transform.h | 8 + python/tvm/script/intrin.py | 2 +- python/tvm/script/scope_handler.py | 8 +- python/tvm/tir/transform/transform.py | 13 + src/printer/tvmscript_printer.cc | 3 +- src/tir/transforms/flatten_buffer.cc | 187 ++++++++++++++ .../test_tir_transform_flatten_buffer.py | 243 ++++++++++++++++++ .../unittest/test_tvmscript_roundtrip.py | 26 ++ 8 files changed, 485 insertions(+), 5 deletions(-) create mode 100644 src/tir/transforms/flatten_buffer.cc create mode 100644 tests/python/unittest/test_tir_transform_flatten_buffer.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index a236c5075d4bc..2de255da3fa24 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -406,6 +406,14 @@ TVM_DLL Pass ConvertBlocksToOpaque(); */ TVM_DLL Pass CompactBufferAllocation(); +/*! + * \brief Flatten the multi-dimensional BufferLoad and BufferStore + * to single dimensional Load/Store. Also remove Block to + * ensure that the flattened TIR can not be scheduled again. + * \return The pass. + */ +TVM_DLL Pass FlattenBuffer(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/script/intrin.py b/python/tvm/script/intrin.py index 48f50a2da442e..2ca11ff30b4b4 100644 --- a/python/tvm/script/intrin.py +++ b/python/tvm/script/intrin.py @@ -121,7 +121,7 @@ def floormod(x, y, span): @register -def load(dtype, var, index, predicate=True, span=None): +def load(dtype, var, index, predicate=None, span=None): return tvm.tir.Load(dtype, var, index, predicate, span) diff --git a/python/tvm/script/scope_handler.py b/python/tvm/script/scope_handler.py index fec9ec0cc11c0..11eecc9831a4f 100644 --- a/python/tvm/script/scope_handler.py +++ b/python/tvm/script/scope_handler.py @@ -154,15 +154,17 @@ class LaunchThread(WithScopeHandler): def __init__(self): def launch_thread(env_var, extent, span): extent = tvm.runtime.convert(extent, span=span) + thread_id = self.context.func_var_env_dict[env_var] + attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent" return tvm.tir.AttrStmt( IterVar( - None, + (0, extent), env_var, getattr(IterVar, "ThreadIndex"), - self.context.func_var_env_dict[env_var], + thread_id, span=span, ), - "thread_extent", + attr_key, extent, self.body, span=span, diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 2ae75d2d0a637..be55b48da71e4 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -610,3 +610,16 @@ def CompactBufferAllocation(): The result pass """ return _ffi_api.CompactBufferAllocation() + + +def FlattenBuffer(): + """Flatten the multi-dimensional BufferLoad and BufferStore + to single dimensional Load/Store. Also remove Block to + ensure that the flattened TIR can not be scheduled again. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.FlattenBuffer() diff --git a/src/printer/tvmscript_printer.cc b/src/printer/tvmscript_printer.cc index f6586ce7000a5..f2ab41294d80b 100644 --- a/src/printer/tvmscript_printer.cc +++ b/src/printer/tvmscript_printer.cc @@ -620,7 +620,8 @@ Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) { } } // concise thread env - if (op->node->IsInstance() && op->attr_key == "thread_extent") { + if (op->node->IsInstance() && + (op->attr_key == "thread_extent" || op->attr_key == "virtual_thread")) { const auto* iter_var = Downcast(op->node).get(); ICHECK(!iter_var->dom.defined()); var_not_in_headers.insert(iter_var->var.get()); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc new file mode 100644 index 0000000000000..05d46005bc23d --- /dev/null +++ b/src/tir/transforms/flatten_buffer.cc @@ -0,0 +1,187 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file flatten_buffer.cc + */ + +#include +#include +#include +#include +#include + +#include "../../support/utils.h" + +namespace tvm { +namespace tir { + +PrimExpr BufferArea(const Buffer& buffer) { + PrimExpr area = Integer(1); + for (const PrimExpr& dim : buffer->shape) { + area = area * dim; + } + return area; +} + +/*! + * \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store + */ +class BufferFlattener : public StmtExprMutator { + public: + static Stmt Flatten(const PrimFunc& f) { return BufferFlattener().VisitStmt(f->body); } + + private: + Stmt VisitStmt_(const BlockRealizeNode* op) final { + // We have convert blocks into opaque blocks in previous passes. + ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please " + "call pass ConvertBlocksToOpaque before."; + // Step 1. Visit the body + Block new_block = Downcast(this->VisitStmt(op->block)); + PrimExpr predicate = this->VisitExpr(op->predicate); + // Step 2. Transform the `predicate` to if-then-else + Stmt body = new_block->body; + if (!is_one(predicate)) { + body = IfThenElse(predicate, std::move(body)); + } + // Step 3. Handle allocations in reverse order + for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) { + const Buffer& buffer = new_block->alloc_buffers[i - 1]; + body = MakeAllocStmt(buffer, std::move(body)); + } + return body; + } + + Stmt VisitStmt_(const ForNode* op) final { + // Step 1. Update unit loop info. + PrimExpr min = this->VisitExpr(op->min); + PrimExpr extent = this->VisitExpr(op->extent); + if (is_one(extent) && op->annotations.empty()) { + // handling unit loop + unit_loop_vars_[op->loop_var] = min; + } + // Step 2. Visit recursively + Stmt body = this->VisitStmt(op->body); + // Step 3. Create new For loop accordingly + if (op->kind == ForKind::kThreadBinding) { + // Case 1. Thread binding + ICHECK(op->thread_binding.defined()); + String thread_tag = op->thread_binding.value()->thread_tag; + body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body); + } else if (is_one(extent) && op->annotations.empty()) { + // Case 2. Unit loop + return body; + } else { + // Case 3. An ordinary loop + body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body)); + } + // Step 4. Handle annotations + for (const auto& annotation : op->annotations) { + const String& ann_key = annotation.first; + const ObjectRef& ann_value = annotation.second; + if (attr::IsPragmaKey(ann_key)) { + body = AttrStmt(op->loop_var, ann_key, Downcast(ann_value), std::move(body)); + } + } + return body; + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(op)); + return store->buffer.vstore(store->indices, store->value); + } + + PrimExpr VisitExpr_(const VarNode* op) final { + Var var = GetRef(op); + auto it = unit_loop_vars_.find(var); + if (it == unit_loop_vars_.end()) { + return std::move(var); + } else { + return it->second; + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(op)); + return load->buffer.vload(load->indices, load->dtype); + } + + // This part will not upstream to mainline. + PrimExpr VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::get_elem_offset())) { + // Handle `get_elem_offset` + ICHECK_EQ(op->args.size(), 1); + PrimExpr arg = op->args[0]; + ICHECK(arg->IsInstance()); + arg = this->VisitExpr(arg); + const auto* load = arg.as(); + ICHECK(load != nullptr); + return load->index; + } + return StmtExprMutator::VisitExpr_(op); + } + + static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) { + String storage_scope = buffer->scope; + if (storage_scope.empty()) { + storage_scope = "global"; + } + PrimExpr area = BufferArea(buffer); + body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body)); + body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body)); + return body; + } + + static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag, + Stmt body) { + IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent), + /*var=*/std::move(var), + /*iter_type=*/IterVarType::kThreadIndex, + /*thread_tag=*/thread_tag); + String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent; + return AttrStmt(/*node=*/std::move(iter_var), + /*attr_key=*/std::move(attr_key), + /*value=*/std::move(extent), + /*body=*/std::move(body)); + } + + /*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */ + std::unordered_map unit_loop_vars_; +}; + +PrimFunc FlattenBuffer(PrimFunc f) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = BufferFlattener::Flatten(f); + return f; +} + +namespace transform { + +Pass FlattenBuffer() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return FlattenBuffer(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer); +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py new file mode 100644 index 0000000000000..f618e964f4e4d --- /dev/null +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -0,0 +1,243 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm.script import ty + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.FlattenBuffer()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed, True) + + +@tvm.script.tir +def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer([1, 16], "float32") + for j in range(0, 16): + with tir.block() as []: + tir.reads(A[i, j]) + tir.writes(B[0, j]) + B[0, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block() as []: + tir.reads(B[0, j]) + tir.writes(C[i, j]) + C[i, j] = B[0, j] * 2.0 + + +@tvm.script.tir +def flattened_elementwise_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in tir.serial(0, 16): + B_new = tir.allocate([16], "float32", "global") + for j in tir.serial(0, 16): + B_new[j] = tir.load("float32", A.data, ((i * 16) + j)) + tir.float32(1) + for j in tir.serial(0, 16): + C.data[((i * 16) + j)] = tir.load("float32", B_new, j) * tir.float32(2) + + +@tvm.script.tir +def compacted_gpu_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i0 in tir.thread_binding(0, 4, thread="blockIdx.x"): + for i1 in tir.thread_binding(0, 2, thread="threadIdx.x"): + for i2 in tir.thread_binding(0, 2, thread="vthread"): + with tir.block([]): + tir.reads(A[i0 * 4 + i1 * 2 + i2, 0:16]) + tir.writes(C[i0 * 4 + i1 * 2 + i2, 0:16]) + B = tir.alloc_buffer([1, 16], "float32", scope="local") + for j in range(0, 16): + with tir.block() as []: + tir.reads(A[i0 * 4 + i1 * 2 + i2, j]) + tir.writes(B[0, j]) + B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 + for j in range(0, 16): + with tir.block() as []: + tir.reads(B[0, j]) + tir.writes(C[i0 * 4 + i1 * 2 + i2, j]) + C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 + + +@tvm.script.tir +def flattened_gpu_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + + i0 = tir.env_thread("blockIdx.x") + i1 = tir.env_thread("threadIdx.x") + i2 = tir.env_thread("vthread") + + tir.launch_thread(i0, 4) + tir.launch_thread(i1, 2) + tir.launch_thread(i2, 2) + B = tir.allocate([16], "float32", "local") + for j in range(0, 16): + B[j] = tir.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + tir.float32(1) + for j in range(0, 16): + C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = tir.load("float32", B, j) * tir.float32(2) + + +@tvm.script.tir +def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32) -> None: + A = tir.match_buffer(a, (n, m), "float32") + C = tir.match_buffer(c, (n, m), "float32") + + for i in range(0, n): + with tir.block([]): + tir.reads(A[i, m]) + tir.writes(C[i, m]) + B = tir.alloc_buffer((m,), "float32") + for j in range(0, m): + with tir.block([]) as []: + tir.reads(A[i, j]) + tir.writes(B[j]) + B[j] = A[i, j] + 1.0 + for j in range(0, m): + with tir.block([]) as []: + tir.reads(B[j]) + tir.writes(C[i, j]) + C[i, j] = B[j] * 2.0 + + +@tvm.script.tir +def flattened_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32, m: ty.int32) -> None: + A = tir.match_buffer(a, (n, m), "float32") + C = tir.match_buffer(c, (n, m), "float32") + + for i in range(0, n): + B = tir.allocate([m], "float32", "global") + for j in range(0, m): + B[j] = tir.load("float32", A.data, i * m + j) + tir.float32(1) + for j in range(0, m): + C.data[i * m + j] = tir.load("float32", B, j) * tir.float32(2) + + +@tvm.script.tir +def compacted_predicate_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (32), "float32") + C = tir.match_buffer(c, (32), "float32") + + for i, j in tir.grid(5, 7): + with tir.block([]) as []: + tir.reads(A[i * 7 + j]) + tir.writes(C[i * 7 + j]) + tir.where(i * 7 + j < 32) + C[i * 7 + j] = A[i * 7 + j] + 1.0 + + +@tvm.script.tir +def flattened_predicate_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (32), "float32") + C = tir.match_buffer(c, (32), "float32") + + for i, j in tir.grid(5, 7): + if i * 7 + j < 32: + C.data[i * 7 + j] = tir.load("float32", A.data, i * 7 + j) + tir.float32(1) + + +@tvm.script.tir +def compacted_unit_loop_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (32), "float32") + C = tir.match_buffer(c, (32), "float32") + + for x, y, z in tir.grid(4, 1, 8): + with tir.block([]) as []: + tir.reads(A[x * 8 + y * 8 + z]) + tir.writes(C[x * 8 + y * 8 + z]) + C[x * 8 + y * 8 + z] = A[x * 8 + y * 8 + z] + 1.0 + + +@tvm.script.tir +def flattened_unit_loop_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (32), "float32") + C = tir.match_buffer(c, (32), "float32") + + for x, z in tir.grid(4, 8): + C.data[x * 8 + z] = tir.load("float32", A.data, x * 8 + z) + tir.float32(1) + + +@tvm.script.tir +def compacted_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (32), "float32") + D = tir.match_buffer(d, (32), "float32") + + for i in range(0, 32): + with tir.block([]) as []: + tir.reads(A[i]) + tir.writes(D[i]) + B = tir.alloc_buffer((32,)) + C = tir.alloc_buffer((32,)) + B[i] = A[i] + 1.0 + C[i] = A[i] + B[i] + D[i] = C[i] * 2.0 + + +@tvm.script.tir +def flattened_multi_alloc_func(a: ty.handle, d: ty.handle) -> None: + A = tir.match_buffer(a, (32), "float32") + D = tir.match_buffer(d, (32), "float32") + + for i in range(0, 32): + B = tir.allocate((32,), "float32", "global") + C = tir.allocate((32,), "float32", "global") + B[i] = tir.load("float32", A.data, i) + tir.float32(1) + C[i] = tir.load("float32", A.data, i) + tir.load("float32", B, i) + D.data[i] = tir.load("float32", C, i) * tir.float32(2) + + +def test_elementwise(): + _check(compacted_elementwise_func, flattened_elementwise_func) + + +def test_gpu_workload(): + _check(compacted_gpu_func, flattened_gpu_func) + + +def test_symbolic_shape(): + _check(compacted_symbolic_func, flattened_symbolic_func) + + +def test_predicate(): + _check(compacted_predicate_func, flattened_predicate_func) + + +def test_unit_loops(): + _check(compacted_unit_loop_func, flattened_unit_loop_func) + + +def test_multi_alloc(): + _check(compacted_multi_alloc_func, flattened_multi_alloc_func) + + +if __name__ == "__main__": + test_elementwise() + test_gpu_workload() + test_symbolic_shape() + test_predicate() + test_unit_loops() + test_multi_alloc() diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index cbdcbbb2e6f0a..e84902f0540e1 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -2662,6 +2662,31 @@ def test_opt_conv_tensorcore_mod_host(): tvm.ir.assert_structural_equal(mod, rt_mod, True) +@tvm.script.tir +def vthread_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + + i0 = tir.env_thread("blockIdx.x") + i1 = tir.env_thread("threadIdx.x") + i2 = tir.env_thread("vthread") + + tir.launch_thread(i0, 4) + tir.launch_thread(i1, 2) + tir.launch_thread(i2, 2) + B = tir.allocate([16], "float32", "local") + for j in range(0, 16): + B[j] = tir.load("float32", A.data, i0 * 64 + i1 * 32 + i2 * 16 + j) + tir.float32(1) + for j in range(0, 16): + C.data[i0 * 64 + i1 * 32 + i2 * 16 + j] = tir.load("float32", B, j) * tir.float32(2) + + +def test_vthread(): + func = vthread_func + rt_func = tvm.script.from_source(tvm.script.asscript(func, True)) + tvm.ir.assert_structural_equal(func, rt_func, True) + + @tvm.script.tir def matmul(a: ty.handle, b: ty.handle, c: ty.handle) -> None: A = tir.match_buffer(a, [128, 128]) @@ -2870,6 +2895,7 @@ def test_opaque_block(): test_opt_conv_tensorcore_normalize() test_opt_conv_tensorcore_lower() test_opt_conv_tensorcore_mod_host() + test_vthread() test_module_define() test_matmul() test_matmul_original()