From 54a5539c5311316d0d7b4f1cd2b5210bd5b5cf2b Mon Sep 17 00:00:00 2001 From: masahi Date: Sat, 31 Jul 2021 23:20:24 +0900 Subject: [PATCH] [CUDA] Support multiple TIR-level dynamic shared memory allocations (#8571) --- include/tvm/tir/transform.h | 5 + python/tvm/driver/build_module.py | 5 +- python/tvm/tir/transform/transform.py | 12 + src/driver/driver_api.cc | 1 + ...merge_dynamic_shared_memory_allocations.cc | 149 ++++++++++ ...merge_dynamic_shared_memory_allocations.py | 259 ++++++++++++++++++ 6 files changed, 430 insertions(+), 1 deletion(-) create mode 100644 src/tir/transforms/merge_dynamic_shared_memory_allocations.cc create mode 100644 tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index ced060b8cc86..d1308fe0059e 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -437,6 +437,11 @@ TVM_DLL Pass LowerMatchBuffer(); */ TVM_DLL Pass FlattenBuffer(); +/*! + * A pass to merge multiple TIR-level dynamic shared memory allocations into one + */ +TVM_DLL Pass MergeDynamicSharedMemoryAllocations(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py index 983a40ab5b3f..a7ebc00c315f 100644 --- a/python/tvm/driver/build_module.py +++ b/python/tvm/driver/build_module.py @@ -161,7 +161,10 @@ def _build_for_device(input_mod, target, target_host): mod_mixed = input_mod mod_mixed = tvm.tir.transform.Apply(lambda f: f.with_attr("target", target))(mod_mixed) - opt_mixed = [tvm.tir.transform.VerifyMemory()] + opt_mixed = [ + tvm.tir.transform.VerifyMemory(), + tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), + ] if len(mod_mixed.functions) == 1: opt_mixed += [tvm.tir.transform.Apply(lambda f: f.with_attr("tir.is_entry_func", True))] diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 3e47eb5a4254..537499a27fa9 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -666,3 +666,15 @@ def FlattenBuffer(): The result pass """ return _ffi_api.FlattenBuffer() # type: ignore + + +def MergeDynamicSharedMemoryAllocations(): + """This pass merges multiple TIR-level dynamic shared memory allocations + into one allocation. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.MergeDynamicSharedMemoryAllocations() # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index 9795bf3bc704..2008fe5e47b8 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -378,6 +378,7 @@ std::pair SplitDevHostFuncs(IRModule mod_mixed, const Target Array mixed_pass_list = {BindTarget(target), tir::transform::VerifyMemory()}; + mixed_pass_list.push_back(tir::transform::MergeDynamicSharedMemoryAllocations()); if (pass_ctx->GetConfig("tir.detect_global_barrier", Bool(false)).value()) { mixed_pass_list.push_back(tir::transform::ThreadSync("global")); } diff --git a/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc new file mode 100644 index 000000000000..e8865b260dc1 --- /dev/null +++ b/src/tir/transforms/merge_dynamic_shared_memory_allocations.cc @@ -0,0 +1,149 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file merge_dynamic_shared_memory_allocations.cc + * \brief Each GPU kernel is allowed to have only one dynamic shared memory allocation. + * This pass merges multiple TIR-level dynamic shared memory allocations into one allocation. + */ +#include +#include +#include +#include + +#include +#include + +#include "../../runtime/thread_storage_scope.h" +#include "ir_utils.h" + +namespace tvm { +namespace tir { + +bool IsDynamicSharedMemory(Var buffer_var) { + auto storage_scope = runtime::StorageScope::Create(GetPtrStorageScope(buffer_var)); + return storage_scope.rank == runtime::StorageRank::kShared && storage_scope.tag == ".dyn"; +} + +class AllocateCollector : public StmtExprVisitor { + public: + void VisitStmt_(const AllocateNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + dyn_shmem_allocs_.insert(op); + } + StmtExprVisitor::VisitStmt_(op); + } + + std::unordered_set dyn_shmem_allocs_; +}; + +class DynamicSharedMemoryRewriter : public StmtExprMutator { + public: + explicit DynamicSharedMemoryRewriter( + const std::unordered_set& dyn_shmem_allocs) + : dyn_shmem_allocs_{dyn_shmem_allocs} {} + + Stmt VisitStmt_(const AttrStmtNode* op) final { + if (op->attr_key == attr::thread_extent && !allocated) { + // Allocate one dynamic shared memory allocation at the beginning of thread scope + int align = 1; + for (const auto& alloc : dyn_shmem_allocs_) { + ICHECK_EQ(alloc->dtype.lanes(), 1) << "vector dtype allocation not supported."; + align = std::max(align, alloc->dtype.bytes()); + } + for (const auto& alloc : dyn_shmem_allocs_) { + ICHECK_EQ(alloc->extents.size(), 1); + buffer_byte_offsets_[alloc->buffer_var.get()] = merged_alloc_size_; + merged_alloc_size_ += alloc->extents[0] * align; + } + + allocated = true; + auto new_body = Allocate(merged_buf_var_, DataType::UInt(8), {merged_alloc_size_}, + const_true(), StmtExprMutator::VisitStmt(op->body)); + return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span); + } + return StmtMutator::VisitStmt_(op); + } + + Stmt VisitStmt_(const AllocateNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + return StmtExprMutator::VisitStmt(op->body); + } + return StmtExprMutator::VisitStmt_(op); + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + auto offset = GetBufferOffset(op->buffer_var, op->dtype); + auto index = StmtExprMutator::VisitExpr(op->index); + return Load(op->dtype, merged_buf_var_, offset + index, op->predicate, op->span); + } + return StmtExprMutator::VisitExpr_(op); + } + + Stmt VisitStmt_(const StoreNode* op) final { + if (IsDynamicSharedMemory(op->buffer_var)) { + auto offset = GetBufferOffset(op->buffer_var, op->value->dtype); + auto index = StmtExprMutator::VisitExpr(op->index); + auto value = StmtExprMutator::VisitExpr(op->value); + return Store(merged_buf_var_, value, offset + index, op->predicate, op->span); + } + return StmtExprMutator::VisitStmt_(op); + } + + private: + PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) { + auto it = buffer_byte_offsets_.find(buffer_var.get()); + ICHECK(it != buffer_byte_offsets_.end()); + return indexdiv(it->second, dtype.bytes()); + } + + Var merged_buf_var_{"buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), "shared.dyn")}; + std::unordered_set dyn_shmem_allocs_; + PrimExpr merged_alloc_size_{0}; + std::unordered_map buffer_byte_offsets_; + bool allocated{false}; +}; + +Stmt MergeDynamicSharedMemoryAllocations(Stmt stmt) { + AllocateCollector collector; + collector(stmt); + if (collector.dyn_shmem_allocs_.size() > 1) { + return DynamicSharedMemoryRewriter(collector.dyn_shmem_allocs_)(std::move(stmt)); + } + return stmt; +} + +namespace transform { + +Pass MergeDynamicSharedMemoryAllocations() { + auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = MergeDynamicSharedMemoryAllocations(std::move(n->body)); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.MergeDynamicSharedMemoryAllocations", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.MergeDynamicSharedMemoryAllocations") + .set_body_typed(MergeDynamicSharedMemoryAllocations); + +} // namespace transform +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py new file mode 100644 index 000000000000..9c511f1de6b9 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_merge_dynamic_shared_memory_allocations.py @@ -0,0 +1,259 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import te +import numpy as np +import tvm.testing +from tvm.topi.math import cast + + +def run_passes(sch, args): + bounds = tvm.te.schedule.InferBound(sch) + assert isinstance(bounds, tvm.container.Map) + stmt = tvm.te.schedule.ScheduleOps(sch, bounds) + + func = tvm.te.schedule.SchedulePostProcToPrimFunc(args, stmt, None) + mod = tvm.IRModule.from_expr(func) + return tvm.transform.Sequential( + [ + tvm.tir.transform.StorageFlatten(64), + tvm.tir.transform.Simplify(), + tvm.tir.transform.VectorizeLoop(), + tvm.tir.transform.StorageRewrite(), + tvm.tir.transform.MergeDynamicSharedMemoryAllocations(), + ] + )(mod) + + +def verify_single_allocation(stmt, alloc_size=None): + num_alloc = [0] + alloc_extents = [] + + def verify(n): + if ( + isinstance(n, tvm.tir.Allocate) + and n.buffer_var.type_annotation.storage_scope == "shared.dyn" + ): + num_alloc[0] += 1 + alloc_extents.append(n.extents[0]) + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + assert num_alloc[0] == 1 + + if alloc_size: + assert alloc_extents[0] == alloc_size + + +@tvm.testing.requires_gpu +def test_matmul_dyn_shared(): + n = 1024 + block = 16 + A = te.placeholder((n, n), name="A", dtype="float16") + B = te.placeholder((n, n), name="B", dtype="float16") + + def syncthread(): + return tvm.tir.Call(None, "tir.tvm_storage_sync", tvm.runtime.convert(["shared"])) + + def test_matmul_ir(A, B, C): + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ty = te.thread_axis("threadIdx.y") + bx = te.thread_axis("blockIdx.x") + by = te.thread_axis("blockIdx.y") + ib.scope_attr(tx, "thread_extent", block) + ib.scope_attr(ty, "thread_extent", block) + ib.scope_attr(bx, "thread_extent", n // block) + ib.scope_attr(by, "thread_extent", n // block) + + A_sh = ib.allocate(A.dtype, (block, block), scope="shared.dyn", name="A_sh") # fp16 + B_sh = ib.allocate(B.dtype, (block, block), scope="shared.dyn", name="B_sh") # fp16 + # Create a dynamic shared memory for the accumulation. + # This is for testing merging dynamic shared memory alloctions with different data type. + # In practice, there is no need to allocate a shared memory for C. + C_sh = ib.allocate(C.dtype, (block, block), scope="shared.dyn", name="C_sh") # fp32 + + A_ptr = ib.buffer_ptr(A) + B_ptr = ib.buffer_ptr(B) + C_ptr = ib.buffer_ptr(C) + + C_sh[ty, tx] = 0.0 + + with ib.for_range(0, n // block, name="i") as i: + A_sh[ty, tx] = A_ptr[by * block + ty, i * block + tx] + B_sh[ty, tx] = B_ptr[i * block + ty, bx * block + tx] + ib.emit(syncthread()) + + with ib.for_range(0, block, name="k") as k: + C_sh[ty, tx] += cast(A_sh[ty, k] * B_sh[k, tx], "float32") + + ib.emit(syncthread()) + + C_ptr[by * block + ty, bx * block + tx] = C_sh[ty, tx] + + return ib.get() + + C = te.extern( + A.shape, + [A, B], + lambda ins, outs: test_matmul_ir(ins[0], ins[1], outs[0]), + name="matmul", + dtype="float32", + ) + s = te.create_schedule(C.op) + mod = run_passes(s, [A, B, C]) + expected_alloc_size = block * block * 3 * 4 + verify_single_allocation(mod["main"].body, expected_alloc_size) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fmatmul = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + size = (n, n) + a_np = np.random.uniform(size=size).astype(A.dtype) + b_np = np.random.uniform(size=size).astype(B.dtype) + a = tvm.nd.array(a_np, dev) + b = tvm.nd.array(b_np, dev) + c = tvm.nd.array(np.zeros(size, dtype=C.dtype), dev) + fmatmul(a, b, c) + np_ref = np.dot(a_np.astype("float32"), b_np.astype("float32")) + tvm.testing.assert_allclose(c.numpy(), np_ref, 1e-4, 1e-4) + + for target in ["cuda", "nvptx"]: + check_target(target) + + +@tvm.testing.requires_gpu +def test_dyn_shared_vectorized_store(): + """Test vectorized store into dynamic shared memory""" + n = te.size_var("n") + A = te.placeholder((n,), name="A", dtype="float16") + B = te.placeholder((n,), name="B", dtype="float32") + + def test_device_ir(A, B, C): + n = A.shape[0] + ib = tvm.tir.ir_builder.create() + + values_per_thread = 4 + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", tvm.tir.indexdiv(n, values_per_thread)) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn") # fp16 + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn") # fp32 + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + + with ib.for_range(0, values_per_thread, kind="vectorize") as i: + A_sh[tx * values_per_thread + i] = Aptr[tx * values_per_thread + i] + B_sh[tx * values_per_thread + i] = Bptr[tx * values_per_thread + i] + + with ib.for_range(0, values_per_thread) as i: + Cptr[tx * values_per_thread + i] = ( + cast(A_sh[tx * values_per_thread + i], "float32") + B_sh[tx * values_per_thread + i] + ) + + return ib.get() + + C = te.extern( + (n,), + [A, B], + lambda ins, outs: test_device_ir(ins[0], ins[1], outs[0]), + name="vadd", + dtype="float32", + ) + s = te.create_schedule(C.op) + + mod = run_passes(s, [A, B, C]) + verify_single_allocation(mod["main"].body) + + def check_target(target): + if not tvm.testing.device_enabled(target): + return + + fadd = tvm.build(s, [A, B, C], target) + dev = tvm.device(target, 0) + + for n in [512, 1024]: + a = tvm.nd.array(np.random.uniform(size=n).astype(A.dtype), dev) + b = tvm.nd.array(np.random.uniform(size=n).astype(B.dtype), dev) + c = tvm.nd.array(np.zeros((n,), dtype=C.dtype), dev) + fadd(a, b, c) + tvm.testing.assert_allclose( + c.numpy(), a.numpy().astype("float32") + b.numpy(), 1e-4, 1e-4 + ) + + for target in ["cuda", "nvptx"]: + check_target(target) + + +@tvm.testing.requires_gpu +def test_dyn_shared_reuse_and_merge(): + n = 64 + A = te.placeholder((n,), name="A", dtype="float32") + B = te.placeholder((n,), name="B", dtype="float32") + C = te.placeholder((te.size_var("n_dyn"),), name="C", dtype="float32") + + def test_device_ir(A, B, C, D): + ib = tvm.tir.ir_builder.create() + + tx = te.thread_axis("threadIdx.x") + ib.scope_attr(tx, "thread_extent", n) + + A_sh = ib.allocate(A.dtype, (n,), scope="shared.dyn", name="A_sh") + B_sh = ib.allocate(B.dtype, (n,), scope="shared.dyn", name="B_sh") + C_sh = ib.allocate(C.dtype, (C.shape[0],), scope="shared.dyn", name="C_sh") + + Aptr = ib.buffer_ptr(A) + Bptr = ib.buffer_ptr(B) + Cptr = ib.buffer_ptr(C) + Dptr = ib.buffer_ptr(D) + + A_sh[tx] = Aptr[tx] + Dptr[tx] = A_sh[tx] + + B_sh[tx] = Bptr[tx] + Dptr[tx] += B_sh[tx] + + C_sh[tx] = Cptr[tx] # C cannot reuse other buffers since it size is dynamic + Dptr[tx] += C_sh[tx] + + return ib.get() + + D = te.extern( + (n,), + [A, B, C], + lambda ins, outs: test_device_ir(ins[0], ins[1], ins[2], outs[0]), + name="vadd", + dtype="float32", + ) + s = te.create_schedule(D.op) + + mod = run_passes(s, [A, B, C, D]) + # merged allocation + # allocate(buf_dyn_shmem: Pointer(shared.dyn uint8), uint8, [((n_dyn*4) + 256)]); + verify_single_allocation(mod["main"].body) + + +if __name__ == "__main__": + test_matmul_dyn_shared() + test_dyn_shared_vectorized_store() + test_dyn_shared_reuse_and_merge()