Skip to content

Commit

Permalink
[async] [metal] Support async mode on Metal (#1920)
Browse files Browse the repository at this point in the history
  • Loading branch information
k-ye authored Oct 4, 2020
1 parent 7db99e0 commit 08bb052
Show file tree
Hide file tree
Showing 11 changed files with 99 additions and 89 deletions.
63 changes: 31 additions & 32 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
#include <functional>
#include <string>

#include "taichi/backends/metal/api.h"
#include "taichi/backends/metal/constants.h"
#include "taichi/backends/metal/env_config.h"
#include "taichi/backends/metal/features.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/transforms.h"
#include "taichi/util/line_appender.h"
#include "taichi/math/arithmetic.h"
#include "taichi/backends/metal/api.h"
#include "taichi/util/line_appender.h"

TLANG_NAMESPACE_BEGIN
namespace metal {
Expand Down Expand Up @@ -83,21 +84,26 @@ class KernelCodegen : public IRVisitor {
};

public:
struct Config {
bool allow_simdgroup = true;
};
// TODO(k-ye): Create a Params to hold these ctor params.
KernelCodegen(const std::string &taichi_kernel_name,
const std::string &root_snode_type_name,
Kernel *kernel,
const CompiledStructs *compiled_structs,
PrintStringTable *print_strtab,
const CodeGen::Config &config)
const Config &config,
OffloadedStmt *offloaded)
: mtl_kernel_prefix_(taichi_kernel_name),
root_snode_type_name_(root_snode_type_name),
kernel_(kernel),
compiled_structs_(compiled_structs),
needs_root_buffer_(compiled_structs_->root_size > 0),
ctx_attribs_(*kernel_),
print_strtab_(print_strtab),
cgen_config_(config) {
cgen_config_(config),
offloaded_(offloaded) {
ti_kernel_attribus_.name = taichi_kernel_name;
ti_kernel_attribus_.is_jit_evaluator = kernel->is_evaluator;
// allow_undefined_visitor = true;
Expand Down Expand Up @@ -769,7 +775,8 @@ class KernelCodegen : public IRVisitor {

void generate_kernels() {
SectionGuard sg(this, Section::Kernels);
kernel_->ir->accept(this);
IRNode *ast = offloaded_ ? offloaded_ : kernel_->ir.get();
ast->accept(this);

if (used_features()->sparse) {
emit("");
Expand Down Expand Up @@ -1212,7 +1219,8 @@ class KernelCodegen : public IRVisitor {
const bool needs_root_buffer_;
const KernelContextAttributes ctx_attribs_;
PrintStringTable *const print_strtab_;
const CodeGen::Config &cgen_config_;
const Config &cgen_config_;
OffloadedStmt *const offloaded_;

bool is_top_level_{true};
int mtl_kernel_count_{0};
Expand All @@ -1226,36 +1234,27 @@ class KernelCodegen : public IRVisitor {

} // namespace

CodeGen::CodeGen(Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs,
const Config &config)
: kernel_(kernel),
kernel_mgr_(kernel_mgr),
compiled_structs_(compiled_structs),
id_(Program::get_kernel_id()),
taichi_kernel_name_(fmt::format("mtl_k{:04d}_{}", id_, kernel_->name)),
config_(config) {
}
FunctionType compile_to_metal_executable(
Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs,
OffloadedStmt *offloaded) {
const auto id = Program::get_kernel_id();
const auto taichi_kernel_name(
fmt::format("mtl_k{:04d}_{}", id, kernel->name));

FunctionType CodeGen::compile() {
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(kernel_->ir.get(), config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/true, config.print_ir,
/*lower_global_access=*/true,
/*make_thread_local=*/config.make_thread_local);
KernelCodegen::Config cgen_config;
cgen_config.allow_simdgroup = EnvConfig::instance().is_simdgroup_enabled();

KernelCodegen codegen(
taichi_kernel_name_, kernel_->program.snode_root->node_type_name, kernel_,
compiled_structs_, kernel_mgr_->print_strtable(), config_);
taichi_kernel_name, kernel->program.snode_root->node_type_name, kernel,
compiled_structs, kernel_mgr->print_strtable(), cgen_config, offloaded);

const auto source_code = codegen.run();
kernel_mgr_->register_taichi_kernel(taichi_kernel_name_, source_code,
codegen.ti_kernels_attribs(),
codegen.kernel_ctx_attribs());
return [kernel_mgr = kernel_mgr_,
kernel_name = taichi_kernel_name_](Context &ctx) {
kernel_mgr->register_taichi_kernel(taichi_kernel_name, source_code,
codegen.ti_kernels_attribs(),
codegen.kernel_ctx_attribs());
return [kernel_mgr, kernel_name = taichi_kernel_name](Context &ctx) {
kernel_mgr->launch_taichi_kernel(kernel_name, &ctx);
};
}
Expand Down
32 changes: 8 additions & 24 deletions taichi/backends/metal/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,14 @@
TLANG_NAMESPACE_BEGIN
namespace metal {

class CodeGen {
public:
struct Config {
bool allow_simdgroup = true;
};

CodeGen(Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs,
const Config &config);

FunctionType compile();

private:
void lower();
FunctionType gen(const SNode &root_snode, KernelManager *runtime);

Kernel *const kernel_;
KernelManager *const kernel_mgr_;
const CompiledStructs *const compiled_structs_;
const int id_;
const std::string taichi_kernel_name_;
const Config config_;
};
// If |offloaded| is nullptr, this compiles the AST in |kernel|. Otherwise it
// compiles just |offloaded|. These ASTs must have already been lowered at the
// CHI level.
FunctionType compile_to_metal_executable(
Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs,
OffloadedStmt *offloaded = nullptr);

} // namespace metal

Expand Down
17 changes: 10 additions & 7 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) {
auto cloned_stmt = ker.ir_handle.clone();
stmt = cloned_stmt->as<OffloadedStmt>();

compilation_workers.enqueue([async_func, stmt, kernel]() {
compilation_workers.enqueue([async_func, stmt, kernel, this]() {
{
// Final lowering
using namespace irpass;
Expand All @@ -143,8 +143,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) {
is_extension_supported(config.arch, Extension::bls) &&
config.make_block_local);
}
auto codegen = KernelCodeGen::create(kernel->arch, kernel, stmt);
auto func = codegen->codegen();
auto func = this->compile_to_backend_(*kernel, stmt);
async_func->set(func);
});
ir_bank_->insert_to_trash_bin(std::move(cloned_stmt));
Expand All @@ -161,14 +160,18 @@ void ExecutionQueue::synchronize() {
launch_worker.flush();
}

ExecutionQueue::ExecutionQueue(IRBank *ir_bank)
ExecutionQueue::ExecutionQueue(
IRBank *ir_bank,
const BackendExecCompilationFunc &compile_to_backend)
: compilation_workers(4), // TODO: remove 4
launch_worker(1),
ir_bank_(ir_bank) {
ir_bank_(ir_bank),
compile_to_backend_(compile_to_backend) {
}

AsyncEngine::AsyncEngine(Program *program)
: queue(&ir_bank_),
AsyncEngine::AsyncEngine(Program *program,
const BackendExecCompilationFunc &compile_to_backend)
: queue(&ir_bank_, compile_to_backend),
program(program),
sfg(std::make_unique<StateFlowGraph>(&ir_bank_)) {
}
Expand Down
12 changes: 10 additions & 2 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <atomic>
#include <condition_variable>
#include <deque>
#include <functional>
#include <future>
#include <mutex>
#include <thread>
Expand Down Expand Up @@ -70,6 +71,10 @@ class ParallelExecutor {
std::condition_variable flush_cv_;
};

// Compiles the offloaded and optimized IR to the target backend's executable.
using BackendExecCompilationFunc =
std::function<FunctionType(Kernel &, OffloadedStmt *)>;

// In charge of (parallel) compilation to binary and (serial) kernel launching
class ExecutionQueue {
public:
Expand All @@ -78,7 +83,8 @@ class ExecutionQueue {
ParallelExecutor compilation_workers; // parallel compilation
ParallelExecutor launch_worker; // serial launching

explicit ExecutionQueue(IRBank *ir_bank);
explicit ExecutionQueue(IRBank *ir_bank,
const BackendExecCompilationFunc &compile_to_backend);

void enqueue(const TaskLaunchRecord &ker);

Expand Down Expand Up @@ -117,6 +123,7 @@ class ExecutionQueue {
std::unordered_map<uint64, AsyncCompiledFunc> compiled_funcs_;

IRBank *ir_bank_; // not owned
BackendExecCompilationFunc compile_to_backend_;
};

// An engine for asynchronous execution and optimization
Expand All @@ -130,7 +137,8 @@ class AsyncEngine {
std::unique_ptr<StateFlowGraph> sfg;
std::deque<TaskLaunchRecord> task_queue;

explicit AsyncEngine(Program *program);
explicit AsyncEngine(Program *program,
const BackendExecCompilationFunc &compile_to_backend);

bool fuse(); // return true when modified

Expand Down
2 changes: 1 addition & 1 deletion taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ bool is_extension_supported(Arch arch, Extension ext) {
{Arch::cuda,
{Extension::sparse, Extension::async_mode, Extension::data64,
Extension::adstack, Extension::bls, Extension::assertion}},
{Arch::metal, {Extension::adstack}},
{Arch::metal, {Extension::adstack, Extension::async_mode}},
{Arch::opengl, {Extension::extfunc}},
{Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}},
};
Expand Down
9 changes: 4 additions & 5 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ void Kernel::compile() {
void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class
// necessary for each backend?
TI_ASSERT(!lowered);
if (arch_is_cpu(arch) || arch == Arch::cuda) {
if (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::metal) {
CurrentKernelGuard _(program, this);
auto codegen = KernelCodeGen::create(arch, this);
auto config = program.config;
bool verbose = config.print_ir;
if ((is_accessor && !config.print_accessor_ir) ||
Expand All @@ -81,9 +80,9 @@ void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class
if (to_executable) {
irpass::compile_to_executable(
ir.get(), config, /*vectorize*/ arch_is_cpu(arch), grad,
/*ad_use_stack*/ true, verbose, /*lower_global_access*/ to_executable,
/*make_thread_local*/ config.make_thread_local,
/*make_block_local*/
/*ad_use_stack=*/true, verbose, /*lower_global_access=*/to_executable,
/*make_thread_local=*/config.make_thread_local,
/*make_block_local=*/
is_extension_supported(config.arch, Extension::bls) &&
config.make_block_local);
} else {
Expand Down
35 changes: 22 additions & 13 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "taichi/backends/cuda/cuda_context.h"
#endif
#include "taichi/backends/metal/codegen_metal.h"
#include "taichi/backends/metal/env_config.h"
#include "taichi/backends/opengl/codegen_opengl.h"
#include "taichi/backends/cpu/codegen_cpu.h"
#include "taichi/struct/struct.h"
Expand Down Expand Up @@ -146,8 +145,11 @@ Program::Program(Arch desired_arch) {

if (config.async_mode) {
TI_WARN("Running in async mode. This is experimental.");
TI_ASSERT(arch_is_cpu(config.arch) || config.arch == Arch::cuda);
async_engine = std::make_unique<AsyncEngine>(this);
TI_ASSERT(is_extension_supported(config.arch, Extension::async_mode));
async_engine = std::make_unique<AsyncEngine>(
this, [this](Kernel &kernel, OffloadedStmt *offloaded) {
return this->compile_to_backend_executable(kernel, offloaded);
});
}

// TODO: allow users to run in debug mode without out-of-bound checks
Expand Down Expand Up @@ -196,17 +198,10 @@ FunctionType Program::compile(Kernel &kernel) {
auto start_t = Time::get_time();
TI_AUTO_PROF;
FunctionType ret = nullptr;
if (arch_is_cpu(kernel.arch) || kernel.arch == Arch::cuda) {
if (arch_is_cpu(kernel.arch) || kernel.arch == Arch::cuda ||
kernel.arch == Arch::metal) {
kernel.lower();
auto codegen = KernelCodeGen::create(kernel.arch, &kernel);
ret = codegen->compile();
} else if (kernel.arch == Arch::metal) {
metal::CodeGen::Config cgen_config;
cgen_config.allow_simdgroup =
metal::EnvConfig::instance().is_simdgroup_enabled();
metal::CodeGen codegen(&kernel, metal_kernel_mgr_.get(),
&metal_compiled_structs_.value(), cgen_config);
ret = codegen.compile();
ret = compile_to_backend_executable(kernel, /*offloaded=*/nullptr);
} else if (kernel.arch == Arch::opengl) {
opengl::OpenglCodeGen codegen(kernel.name, &opengl_struct_compiled_.value(),
opengl_kernel_launcher_.get());
Expand All @@ -223,6 +218,20 @@ FunctionType Program::compile(Kernel &kernel) {
return ret;
}

FunctionType Program::compile_to_backend_executable(Kernel &kernel,
OffloadedStmt *offloaded) {
if (arch_is_cpu(kernel.arch) || kernel.arch == Arch::cuda) {
auto codegen = KernelCodeGen::create(kernel.arch, &kernel, offloaded);
return codegen->compile();
} else if (kernel.arch == Arch::metal) {
return metal::compile_to_metal_executable(&kernel, metal_kernel_mgr_.get(),
&metal_compiled_structs_.value(),
offloaded);
}
TI_NOT_IMPLEMENTED;
return nullptr;
}

// For CPU and CUDA archs only
void Program::initialize_runtime_system(StructCompiler *scomp) {
// auto tlctx = llvm_context_host.get();
Expand Down
6 changes: 6 additions & 0 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,14 @@ class Program {
void end_function_definition() {
}

// TODO: This function is doing two things: 1) compiling CHI IR, and 2)
// offloading them to each backend. We should probably separate the logic?
FunctionType compile(Kernel &kernel);

// Just does the per-backend executable compilation without kernel lowering.
FunctionType compile_to_backend_executable(Kernel &kernel,
OffloadedStmt *stmt);

void initialize_runtime_system(StructCompiler *scomp);

void materialize_layout();
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_fuse_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
template_fuse_reduction


@ti.archs_with([ti.cpu], async_mode=True)
@ti.test(require=ti.extension.async_mode, async_mode=True)
def test_fuse_dense_x2y2z():
template_fuse_dense_x2y2z(size=100 * 1024**2)


@ti.archs_with([ti.cpu], async_mode=True)
@ti.test(require=ti.extension.async_mode, async_mode=True)
def test_fuse_reduction():
template_fuse_reduction(size=10 * 1024**2)


@ti.archs_with([ti.cpu], async_mode=True)
@ti.test(require=ti.extension.async_mode, async_mode=True)
def test_no_fuse_sigs_mismatch():
n = 4096
x = ti.field(ti.i32, shape=(n, ))
Expand Down
3 changes: 2 additions & 1 deletion tests/python/test_fuse_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def y_to_z():
assert z[i] == x[i] + 5


@ti.archs_with([ti.cpu], async_mode=True)
@ti.test(require=[ti.extension.async_mode, ti.extension.sparse],
async_mode=True)
def test_fuse_dynamic_x2y2z():
benchmark_fuse_dynamic_x2y2z()
3 changes: 2 additions & 1 deletion tests/python/test_sfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pytest


@ti.test(require=ti.extension.async_mode, async_mode=True)
@ti.test(require=[ti.extension.async_mode, ti.extension.sparse],
async_mode=True)
def test_remove_clear_list_from_fused_serial():
x = ti.field(ti.i32)
y = ti.field(ti.i32)
Expand Down

0 comments on commit 08bb052

Please sign in to comment.