Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[async] [metal] Support async mode on Metal #1920

Merged
merged 2 commits into from
Oct 4, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 31 additions & 32 deletions taichi/backends/metal/codegen_metal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@
#include <functional>
#include <string>

#include "taichi/backends/metal/api.h"
#include "taichi/backends/metal/constants.h"
#include "taichi/backends/metal/env_config.h"
#include "taichi/backends/metal/features.h"
#include "taichi/ir/ir.h"
#include "taichi/ir/transforms.h"
#include "taichi/util/line_appender.h"
#include "taichi/math/arithmetic.h"
#include "taichi/backends/metal/api.h"
#include "taichi/util/line_appender.h"

TLANG_NAMESPACE_BEGIN
namespace metal {
Expand Down Expand Up @@ -83,21 +84,26 @@ class KernelCodegen : public IRVisitor {
};

public:
struct Config {
bool allow_simdgroup = true;
};
// TODO(k-ye): Create a Params to hold these ctor params.
KernelCodegen(const std::string &taichi_kernel_name,
const std::string &root_snode_type_name,
Kernel *kernel,
const CompiledStructs *compiled_structs,
PrintStringTable *print_strtab,
const CodeGen::Config &config)
const Config &config,
OffloadedStmt *offloaded)
: mtl_kernel_prefix_(taichi_kernel_name),
root_snode_type_name_(root_snode_type_name),
kernel_(kernel),
compiled_structs_(compiled_structs),
needs_root_buffer_(compiled_structs_->root_size > 0),
ctx_attribs_(*kernel_),
print_strtab_(print_strtab),
cgen_config_(config) {
cgen_config_(config),
offloaded_(offloaded) {
ti_kernel_attribus_.name = taichi_kernel_name;
ti_kernel_attribus_.is_jit_evaluator = kernel->is_evaluator;
// allow_undefined_visitor = true;
Expand Down Expand Up @@ -769,7 +775,8 @@ class KernelCodegen : public IRVisitor {

void generate_kernels() {
SectionGuard sg(this, Section::Kernels);
kernel_->ir->accept(this);
IRNode *ast = offloaded_ ? offloaded_ : kernel_->ir.get();
ast->accept(this);

if (used_features()->sparse) {
emit("");
Expand Down Expand Up @@ -1212,7 +1219,8 @@ class KernelCodegen : public IRVisitor {
const bool needs_root_buffer_;
const KernelContextAttributes ctx_attribs_;
PrintStringTable *const print_strtab_;
const CodeGen::Config &cgen_config_;
const Config &cgen_config_;
OffloadedStmt *const offloaded_;

bool is_top_level_{true};
int mtl_kernel_count_{0};
Expand All @@ -1226,36 +1234,27 @@ class KernelCodegen : public IRVisitor {

} // namespace

CodeGen::CodeGen(Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs,
const Config &config)
: kernel_(kernel),
kernel_mgr_(kernel_mgr),
compiled_structs_(compiled_structs),
id_(Program::get_kernel_id()),
taichi_kernel_name_(fmt::format("mtl_k{:04d}_{}", id_, kernel_->name)),
config_(config) {
}
FunctionType compile_to_metal_executable(
Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs,
OffloadedStmt *offloaded) {
const auto id = Program::get_kernel_id();
const auto taichi_kernel_name(
fmt::format("mtl_k{:04d}_{}", id, kernel->name));

FunctionType CodeGen::compile() {
auto &config = kernel_->program.config;
config.demote_dense_struct_fors = true;
irpass::compile_to_executable(kernel_->ir.get(), config,
/*vectorize=*/false, kernel_->grad,
/*ad_use_stack=*/true, config.print_ir,
/*lower_global_access=*/true,
/*make_thread_local=*/config.make_thread_local);
KernelCodegen::Config cgen_config;
cgen_config.allow_simdgroup = EnvConfig::instance().is_simdgroup_enabled();

KernelCodegen codegen(
taichi_kernel_name_, kernel_->program.snode_root->node_type_name, kernel_,
compiled_structs_, kernel_mgr_->print_strtable(), config_);
taichi_kernel_name, kernel->program.snode_root->node_type_name, kernel,
compiled_structs, kernel_mgr->print_strtable(), cgen_config, offloaded);

const auto source_code = codegen.run();
kernel_mgr_->register_taichi_kernel(taichi_kernel_name_, source_code,
codegen.ti_kernels_attribs(),
codegen.kernel_ctx_attribs());
return [kernel_mgr = kernel_mgr_,
kernel_name = taichi_kernel_name_](Context &ctx) {
kernel_mgr->register_taichi_kernel(taichi_kernel_name, source_code,
codegen.ti_kernels_attribs(),
codegen.kernel_ctx_attribs());
return [kernel_mgr, kernel_name = taichi_kernel_name](Context &ctx) {
kernel_mgr->launch_taichi_kernel(kernel_name, &ctx);
};
}
Expand Down
32 changes: 8 additions & 24 deletions taichi/backends/metal/codegen_metal.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,30 +16,14 @@
TLANG_NAMESPACE_BEGIN
namespace metal {

class CodeGen {
public:
struct Config {
bool allow_simdgroup = true;
};

CodeGen(Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs,
const Config &config);

FunctionType compile();

private:
void lower();
FunctionType gen(const SNode &root_snode, KernelManager *runtime);

Kernel *const kernel_;
KernelManager *const kernel_mgr_;
const CompiledStructs *const compiled_structs_;
const int id_;
const std::string taichi_kernel_name_;
const Config config_;
};
// If |offloaded| is nullptr, this compiles the AST in |kernel|. Otherwise it
// compiles just |offloaded|. These ASTs must have already been lowered at the
// CHI level.
FunctionType compile_to_metal_executable(
Kernel *kernel,
KernelManager *kernel_mgr,
const CompiledStructs *compiled_structs,
OffloadedStmt *offloaded = nullptr);

} // namespace metal

Expand Down
17 changes: 10 additions & 7 deletions taichi/program/async_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) {
auto cloned_stmt = ker.ir_handle.clone();
stmt = cloned_stmt->as<OffloadedStmt>();

compilation_workers.enqueue([async_func, stmt, kernel]() {
compilation_workers.enqueue([async_func, stmt, kernel, this]() {
{
// Final lowering
using namespace irpass;
Expand All @@ -143,8 +143,7 @@ void ExecutionQueue::enqueue(const TaskLaunchRecord &ker) {
is_extension_supported(config.arch, Extension::bls) &&
config.make_block_local);
}
auto codegen = KernelCodeGen::create(kernel->arch, kernel, stmt);
auto func = codegen->codegen();
auto func = this->compile_to_backend_(*kernel, stmt);
async_func->set(func);
});
ir_bank_->insert_to_trash_bin(std::move(cloned_stmt));
Expand All @@ -161,14 +160,18 @@ void ExecutionQueue::synchronize() {
launch_worker.flush();
}

ExecutionQueue::ExecutionQueue(IRBank *ir_bank)
ExecutionQueue::ExecutionQueue(
IRBank *ir_bank,
const BackendExecCompilationFunc &compile_to_backend)
: compilation_workers(4), // TODO: remove 4
launch_worker(1),
ir_bank_(ir_bank) {
ir_bank_(ir_bank),
compile_to_backend_(compile_to_backend) {
}

AsyncEngine::AsyncEngine(Program *program)
: queue(&ir_bank_),
AsyncEngine::AsyncEngine(Program *program,
const BackendExecCompilationFunc &compile_to_backend)
: queue(&ir_bank_, compile_to_backend),
program(program),
sfg(std::make_unique<StateFlowGraph>(&ir_bank_)) {
}
Expand Down
12 changes: 10 additions & 2 deletions taichi/program/async_engine.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#include <atomic>
#include <condition_variable>
#include <deque>
#include <functional>
#include <future>
#include <mutex>
#include <thread>
Expand Down Expand Up @@ -70,6 +71,10 @@ class ParallelExecutor {
std::condition_variable flush_cv_;
};

// Compiles the offloaded and optimized IR to the target backend's executable.
using BackendExecCompilationFunc =
std::function<FunctionType(Kernel &, OffloadedStmt *)>;

// In charge of (parallel) compilation to binary and (serial) kernel launching
class ExecutionQueue {
public:
Expand All @@ -78,7 +83,8 @@ class ExecutionQueue {
ParallelExecutor compilation_workers; // parallel compilation
ParallelExecutor launch_worker; // serial launching

explicit ExecutionQueue(IRBank *ir_bank);
explicit ExecutionQueue(IRBank *ir_bank,
const BackendExecCompilationFunc &compile_to_backend);

void enqueue(const TaskLaunchRecord &ker);

Expand Down Expand Up @@ -117,6 +123,7 @@ class ExecutionQueue {
std::unordered_map<uint64, AsyncCompiledFunc> compiled_funcs_;

IRBank *ir_bank_; // not owned
BackendExecCompilationFunc compile_to_backend_;
};

// An engine for asynchronous execution and optimization
Expand All @@ -130,7 +137,8 @@ class AsyncEngine {
std::unique_ptr<StateFlowGraph> sfg;
std::deque<TaskLaunchRecord> task_queue;

explicit AsyncEngine(Program *program);
explicit AsyncEngine(Program *program,
const BackendExecCompilationFunc &compile_to_backend);

bool fuse(); // return true when modified

Expand Down
2 changes: 1 addition & 1 deletion taichi/program/extension.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ bool is_extension_supported(Arch arch, Extension ext) {
{Arch::cuda,
{Extension::sparse, Extension::async_mode, Extension::data64,
Extension::adstack, Extension::bls, Extension::assertion}},
{Arch::metal, {Extension::adstack}},
{Arch::metal, {Extension::adstack, Extension::async_mode}},
{Arch::opengl, {Extension::extfunc}},
{Arch::cc, {Extension::data64, Extension::extfunc, Extension::adstack}},
};
Expand Down
9 changes: 4 additions & 5 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,9 +69,8 @@ void Kernel::compile() {
void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class
// necessary for each backend?
TI_ASSERT(!lowered);
if (arch_is_cpu(arch) || arch == Arch::cuda) {
if (arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::metal) {
CurrentKernelGuard _(program, this);
auto codegen = KernelCodeGen::create(arch, this);
auto config = program.config;
bool verbose = config.print_ir;
if ((is_accessor && !config.print_accessor_ir) ||
Expand All @@ -81,9 +80,9 @@ void Kernel::lower(bool to_executable) { // TODO: is a "Lowerer" class
if (to_executable) {
irpass::compile_to_executable(
ir.get(), config, /*vectorize*/ arch_is_cpu(arch), grad,
/*ad_use_stack*/ true, verbose, /*lower_global_access*/ to_executable,
/*make_thread_local*/ config.make_thread_local,
/*make_block_local*/
/*ad_use_stack=*/true, verbose, /*lower_global_access=*/to_executable,
/*make_thread_local=*/config.make_thread_local,
/*make_block_local=*/
is_extension_supported(config.arch, Extension::bls) &&
config.make_block_local);
} else {
Expand Down
35 changes: 22 additions & 13 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
#include "taichi/backends/cuda/cuda_context.h"
#endif
#include "taichi/backends/metal/codegen_metal.h"
#include "taichi/backends/metal/env_config.h"
#include "taichi/backends/opengl/codegen_opengl.h"
#include "taichi/backends/cpu/codegen_cpu.h"
#include "taichi/struct/struct.h"
Expand Down Expand Up @@ -146,8 +145,11 @@ Program::Program(Arch desired_arch) {

if (config.async_mode) {
TI_WARN("Running in async mode. This is experimental.");
TI_ASSERT(arch_is_cpu(config.arch) || config.arch == Arch::cuda);
async_engine = std::make_unique<AsyncEngine>(this);
TI_ASSERT(is_extension_supported(config.arch, Extension::async_mode));
async_engine = std::make_unique<AsyncEngine>(
this, [this](Kernel &kernel, OffloadedStmt *offloaded) {
return this->compile_to_backend_executable(kernel, offloaded);
});
}

// TODO: allow users to run in debug mode without out-of-bound checks
Expand Down Expand Up @@ -196,17 +198,10 @@ FunctionType Program::compile(Kernel &kernel) {
auto start_t = Time::get_time();
TI_AUTO_PROF;
FunctionType ret = nullptr;
if (arch_is_cpu(kernel.arch) || kernel.arch == Arch::cuda) {
if (arch_is_cpu(kernel.arch) || kernel.arch == Arch::cuda ||
kernel.arch == Arch::metal) {
kernel.lower();
auto codegen = KernelCodeGen::create(kernel.arch, &kernel);
ret = codegen->compile();
} else if (kernel.arch == Arch::metal) {
metal::CodeGen::Config cgen_config;
cgen_config.allow_simdgroup =
metal::EnvConfig::instance().is_simdgroup_enabled();
metal::CodeGen codegen(&kernel, metal_kernel_mgr_.get(),
&metal_compiled_structs_.value(), cgen_config);
ret = codegen.compile();
ret = compile_to_backend_executable(kernel, /*offloaded=*/nullptr);
} else if (kernel.arch == Arch::opengl) {
opengl::OpenglCodeGen codegen(kernel.name, &opengl_struct_compiled_.value(),
opengl_kernel_launcher_.get());
Expand All @@ -223,6 +218,20 @@ FunctionType Program::compile(Kernel &kernel) {
return ret;
}

FunctionType Program::compile_to_backend_executable(Kernel &kernel,
OffloadedStmt *offloaded) {
if (arch_is_cpu(kernel.arch) || kernel.arch == Arch::cuda) {
auto codegen = KernelCodeGen::create(kernel.arch, &kernel, offloaded);
return codegen->compile();
} else if (kernel.arch == Arch::metal) {
return metal::compile_to_metal_executable(&kernel, metal_kernel_mgr_.get(),
&metal_compiled_structs_.value(),
offloaded);
}
TI_NOT_IMPLEMENTED;
return nullptr;
}

// For CPU and CUDA archs only
void Program::initialize_runtime_system(StructCompiler *scomp) {
// auto tlctx = llvm_context_host.get();
Expand Down
6 changes: 6 additions & 0 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -181,8 +181,14 @@ class Program {
void end_function_definition() {
}

// TODO: This function is doing two things: 1) compiling CHI IR, and 2)
// offloading them to each backend. We should probably separate the logic?
FunctionType compile(Kernel &kernel);

// Just does the per-backend executable compilation without kernel lowering.
FunctionType compile_to_backend_executable(Kernel &kernel,
OffloadedStmt *stmt);

void initialize_runtime_system(StructCompiler *scomp);

void materialize_layout();
Expand Down
6 changes: 3 additions & 3 deletions tests/python/test_fuse_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,17 @@
template_fuse_reduction


@ti.archs_with([ti.cpu], async_mode=True)
@ti.test(require=ti.extension.async_mode, async_mode=True)
def test_fuse_dense_x2y2z():
template_fuse_dense_x2y2z(size=100 * 1024**2)


@ti.archs_with([ti.cpu], async_mode=True)
@ti.test(require=ti.extension.async_mode, async_mode=True)
def test_fuse_reduction():
template_fuse_reduction(size=10 * 1024**2)


@ti.archs_with([ti.cpu], async_mode=True)
@ti.test(require=ti.extension.async_mode, async_mode=True)
def test_no_fuse_sigs_mismatch():
n = 4096
x = ti.field(ti.i32, shape=(n, ))
Expand Down
3 changes: 2 additions & 1 deletion tests/python/test_fuse_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def y_to_z():
assert z[i] == x[i] + 5


@ti.archs_with([ti.cpu], async_mode=True)
@ti.test(require=[ti.extension.async_mode, ti.extension.sparse],
async_mode=True)
def test_fuse_dynamic_x2y2z():
benchmark_fuse_dynamic_x2y2z()
3 changes: 2 additions & 1 deletion tests/python/test_sfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import pytest


@ti.test(require=ti.extension.async_mode, async_mode=True)
@ti.test(require=[ti.extension.async_mode, ti.extension.sparse],
async_mode=True)
def test_remove_clear_list_from_fused_serial():
x = ti.field(ti.i32)
y = ti.field(ti.i32)
Expand Down