Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[metal] Refactor runtime ListManager utils #1444

Merged
merged 3 commits into from
Jul 11, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,11 +946,11 @@ class KernelCodegen : public IRVisitor {
{
ScopedIndent s2(current_appender());
emit("const int parent_idx_ = (ii / child_num_slots);");
emit("if (parent_idx_ >= num_active(&parent_list)) return;");
emit("if (parent_idx_ >= parent_list.num_active()) return;");
emit("const int child_idx_ = (ii % child_num_slots);");
emit(
"const auto parent_elem_ = get<ListgenElement>(&parent_list, "
"parent_idx_);");
"const auto parent_elem_ = "
"parent_list.get<ListgenElement>(parent_idx_);");

emit("ListgenElement {};", kListgenElemVarName);
// No need to add mem_offset_in_parent, because place() always starts at 0
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,7 @@ class KernelManager::Impl {
ListManager root_lm;
root_lm.lm_data = rtm_list_head + root_id;
root_lm.mem_alloc = alloc;
append(&root_lm, root_elem);
root_lm.append(root_elem);
}

did_modify_range(runtime_buffer_.get(), /*location=*/0,
Expand Down
8 changes: 4 additions & 4 deletions taichi/backends/metal/shaders/runtime_kernels.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ STR(
child_list.lm_data =
(reinterpret_cast<device Runtime *>(runtime_addr)->snode_lists +
child_snode_id);
clear(&child_list);
child_list.clear();
}

kernel void element_listgen(device byte *runtime_addr[[buffer(0)]],
Expand Down Expand Up @@ -83,13 +83,13 @@ STR(
const int max_num_elems = args[2];
for (int ii = utid_; ii < max_num_elems; ii += grid_size) {
const int parent_idx = (ii / num_slots);
if (parent_idx >= num_active(&parent_list)) {
if (parent_idx >= parent_list.num_active()) {
// Since |parent_idx| increases monotonically, we can return directly
// once it goes beyond the number of active parent elements.
return;
}
const int child_idx = (ii % num_slots);
const auto parent_elem = get<ListgenElement>(&parent_list, parent_idx);
const auto parent_elem = parent_list.get<ListgenElement>(parent_idx);
ListgenElement child_elem;
child_elem.root_mem_offset = parent_elem.root_mem_offset +
child_idx * child_stride +
Expand All @@ -99,7 +99,7 @@ STR(
refine_coordinates(parent_elem,
runtime->snode_extractors[parent_snode_id],
child_idx, &child_elem);
append(&child_list, child_elem);
child_list.append(child_elem);
}
}
}
Expand Down
5 changes: 0 additions & 5 deletions taichi/backends/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,6 @@ STR(
atomic_int chunks[kTaichiNumChunks];
};

struct ListManager {
device ListManagerData *lm_data;
device MemoryAllocator *mem_alloc;
};

// This class is very similar to metal::SNodeDescriptor
struct SNodeMeta {
enum Type { Root = 0, Dense = 1, Bitmasked = 2, Dynamic = 3 };
Expand Down
160 changes: 85 additions & 75 deletions taichi/backends/metal/shaders/runtime_utils.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,7 @@
// clang-format off
METAL_BEGIN_RUNTIME_UTILS_DEF
STR(
using PtrOffset = int32_t;
constant constexpr int kAlignment = 8;
using PtrOffset = int32_t; constant constexpr int kAlignment = 8;

[[maybe_unused]] PtrOffset mtl_memalloc_alloc(device MemoryAllocator * ma,
int32_t size) {
Expand All @@ -46,86 +45,98 @@ STR(
return reinterpret_cast<device char *>(ma + 1) + offs;
}

[[maybe_unused]] int num_active(thread ListManager *l) {
return atomic_load_explicit(&(l->lm_data->next),
metal::memory_order_relaxed);
}
struct ListManager {
device ListManagerData *lm_data;
device MemoryAllocator *mem_alloc;

[[maybe_unused]] void clear(thread ListManager *l) {
atomic_store_explicit(&(l->lm_data->next), 0,
metal::memory_order_relaxed);
}
inline int num_active() {
return atomic_load_explicit(&(lm_data->next),
metal::memory_order_relaxed);
}

[[maybe_unused]] PtrOffset mtl_listmgr_ensure_chunk(thread ListManager *l,
int i) {
device ListManagerData *list = l->lm_data;
PtrOffset offs = 0;
const int kChunkBytes =
(list->element_stride << list->log2_num_elems_per_chunk);

while (true) {
int stored = 0;
// If chunks[i] is unallocated, i.e. 0, mark it as 1 to prevent others
// from requesting memory again. Once allocated, set chunks[i] to the
// actual address offset, which is guaranteed to be greater than 1.
const bool is_me = atomic_compare_exchange_weak_explicit(
list->chunks + i, &stored, 1, metal::memory_order_relaxed,
metal::memory_order_relaxed);
if (is_me) {
offs = mtl_memalloc_alloc(l->mem_alloc, kChunkBytes);
atomic_store_explicit(list->chunks + i, offs,
metal::memory_order_relaxed);
break;
} else if (stored > 1) {
offs = stored;
break;
}
// |stored| == 1, just spin
inline void resize(int sz) {
atomic_store_explicit(&(lm_data->next), sz,
metal::memory_order_relaxed);
}
return offs;
}

[[maybe_unused]] device char *mtl_listmgr_get_elem_from_chunk(
thread ListManager *l,
int i,
PtrOffset chunk_ptr_offs) {
device ListManagerData *list = l->lm_data;
device char *chunk_ptr = reinterpret_cast<device char *>(
mtl_memalloc_to_ptr(l->mem_alloc, chunk_ptr_offs));
const uint32_t mask = ((1 << list->log2_num_elems_per_chunk) - 1);
return chunk_ptr + ((i & mask) * list->element_stride);
}
inline void clear() {
resize(0);
}

[[maybe_unused]] device char *append(thread ListManager *l) {
device ListManagerData *list = l->lm_data;
const int elem_idx = atomic_fetch_add_explicit(
&list->next, 1, metal::memory_order_relaxed);
const int chunk_idx = elem_idx >> list->log2_num_elems_per_chunk;
const PtrOffset chunk_ptr_offs = mtl_listmgr_ensure_chunk(l, chunk_idx);
return mtl_listmgr_get_elem_from_chunk(l, elem_idx, chunk_ptr_offs);
}
struct ReserveElemResult {
int elem_idx;
PtrOffset chunk_ptr_offs;
};

ReserveElemResult reserve_new_elem() {
const int elem_idx = atomic_fetch_add_explicit(
&lm_data->next, 1, metal::memory_order_relaxed);
const int chunk_idx = elem_idx >> lm_data->log2_num_elems_per_chunk;
const PtrOffset chunk_ptr_offs = ensure_chunk(chunk_idx);
return {elem_idx, chunk_ptr_offs};
}

device char *append() {
auto reserved = reserve_new_elem();
return get_elem_from_chunk(reserved.elem_idx, reserved.chunk_ptr_offs);
}

template <typename T>
[[maybe_unused]] void append(thread ListManager *l, thread const T &elem) {
device char *ptr = append(l);
thread char *elem_ptr = (thread char *)(&elem);
template <typename T>
void append(thread const T &elem) {
device char *ptr = append();
thread char *elem_ptr = (thread char *)(&elem);

for (int i = 0; i < l->lm_data->element_stride; ++i) {
*ptr = *elem_ptr;
++ptr;
++elem_ptr;
for (int i = 0; i < lm_data->element_stride; ++i) {
*ptr = *elem_ptr;
++ptr;
++elem_ptr;
}
}
}

template <typename T>
[[maybe_unused]] T get(thread ListManager *l, int i) {
device ListManagerData *list = l->lm_data;
const int chunk_idx = i >> list->log2_num_elems_per_chunk;
const PtrOffset chunk_ptr_offs = atomic_load_explicit(
list->chunks + chunk_idx, metal::memory_order_relaxed);
return *reinterpret_cast<device T *>(
mtl_listmgr_get_elem_from_chunk(l, i, chunk_ptr_offs));
}
template <typename T>
T get(int i) {
const int chunk_idx = i >> lm_data->log2_num_elems_per_chunk;
const PtrOffset chunk_ptr_offs = atomic_load_explicit(
lm_data->chunks + chunk_idx, metal::memory_order_relaxed);
return *reinterpret_cast<device T *>(
get_elem_from_chunk(i, chunk_ptr_offs));
}

private:
PtrOffset ensure_chunk(int i) {
PtrOffset offs = 0;
const int kChunkBytes =
(lm_data->element_stride << lm_data->log2_num_elems_per_chunk);

while (true) {
int stored = 0;
// If chunks[i] is unallocated, i.e. 0, mark it as 1 to prevent others
// from requesting memory again. Once allocated, set chunks[i] to the
// actual address offset, which is guaranteed to be greater than 1.
const bool is_me = atomic_compare_exchange_weak_explicit(
lm_data->chunks + i, &stored, 1, metal::memory_order_relaxed,
metal::memory_order_relaxed);
if (is_me) {
offs = mtl_memalloc_alloc(mem_alloc, kChunkBytes);
atomic_store_explicit(lm_data->chunks + i, offs,
metal::memory_order_relaxed);
break;
} else if (stored > 1) {
offs = stored;
break;
}
// |stored| == 1, just spin
}
return offs;
}

device char *get_elem_from_chunk(int i, PtrOffset chunk_ptr_offs) {
device char *chunk_ptr = reinterpret_cast<device char *>(
mtl_memalloc_to_ptr(mem_alloc, chunk_ptr_offs));
const uint32_t mask = ((1 << lm_data->log2_num_elems_per_chunk) - 1);
return chunk_ptr + ((i & mask) * lm_data->element_stride);
}
};

[[maybe_unused]] int is_active(device byte *addr, SNodeMeta meta, int i) {
if (meta.type == SNodeMeta::Root || meta.type == SNodeMeta::Dense) {
Expand Down Expand Up @@ -207,8 +218,7 @@ STR(
device auto *n_ptr = reinterpret_cast<device atomic_int *>(
addr + (meta.num_slots * meta.element_stride));
return atomic_load_explicit(n_ptr, metal::memory_order_relaxed);
}
)
})
METAL_END_RUNTIME_UTILS_DEF
// clang-format on

Expand Down
6 changes: 3 additions & 3 deletions taichi/backends/metal/struct_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ namespace shaders {

} // namespace shaders

constexpr size_t kListManagerSize = sizeof(shaders::ListManager);
constexpr size_t kListManagerDataSize = sizeof(shaders::ListManagerData);
constexpr size_t kSNodeMetaSize = sizeof(shaders::SNodeMeta);
constexpr size_t kSNodeExtractorsSize = sizeof(shaders::SNodeExtractors);

Expand Down Expand Up @@ -226,8 +226,8 @@ class StructCompiler {
}

size_t compute_runtime_size() {
size_t result = (max_snodes_) *
(kSNodeMetaSize + kSNodeExtractorsSize + kListManagerSize);
size_t result = (max_snodes_) * (kSNodeMetaSize + kSNodeExtractorsSize +
kListManagerDataSize);
result += sizeof(uint32_t) * kNumRandSeeds;
return result;
}
Expand Down