From 5d3dad7168172361b31de027f4eca75c7701e417 Mon Sep 17 00:00:00 2001 From: Ye Kuang Date: Tue, 13 Oct 2020 21:18:06 +0900 Subject: [PATCH] [metal] Support TLS for struct-for tasks --- taichi/backends/metal/codegen_metal.cpp | 53 ++++++++++++++++++++----- 1 file changed, 44 insertions(+), 9 deletions(-) diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index 4ffdae912a0dc..c2de4856e403f 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -945,7 +945,12 @@ class KernelCodegen : public IRVisitor { ka.task_type = stmt->task_type; ka.buffers = get_common_buffers(); - emit_mtl_kernel_sig(mtl_kernel_name, ka.buffers); + const bool used_tls = (stmt->tls_prologue != nullptr); + KernelSigExtensions kernel_exts; + kernel_exts.use_simdgroup = (used_tls && cgen_config_.allow_simdgroup); + used_features()->simdgroup = + used_features()->simdgroup || kernel_exts.use_simdgroup; + emit_mtl_kernel_sig(mtl_kernel_name, ka.buffers, kernel_exts); const int sn_id = stmt->snode->id; // struct_for kernels use grid-stride loops @@ -959,6 +964,21 @@ class KernelCodegen : public IRVisitor { current_appender().push_indent(); emit("// struct_for"); emit_runtime_and_memalloc_def(); + + if (used_tls) { + // Using TLS means we will access some SNodes within this kernel. The + // struct of an SNode needs Runtime and MemoryAllocator to construct. + // Using |int32_t| because it aligns to 4bytes. + // + // TODO(k-ye): De-dupe TLS for range-for and struct-for. + emit("// TLS prologue"); + const std::string tls_bufi32_name = "tls_bufi32_"; + emit("int32_t {}[{}];", tls_bufi32_name, (stmt->tls_size + 3) / 4); + emit("thread char* {} = reinterpret_cast({});", + kTlsBufferName, tls_bufi32_name); + stmt->tls_prologue->accept(this); + } + emit("ListManager parent_list;"); emit("parent_list.lm_data = ({}->snode_lists + {});", kRuntimeVarName, sn_id); @@ -995,19 +1015,34 @@ class KernelCodegen : public IRVisitor { current_kernel_attribs_ = &ka; const auto mtl_func_name = mtl_kernel_func_name(mtl_kernel_name); - emit_mtl_kernel_func_def( - mtl_func_name, ka.buffers, - /*extra_params=*/ - {{"thread const ListgenElement&", kListgenElemVarName}}, - stmt->body.get()); - emit_call_mtl_kernel_func(mtl_func_name, ka.buffers, - /*extra_args=*/ - {kListgenElemVarName}, + std::vector extra_func_params = { + {"thread const ListgenElement&", kListgenElemVarName}, + }; + std::vector extra_args = { + kListgenElemVarName, + }; + if (used_tls) { + extra_func_params.push_back({"thread char*", kTlsBufferName}); + extra_args.push_back(kTlsBufferName); + } + emit_mtl_kernel_func_def(mtl_func_name, ka.buffers, extra_func_params, + stmt->body.get()); + emit_call_mtl_kernel_func(mtl_func_name, ka.buffers, extra_args, /*loop_index_expr=*/"ii"); current_kernel_attribs_ = nullptr; } emit("}}"); // closes for loop current_appender().pop_indent(); + + if (used_tls) { + // TODO(k-ye): De-dupe TLS for range-for and struct-for. + TI_ASSERT(stmt->tls_epilogue != nullptr); + inside_tls_epilogue_ = true; + emit("{{ // TLS epilogue"); + stmt->tls_epilogue->accept(this); + inside_tls_epilogue_ = false; + emit("}}"); + } emit("}}\n"); // closes kernel mtl_kernels_attribs()->push_back(ka);