diff --git a/taichi/backends/metal/codegen_metal.cpp b/taichi/backends/metal/codegen_metal.cpp index c88f66fafe0f9..534f162dc330c 100644 --- a/taichi/backends/metal/codegen_metal.cpp +++ b/taichi/backends/metal/codegen_metal.cpp @@ -946,11 +946,11 @@ class KernelCodegen : public IRVisitor { { ScopedIndent s2(current_appender()); emit("const int parent_idx_ = (ii / child_num_slots);"); - emit("if (parent_idx_ >= num_active(&parent_list)) return;"); + emit("if (parent_idx_ >= parent_list.num_active()) return;"); emit("const int child_idx_ = (ii % child_num_slots);"); emit( - "const auto parent_elem_ = get(&parent_list, " - "parent_idx_);"); + "const auto parent_elem_ = " + "parent_list.get(parent_idx_);"); emit("ListgenElement {};", kListgenElemVarName); // No need to add mem_offset_in_parent, because place() always starts at 0 diff --git a/taichi/backends/metal/kernel_manager.cpp b/taichi/backends/metal/kernel_manager.cpp index cc32942998d2c..1e0a1637bfcbb 100644 --- a/taichi/backends/metal/kernel_manager.cpp +++ b/taichi/backends/metal/kernel_manager.cpp @@ -752,7 +752,7 @@ class KernelManager::Impl { ListManager root_lm; root_lm.lm_data = rtm_list_head + root_id; root_lm.mem_alloc = alloc; - append(&root_lm, root_elem); + root_lm.append(root_elem); } did_modify_range(runtime_buffer_.get(), /*location=*/0, diff --git a/taichi/backends/metal/shaders/runtime_kernels.metal.h b/taichi/backends/metal/shaders/runtime_kernels.metal.h index ce94e51e68dbc..9b6d5467c2b05 100644 --- a/taichi/backends/metal/shaders/runtime_kernels.metal.h +++ b/taichi/backends/metal/shaders/runtime_kernels.metal.h @@ -53,7 +53,7 @@ STR( child_list.lm_data = (reinterpret_cast(runtime_addr)->snode_lists + child_snode_id); - clear(&child_list); + child_list.clear(); } kernel void element_listgen(device byte *runtime_addr[[buffer(0)]], @@ -83,13 +83,13 @@ STR( const int max_num_elems = args[2]; for (int ii = utid_; ii < max_num_elems; ii += grid_size) { const int parent_idx = (ii / num_slots); - if (parent_idx >= num_active(&parent_list)) { + if (parent_idx >= parent_list.num_active()) { // Since |parent_idx| increases monotonically, we can return directly // once it goes beyond the number of active parent elements. return; } const int child_idx = (ii % num_slots); - const auto parent_elem = get(&parent_list, parent_idx); + const auto parent_elem = parent_list.get(parent_idx); ListgenElement child_elem; child_elem.root_mem_offset = parent_elem.root_mem_offset + child_idx * child_stride + @@ -99,7 +99,7 @@ STR( refine_coordinates(parent_elem, runtime->snode_extractors[parent_snode_id], child_idx, &child_elem); - append(&child_list, child_elem); + child_list.append(child_elem); } } } diff --git a/taichi/backends/metal/shaders/runtime_structs.metal.h b/taichi/backends/metal/shaders/runtime_structs.metal.h index 578648fca93dc..bb5b79eb0bc59 100644 --- a/taichi/backends/metal/shaders/runtime_structs.metal.h +++ b/taichi/backends/metal/shaders/runtime_structs.metal.h @@ -51,11 +51,6 @@ STR( atomic_int chunks[kTaichiNumChunks]; }; - struct ListManager { - device ListManagerData *lm_data; - device MemoryAllocator *mem_alloc; - }; - // This class is very similar to metal::SNodeDescriptor struct SNodeMeta { enum Type { Root = 0, Dense = 1, Bitmasked = 2, Dynamic = 3 }; diff --git a/taichi/backends/metal/shaders/runtime_utils.metal.h b/taichi/backends/metal/shaders/runtime_utils.metal.h index 36e53cb790c04..d993a8023c772 100644 --- a/taichi/backends/metal/shaders/runtime_utils.metal.h +++ b/taichi/backends/metal/shaders/runtime_utils.metal.h @@ -31,8 +31,7 @@ // clang-format off METAL_BEGIN_RUNTIME_UTILS_DEF STR( - using PtrOffset = int32_t; - constant constexpr int kAlignment = 8; + using PtrOffset = int32_t; constant constexpr int kAlignment = 8; [[maybe_unused]] PtrOffset mtl_memalloc_alloc(device MemoryAllocator * ma, int32_t size) { @@ -46,86 +45,98 @@ STR( return reinterpret_cast(ma + 1) + offs; } - [[maybe_unused]] int num_active(thread ListManager *l) { - return atomic_load_explicit(&(l->lm_data->next), - metal::memory_order_relaxed); - } + struct ListManager { + device ListManagerData *lm_data; + device MemoryAllocator *mem_alloc; - [[maybe_unused]] void clear(thread ListManager *l) { - atomic_store_explicit(&(l->lm_data->next), 0, - metal::memory_order_relaxed); - } + inline int num_active() { + return atomic_load_explicit(&(lm_data->next), + metal::memory_order_relaxed); + } - [[maybe_unused]] PtrOffset mtl_listmgr_ensure_chunk(thread ListManager *l, - int i) { - device ListManagerData *list = l->lm_data; - PtrOffset offs = 0; - const int kChunkBytes = - (list->element_stride << list->log2_num_elems_per_chunk); - - while (true) { - int stored = 0; - // If chunks[i] is unallocated, i.e. 0, mark it as 1 to prevent others - // from requesting memory again. Once allocated, set chunks[i] to the - // actual address offset, which is guaranteed to be greater than 1. - const bool is_me = atomic_compare_exchange_weak_explicit( - list->chunks + i, &stored, 1, metal::memory_order_relaxed, - metal::memory_order_relaxed); - if (is_me) { - offs = mtl_memalloc_alloc(l->mem_alloc, kChunkBytes); - atomic_store_explicit(list->chunks + i, offs, - metal::memory_order_relaxed); - break; - } else if (stored > 1) { - offs = stored; - break; - } - // |stored| == 1, just spin + inline void resize(int sz) { + atomic_store_explicit(&(lm_data->next), sz, + metal::memory_order_relaxed); } - return offs; - } - [[maybe_unused]] device char *mtl_listmgr_get_elem_from_chunk( - thread ListManager *l, - int i, - PtrOffset chunk_ptr_offs) { - device ListManagerData *list = l->lm_data; - device char *chunk_ptr = reinterpret_cast( - mtl_memalloc_to_ptr(l->mem_alloc, chunk_ptr_offs)); - const uint32_t mask = ((1 << list->log2_num_elems_per_chunk) - 1); - return chunk_ptr + ((i & mask) * list->element_stride); - } + inline void clear() { + resize(0); + } - [[maybe_unused]] device char *append(thread ListManager *l) { - device ListManagerData *list = l->lm_data; - const int elem_idx = atomic_fetch_add_explicit( - &list->next, 1, metal::memory_order_relaxed); - const int chunk_idx = elem_idx >> list->log2_num_elems_per_chunk; - const PtrOffset chunk_ptr_offs = mtl_listmgr_ensure_chunk(l, chunk_idx); - return mtl_listmgr_get_elem_from_chunk(l, elem_idx, chunk_ptr_offs); - } + struct ReserveElemResult { + int elem_idx; + PtrOffset chunk_ptr_offs; + }; + + ReserveElemResult reserve_new_elem() { + const int elem_idx = atomic_fetch_add_explicit( + &lm_data->next, 1, metal::memory_order_relaxed); + const int chunk_idx = elem_idx >> lm_data->log2_num_elems_per_chunk; + const PtrOffset chunk_ptr_offs = ensure_chunk(chunk_idx); + return {elem_idx, chunk_ptr_offs}; + } + + device char *append() { + auto reserved = reserve_new_elem(); + return get_elem_from_chunk(reserved.elem_idx, reserved.chunk_ptr_offs); + } - template - [[maybe_unused]] void append(thread ListManager *l, thread const T &elem) { - device char *ptr = append(l); - thread char *elem_ptr = (thread char *)(&elem); + template + void append(thread const T &elem) { + device char *ptr = append(); + thread char *elem_ptr = (thread char *)(&elem); - for (int i = 0; i < l->lm_data->element_stride; ++i) { - *ptr = *elem_ptr; - ++ptr; - ++elem_ptr; + for (int i = 0; i < lm_data->element_stride; ++i) { + *ptr = *elem_ptr; + ++ptr; + ++elem_ptr; + } } - } - template - [[maybe_unused]] T get(thread ListManager *l, int i) { - device ListManagerData *list = l->lm_data; - const int chunk_idx = i >> list->log2_num_elems_per_chunk; - const PtrOffset chunk_ptr_offs = atomic_load_explicit( - list->chunks + chunk_idx, metal::memory_order_relaxed); - return *reinterpret_cast( - mtl_listmgr_get_elem_from_chunk(l, i, chunk_ptr_offs)); - } + template + T get(int i) { + const int chunk_idx = i >> lm_data->log2_num_elems_per_chunk; + const PtrOffset chunk_ptr_offs = atomic_load_explicit( + lm_data->chunks + chunk_idx, metal::memory_order_relaxed); + return *reinterpret_cast( + get_elem_from_chunk(i, chunk_ptr_offs)); + } + + private: + PtrOffset ensure_chunk(int i) { + PtrOffset offs = 0; + const int kChunkBytes = + (lm_data->element_stride << lm_data->log2_num_elems_per_chunk); + + while (true) { + int stored = 0; + // If chunks[i] is unallocated, i.e. 0, mark it as 1 to prevent others + // from requesting memory again. Once allocated, set chunks[i] to the + // actual address offset, which is guaranteed to be greater than 1. + const bool is_me = atomic_compare_exchange_weak_explicit( + lm_data->chunks + i, &stored, 1, metal::memory_order_relaxed, + metal::memory_order_relaxed); + if (is_me) { + offs = mtl_memalloc_alloc(mem_alloc, kChunkBytes); + atomic_store_explicit(lm_data->chunks + i, offs, + metal::memory_order_relaxed); + break; + } else if (stored > 1) { + offs = stored; + break; + } + // |stored| == 1, just spin + } + return offs; + } + + device char *get_elem_from_chunk(int i, PtrOffset chunk_ptr_offs) { + device char *chunk_ptr = reinterpret_cast( + mtl_memalloc_to_ptr(mem_alloc, chunk_ptr_offs)); + const uint32_t mask = ((1 << lm_data->log2_num_elems_per_chunk) - 1); + return chunk_ptr + ((i & mask) * lm_data->element_stride); + } + }; [[maybe_unused]] int is_active(device byte *addr, SNodeMeta meta, int i) { if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) { @@ -207,8 +218,7 @@ STR( device auto *n_ptr = reinterpret_cast( addr + (meta.num_slots * meta.element_stride)); return atomic_load_explicit(n_ptr, metal::memory_order_relaxed); - } -) + }) METAL_END_RUNTIME_UTILS_DEF // clang-format on diff --git a/taichi/backends/metal/struct_metal.cpp b/taichi/backends/metal/struct_metal.cpp index 9d976998f5e5f..8b5ab05ffd477 100644 --- a/taichi/backends/metal/struct_metal.cpp +++ b/taichi/backends/metal/struct_metal.cpp @@ -26,7 +26,7 @@ namespace shaders { } // namespace shaders -constexpr size_t kListManagerSize = sizeof(shaders::ListManager); +constexpr size_t kListManagerDataSize = sizeof(shaders::ListManagerData); constexpr size_t kSNodeMetaSize = sizeof(shaders::SNodeMeta); constexpr size_t kSNodeExtractorsSize = sizeof(shaders::SNodeExtractors); @@ -226,8 +226,8 @@ class StructCompiler { } size_t compute_runtime_size() { - size_t result = (max_snodes_) * - (kSNodeMetaSize + kSNodeExtractorsSize + kListManagerSize); + size_t result = (max_snodes_) * (kSNodeMetaSize + kSNodeExtractorsSize + + kListManagerDataSize); result += sizeof(uint32_t) * kNumRandSeeds; return result; }