Skip to content

Commit

Permalink
[metal] Refactor sparse shader impl in prep for pointer SNode (#1994)
Browse files Browse the repository at this point in the history
* [metal] Refactor sparse shader impl in prep for pointer SNode

* Update taichi/backends/metal/shaders/runtime_structs.metal.h

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>

Co-authored-by: Yuanming Hu <yuanming-hu@users.noreply.github.com>
  • Loading branch information
k-ye and yuanming-hu authored Oct 26, 2020
1 parent 6dcb69b commit 16af509
Show file tree
Hide file tree
Showing 6 changed files with 115 additions and 74 deletions.
33 changes: 14 additions & 19 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ constexpr char kPrintAssertBufferName[] = "print_assert_addr";
constexpr char kPrintAllocVarName[] = "print_alloc_";
constexpr char kAssertRecorderVarName[] = "assert_rec_";
constexpr char kLinearLoopIndexName[] = "linear_loop_idx_";
constexpr char kListgenElemVarName[] = "listgen_elem_";
constexpr char kElementCoordsVarName[] = "elem_coords_";
constexpr char kRandStateVarName[] = "rand_state_";
constexpr char kMemAllocVarName[] = "mem_alloc_";
constexpr char kTlsBufferName[] = "tls_buffer_";
Expand Down Expand Up @@ -372,7 +372,7 @@ class KernelCodegen : public IRVisitor {
TI_ASSERT(stmt->index == 0);
emit("const int {} = {};", stmt_name, kLinearLoopIndexName);
} else if (type == TaskType::struct_for) {
emit("const int {} = {}.coords[{}];", stmt_name, kListgenElemVarName,
emit("const int {} = {}.at[{}];", stmt_name, kElementCoordsVarName,
stmt->index);
} else {
TI_NOT_IMPLEMENTED;
Expand Down Expand Up @@ -568,8 +568,6 @@ class KernelCodegen : public IRVisitor {
} else if (stmt->task_type == Type::gc) {
// Ignored
} else {
// struct_for is automatically lowered to ranged_for for dense snodes
// (#378). So we only need to support serial and range_for tasks.
TI_ERROR("Unsupported offload type={} on Metal arch", stmt->task_name());
}
is_top_level_ = true;
Expand Down Expand Up @@ -1015,32 +1013,29 @@ class KernelCodegen : public IRVisitor {
{
ScopedIndent s2(current_appender());
emit("const int parent_idx_ = (ii / child_num_slots);");
emit("if (parent_idx_ >= parent_list.num_active()) return;");
emit("if (parent_idx_ >= parent_list.num_active()) break;");
emit("const int child_idx_ = (ii % child_num_slots);");
emit(
"const auto parent_elem_ = "
"parent_list.get<ListgenElement>(parent_idx_);");
emit("device auto *parent_addr_ = {} + parent_elem_.root_mem_offset;",
kRootBufferName);
emit("if (!is_active(parent_addr_, parent_meta, child_idx_)) continue;");
emit("ListgenElement {};", kListgenElemVarName);
// No need to add mem_offset_in_parent, because place() always starts at 0
emit(
"{}.root_mem_offset = parent_elem_.root_mem_offset + child_idx_ * "
"child_stride;",
kListgenElemVarName);
"device auto *parent_addr_ = mtl_lgen_snode_addr(parent_elem_, {}, "
"{}, {});",
kRootBufferName, kRuntimeVarName, kMemAllocVarName);
emit("if (!is_active(parent_addr_, parent_meta, child_idx_)) continue;");
emit("ElementCoords {};", kElementCoordsVarName);
emit(
"refine_coordinates(parent_elem_, {}->snode_extractors[{}], "
"refine_coordinates(parent_elem_.coords, {}->snode_extractors[{}], "
"child_idx_, &{});",
kRuntimeVarName, sn_id, kListgenElemVarName);
kRuntimeVarName, sn_id, kElementCoordsVarName);

current_kernel_attribs_ = &ka;
const auto mtl_func_name = mtl_kernel_func_name(mtl_kernel_name);
std::vector<FuncParamLiteral> extra_func_params = {
{"thread const ListgenElement&", kListgenElemVarName},
{"thread const ElementCoords &", kElementCoordsVarName},
};
std::vector<std::string> extra_args = {
kListgenElemVarName,
kElementCoordsVarName,
};
if (used_tls) {
extra_func_params.push_back({"thread char*", kTlsBufferName});
Expand All @@ -1053,11 +1048,11 @@ class KernelCodegen : public IRVisitor {
current_kernel_attribs_ = nullptr;
}
emit("}}"); // closes for loop
current_appender().pop_indent();

if (used_tls) {
generate_tls_epilogue(stmt);
}

current_appender().pop_indent();
emit("}}\n"); // closes kernel

mtl_kernels_attribs()->push_back(ka);
Expand Down
2 changes: 1 addition & 1 deletion taichi/backends/metal/data_types.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ MetalDataType to_metal_type(DataType dt) {
METAL_CASE(u64);
METAL_CASE(unknown);
else {
TI_NOT_IMPLEMENTED;
TI_ERROR("[Metal] type={} not supported", data_type_name(dt));
}
#undef METAL_CASE
return MetalDataType::unknown;
Expand Down
5 changes: 3 additions & 2 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -793,9 +793,9 @@ class KernelManager::Impl {
mem_alloc->next = shaders::kAlignment;
// root list data are static
ListgenElement root_elem;
root_elem.root_mem_offset = 0;
root_elem.mem_offset = 0;
for (int i = 0; i < taichi_max_num_indices; ++i) {
root_elem.coords[i] = 0;
root_elem.coords.at[i] = 0;
}
ListManager root_lm;
root_lm.lm_data = rtm_list_begin + root_id;
Expand Down Expand Up @@ -880,6 +880,7 @@ class KernelManager::Impl {
print_mem_->ptr() + shaders::kMetalAssertBufferSize);
const int used_sz =
std::min(pa->next, shaders::kMetalPrintMsgsMaxQueueSize);
TI_TRACE("Print buffer used bytes: {}", used_sz);
using MsgType = shaders::PrintMsg::Type;
char *buf = reinterpret_cast<char *>(pa + 1);
const char *buf_end = buf + used_sz;
Expand Down
23 changes: 8 additions & 15 deletions taichi/backends/metal/shaders/runtime_kernels.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,6 @@

static_assert(false, "Do not include");

// Just a mock to illustrate what the Runtime looks like, do not use.
// The actual Runtime struct has to be emitted by codegen, because it depends
// on the number of SNodes.
struct Runtime {
SNodeMeta *snode_metas;
SNodeExtractors *snode_extractors;
ListManager *snode_lists;
};

#define METAL_BEGIN_RUNTIME_KERNELS_DEF
#define METAL_END_RUNTIME_KERNELS_DEF

Expand Down Expand Up @@ -77,15 +68,17 @@ STR(
}
const int child_idx = (ii % num_slots);
const auto parent_elem = parent_list.get<ListgenElement>(parent_idx);
device byte *parent_addr = root_addr + parent_elem.root_mem_offset;
device byte *parent_addr =
mtl_lgen_snode_addr(parent_elem, root_addr, runtime, mem_alloc);
if (is_active(parent_addr, parent_meta, child_idx)) {
ListgenElement child_elem;
child_elem.root_mem_offset = parent_elem.root_mem_offset +
child_idx * child_stride +
child_meta.mem_offset_in_parent;
refine_coordinates(parent_elem,
child_elem.mem_offset =
parent_elem.mem_offset + child_idx * child_stride;
child_elem.mem_offset += child_meta.mem_offset_in_parent;

refine_coordinates(parent_elem.coords,
runtime->snode_extractors[parent_snode_id],
child_idx, &child_elem);
child_idx, &(child_elem.coords));
child_list.append(child_elem);
}
}
Expand Down
87 changes: 64 additions & 23 deletions taichi/backends/metal/shaders/runtime_structs.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,11 +34,6 @@ STR(

struct MemoryAllocator { atomic_int next; };

struct ListgenElement {
int32_t coords[kTaichiMaxNumIndices];
int32_t root_mem_offset = 0;
};

// ListManagerData manages a list of elements with adjustable size.
struct ListManagerData {
int32_t element_stride = 0;
Expand All @@ -54,17 +49,11 @@ STR(
// NodeManagerData stores the actual data needed to implement NodeManager
// in Metal buffers.
//
// There are several level of indirections here to retrieve an allocated
// element from a NodeManager. The actual allocated elements are not
// embedded in the memory region of NodeManagerData. Instead, all this data
// structure does is to maintain a few lists (ListManagerData).
//
// However, these lists do not store the actual data, either. Instead, their
// elements are just 32-bit integers, which are memory offsets (PtrOffset)
// in a Metal buffer. That buffer to which these offsets point holds the
// actual data.
// The actual allocated elements are not embedded in the memory region of
// NodeManagerData. Instead, all this data structure does is to maintain a
// few lists (ListManagerData). In particular, |data_list| stores the actual
// data, while |free_list| and |recycle_list| are only meant for GC.
struct NodeManagerData {
using ElemIndex = int32_t;
// Stores the actual data.
ListManagerData data_list;
// For GC
Expand All @@ -73,14 +62,51 @@ STR(
atomic_int free_list_used;
// Need this field to bookkeep some data during GC
int recycled_list_size_backup;
// The first 8 index values are reserved to encode special status:
// * 0 : nullptr
// * 1 : spinning for allocation
// * 2-7: unused for now
//
/// For each allocated index, it is added by |index_offset| to skip over
/// these reserved values.
constant static constexpr ElemIndex kIndexOffset = 8;

// Use this type instead of the raw index type (int32_t), because the
// raw value needs to be shifted by |kIndexOffset| in order for the
// spinning memory allocation algorithm to work.
struct ElemIndex {
// The first 8 index values are reserved to encode special status:
// * 0 : nullptr
// * 1 : spinning for allocation
// * 2-7: unused for now
//
/// For each allocated index, it is added by |index_offset| to skip over
/// these reserved values.
constant static constexpr int32_t kIndexOffset = 8;

ElemIndex() = default;

static ElemIndex from_index(int i) {
return ElemIndex(i + kIndexOffset);
}

static ElemIndex from_raw(int r) {
return ElemIndex(r);
}

inline int32_t index() const {
return raw_ - kIndexOffset;
}

inline int32_t raw() const {
return raw_;
}

inline bool is_valid() const {
return raw_ >= kIndexOffset;
}

inline static bool is_valid(int raw) {
return ElemIndex::from_raw(raw).is_valid();
}

private:
explicit ElemIndex(int r) : raw_(r) {
}
int32_t raw_ = 0;
};
};

// This class is very similar to metal::SNodeDescriptor
Expand All @@ -102,6 +128,21 @@ STR(

Extractor extractors[kTaichiMaxNumIndices];
};

struct ElementCoords { int32_t at[kTaichiMaxNumIndices]; };

struct ListgenElement {
ElementCoords coords;
// Memory offset from a given address.
// * If in_root_buffer() is true, this is from the root buffer address.
// * O/W this is from the |id|-th NodeManager's |elem_idx|-th element.
int32_t mem_offset = 0;

inline bool in_root_buffer() const {
// Placeholder impl
return true;
}
};
// clang-format off
)
METAL_END_RUNTIME_STRUCTS_DEF
Expand Down
39 changes: 25 additions & 14 deletions taichi/backends/metal/shaders/runtime_utils.metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,16 @@

#else

// Just a mock to illustrate what the Runtime looks like, do not use.
// The actual Runtime struct has to be emitted by codegen, because it depends
// on the number of SNodes.
struct Runtime {
SNodeMeta *snode_metas = nullptr;
SNodeExtractors *snode_extractors = nullptr;
ListManagerData *snode_lists = nullptr;
uint32_t *rand_seeds = nullptr;
};

#define METAL_BEGIN_RUNTIME_UTILS_DEF
#define METAL_END_RUNTIME_UTILS_DEF

Expand Down Expand Up @@ -148,10 +158,6 @@ STR(
device NodeManagerData *nm_data;
device MemoryAllocator *mem_alloc;

static inline bool is_valid(ElemIndex i) {
return i >= NodeManagerData::kIndexOffset;
}

ElemIndex allocate() {
ListManager free_list;
free_list.lm_data = &(nm_data->free_list);
Expand All @@ -165,21 +171,19 @@ STR(
if (cur_used < free_list.num_active()) {
return free_list.get<ElemIndex>(cur_used);
}
// Shift by |kIndexOffset| to skip special encoded values.
return data_list.reserve_new_elem().elem_idx +
NodeManagerData::kIndexOffset;

return ElemIndex::from_index(data_list.reserve_new_elem().elem_idx);
}

device byte *get(ElemIndex i) {
ListManager data_list;
data_list.lm_data = &(nm_data->data_list);
data_list.mem_alloc = mem_alloc;

return data_list.get_ptr(i - NodeManagerData::kIndexOffset);
return data_list.get_ptr(i.index());
}

void recycle(ElemIndex i) {
// Precondition: |i| is shifted by |kIndexOffset|.
ListManager recycled_list;
recycled_list.lm_data = &(nm_data->recycled_list);
recycled_list.mem_alloc = mem_alloc;
Expand Down Expand Up @@ -322,16 +326,23 @@ STR(
}

[[maybe_unused]] void refine_coordinates(
thread const ListgenElement &parent_elem,
device const SNodeExtractors &child_extrators,
int l,
thread ListgenElement *child_elem) {
thread const ElementCoords &parent,
device const SNodeExtractors &child_extrators, int l,
thread ElementCoords *child) {
for (int i = 0; i < kTaichiMaxNumIndices; ++i) {
device const auto &ex = child_extrators.extractors[i];
const int mask = ((1 << ex.num_bits) - 1);
const int addition = (((l >> ex.acc_offset) & mask) << ex.start);
child_elem->coords[i] = (parent_elem.coords[i] | addition);
child->at[i] = (parent.at[i] | addition);
}
}

// Gets the address of an SNode cell identified by |lgen|.
[[maybe_unused]] device byte *mtl_lgen_snode_addr(
thread const ListgenElement &lgen, device byte *root_addr,
device Runtime *rtm, device MemoryAllocator *mem_alloc) {
// Placeholder impl
return root_addr + lgen.mem_offset;
})
METAL_END_RUNTIME_UTILS_DEF
// clang-format on
Expand Down

0 comments on commit 16af509

Please sign in to comment.