Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Metal] Support assert() #1959

Merged
merged 2 commits into from
Oct 16, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 54 additions & 5 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,9 @@ constexpr char kContextBufferName[] = "ctx_addr";
constexpr char kContextVarName[] = "kernel_ctx_";
constexpr char kRuntimeBufferName[] = "runtime_addr";
constexpr char kRuntimeVarName[] = "runtime_";
constexpr char kPrintBufferName[] = "print_addr";
constexpr char kPrintAssertBufferName[] = "print_assert_addr";
constexpr char kPrintAllocVarName[] = "print_alloc_";
constexpr char kAssertRecorderVarName[] = "assert_rec_";
constexpr char kLinearLoopIndexName[] = "linear_loop_idx_";
constexpr char kListgenElemVarName[] = "listgen_elem_";
constexpr char kRandStateVarName[] = "rand_state_";
Expand All @@ -60,7 +61,7 @@ std::string buffer_to_name(BuffersEnum b) {
case BuffersEnum::Runtime:
return kRuntimeBufferName;
case BuffersEnum::Print:
return kPrintBufferName;
return kPrintAssertBufferName;
default:
TI_NOT_IMPLEMENTED;
break;
Expand Down Expand Up @@ -638,6 +639,46 @@ class KernelCodegen : public IRVisitor {
emit("}}");
}

void visit(AssertStmt *stmt) override {
used_features()->assertion = true;

const auto &args = stmt->args;
// +1 because the assertion message template itself takes one slot
const auto num_args = args.size() + 1;
TI_ASSERT_INFO(num_args <= shaders::kMetalMaxNumAssertArgs,
"[Metal] Too many args in assert()");
emit("if (!({})) {{", stmt->cond->raw_name());
{
ScopedIndent s(current_appender());
// Only record the message for the first-time assertion failure.
emit("if ({}.mark_first_failure()) {{", kAssertRecorderVarName);
{
ScopedIndent s2(current_appender());
emit("{}.set_num_args({});", kAssertRecorderVarName, num_args);
const std::string asst_var_name = stmt->raw_name() + "_msg_";
emit("PrintMsg {}({}.msg_buf_addr(), {});", asst_var_name,
kAssertRecorderVarName, num_args);
const int msg_str_id = print_strtab_->put(stmt->text);
emit("{}.pm_set_str(/*i=*/0, {});", asst_var_name, msg_str_id);
for (int i = 1; i < num_args; ++i) {
auto *arg = args[i - 1];
const auto ty = arg->element_type();
if (ty == PrimitiveType::i32 || ty == PrimitiveType::f32) {
emit("{}.pm_set_{}({}, {});", asst_var_name,
data_type_short_name(ty), i, arg->raw_name());
} else {
TI_ERROR(
"[Metal] assert() only supports i32 or f32 scalars for now.");
}
}
}
emit("}}");
// This has failed, no point executing the rest of the kernel.
emit("return;");
}
emit("}}");
}

void visit(StackAllocaStmt *stmt) override {
TI_ASSERT(stmt->width() == 1);

Expand Down Expand Up @@ -1124,9 +1165,17 @@ class KernelCodegen : public IRVisitor {
fmt::arg("rtm", kRuntimeVarName),
fmt::arg("lidx", kLinearLoopIndexName),
fmt::arg("nums", kNumRandSeeds));
// Init PrintMsgAllocator
emit("device auto* {} = reinterpret_cast<device PrintMsgAllocator*>({});",
kPrintAllocVarName, kPrintBufferName);
// Init AssertRecorder.
emit("AssertRecorder {}({});", kAssertRecorderVarName,
kPrintAssertBufferName);
// Init PrintMsgAllocator.
// The print buffer comes after (AssertRecorder + assert message buffer),
// therefore we skip by +|kMetalAssertBufferSize|.
emit(
"device auto* {} = reinterpret_cast<device PrintMsgAllocator*>({} + "
"{});",
kPrintAllocVarName, kPrintAssertBufferName,
shaders::kMetalAssertBufferSize);
}
// We do not need additional indentation, because |func_ir| itself is a
// block, which will be indented automatically.
Expand Down
87 changes: 69 additions & 18 deletions taichi/backends/metal/kernel_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#include "taichi/util/action_recorder.h"
#include "taichi/python/print_buffer.h"
#include "taichi/util/file_sequence_writer.h"
#include "taichi/util/str.h"

#ifdef TI_PLATFORM_OSX
#include <sys/mman.h>
Expand Down Expand Up @@ -53,7 +54,7 @@ class BufferMemoryView {
BufferMemoryView(size_t size, MemoryPool *mem_pool) {
// Both |ptr_| and |size_| must be aligned to page size.
size_ = iroundup(size, taichi_page_size);
ptr_ = mem_pool->allocate(size_, /*alignment=*/taichi_page_size);
ptr_ = (char *)mem_pool->allocate(size_, /*alignment=*/taichi_page_size);
TI_ASSERT(ptr_ != nullptr);
std::memset(ptr_, 0, size_);
}
Expand All @@ -66,13 +67,13 @@ class BufferMemoryView {
inline size_t size() const {
return size_;
}
inline void *ptr() const {
inline char *ptr() const {
return ptr_;
}

private:
size_t size_;
void *ptr_;
char *ptr_;
};

// MetalRuntime maintains a series of MTLBuffers that are shared across all the
Expand Down Expand Up @@ -578,14 +579,13 @@ class KernelManager::Impl {
"Failed to allocate Metal runtime buffer, requested {} bytes",
runtime_mem_->size());
print_mem_ = std::make_unique<BufferMemoryView>(
sizeof(shaders::PrintMsgAllocator) + shaders::kMetalPrintBufferSize,
mem_pool_);
shaders::kMetalPrintAssertBufferSize, mem_pool_);
print_buffer_ = new_mtl_buffer_no_copy(device_.get(), print_mem_->ptr(),
print_mem_->size());
TI_ASSERT(print_buffer_ != nullptr);

init_runtime(params.root_id);
init_print_buffer();
clear_print_assert_buffer();
}

void register_taichi_kernel(const std::string &taichi_kernel_name,
Expand Down Expand Up @@ -640,24 +640,30 @@ class KernelManager::Impl {
for (const auto &mk : ctk.compiled_mtl_kernels) {
mk->launch(input_buffers, cur_command_buffer_.get());
}
const bool used_print = ctk.ti_kernel_attribs.used_features.print;
if (ctx_blitter || used_print) {

const auto &used = ctk.ti_kernel_attribs.used_features;
const bool used_print_assert = (used.print || used.assertion);
if (ctx_blitter || used_print_assert) {
// TODO(k-ye): One optimization is to synchronize only when we absolutely
// need to transfer the data back to host. This includes the cases where
// an arg is 1) an array, or 2) used as return value.
std::vector<MTLBuffer *> buffers_to_blit;
if (ctx_blitter) {
buffers_to_blit.push_back(ctx_blitter->ctx_buffer());
}
if (used_print) {
if (used_print_assert) {
clear_print_assert_buffer();
buffers_to_blit.push_back(print_buffer_.get());
}
blit_buffers_and_sync(buffers_to_blit);

if (ctx_blitter) {
ctx_blitter->metal_to_host();
}
if (used_print) {
if (used.assertion) {
check_assertion_failure();
}
if (used.print) {
flush_print_buffers();
}
}
Expand Down Expand Up @@ -801,9 +807,10 @@ class KernelManager::Impl {
runtime_mem_->size());
}

void init_print_buffer() {
// TODO(k-ye): Do we need this at all?
did_modify_range(print_buffer_.get(), /*location=*/0, print_mem_->size());
void clear_print_assert_buffer() {
const auto sz = print_mem_->size();
std::memset(print_mem_->ptr(), 0, sz);
did_modify_range(print_buffer_.get(), /*location=*/0, sz);
}

void blit_buffers_and_sync(
Expand All @@ -828,10 +835,51 @@ class KernelManager::Impl {
profiler_->stop();
}

void check_assertion_failure() {
// TODO: Copy this to program's result_buffer, and let the Taichi runtime
// handle the assertion failures uniformly.
auto *asst_rec =
reinterpret_cast<shaders::AssertRecorderData *>(print_mem_->ptr());
if (!asst_rec->flag) {
return;
}
auto *msg_ptr = reinterpret_cast<int32_t *>(asst_rec + 1);
shaders::PrintMsg msg(msg_ptr, asst_rec->num_args);
using MsgType = shaders::PrintMsg::Type;
TI_ASSERT(msg.pm_get_type(0) == MsgType::Str);
const auto fmt_str = print_strtable_.get(msg.pm_get_data(0));
const auto err_str = format_error_message(fmt_str, [&msg](int argument_id) {
// +1 to skip the first arg, which is the error message template.
const int32 x = msg.pm_get_data(argument_id + 1);
return taichi_union_cast_with_different_sizes<uint64>(x);
});
// Note that we intentionally comment out the flag reset below, because it
// is ineffective at all. This is a very tricky part:
// 1. Under .managed storage mode, we need to call [didModifyRange:] to sync
// buffer data from CPU -> GPU. So ideally, after resetting the flag, we
// should just do so.
// 2. However, during the assertion (TI_ERROR), the stack unwinding seems to
// have deviated from the normal execution path. As a result, if we put
// [didModifyRange:] after TI_ERROR, it doesn't get executed...
// 3. The reason we put [didModifyRange:] after TI_ERROR is because we
// should do so after flush_print_buffers():
//
// check_assertion_failure(); <-- Code below is skipped...
// flush_print_buffers();
// memset(print_mem_->ptr(), 0, print_mem_->size());
// did_modify_range(print_buffer_);
//
// As a workaround, we put [didModifyRange:] before sync, where the program
// is still executing normally.
// asst_rec->flag = 0;
TI_ERROR("Assertion failure: {}", err_str);
}

void flush_print_buffers() {
auto *pa =
reinterpret_cast<shaders::PrintMsgAllocator *>(print_mem_->ptr());
const int used_sz = std::min(pa->next, shaders::kMetalPrintBufferSize);
auto *pa = reinterpret_cast<shaders::PrintMsgAllocator *>(
print_mem_->ptr() + shaders::kMetalAssertBufferSize);
const int used_sz =
std::min(pa->next, shaders::kMetalPrintMsgsMaxQueueSize);
using MsgType = shaders::PrintMsg::Type;
char *buf = reinterpret_cast<char *>(pa + 1);
const char *buf_end = buf + used_sz;
Expand All @@ -857,11 +905,13 @@ class KernelManager::Impl {
buf += shaders::mtl_compute_print_msg_bytes(num_entries);
}

if (pa->next >= shaders::kMetalPrintBufferSize) {
if (pa->next >= shaders::kMetalPrintMsgsMaxQueueSize) {
py_cout << "...(maximum print buffer reached)\n";
}

pa->next = 0;
// Comment out intentionally since it is ineffective otherwise. See
// check_assertion_failure() for the explanation.
// pa->next = 0;
}

static int compute_num_elems_per_chunk(int n) {
Expand Down Expand Up @@ -902,6 +952,7 @@ class KernelManager::Impl {
nsobj_unique_ptr<MTLBuffer> global_tmps_buffer_;
std::unique_ptr<BufferMemoryView> runtime_mem_;
nsobj_unique_ptr<MTLBuffer> runtime_buffer_;
// TODO: Rename these to 'print_assert_{mem|buffer}_'
std::unique_ptr<BufferMemoryView> print_mem_;
nsobj_unique_ptr<MTLBuffer> print_buffer_;
std::unordered_map<std::string, std::unique_ptr<CompiledTaichiKernel>>
Expand Down
2 changes: 2 additions & 0 deletions taichi/backends/metal/kernel_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,8 @@ struct TaichiKernelAttributes {
struct UsedFeatures {
// Whether print() is called inside this kernel.
bool print = false;
// Whether assert is called inside this kernel.
bool assertion = false;
// Whether this kernel accesses (read or write) sparse SNodes.
bool sparse = false;
// Whether [[thread_index_in_simdgroup]] is used. This is only supported
Expand Down
Loading