Skip to content

Commit

Permalink
[refactor] Split Program::compile() (taichi-dev#7847)
Browse files Browse the repository at this point in the history
Issue: taichi-dev#7002

### Brief Summary

`Program::compile()` => `Program::compile_kernel()` +
`Program::launch_kernel()`
  • Loading branch information
PGZXB authored Apr 21, 2023
1 parent 6498808 commit 570d249
Show file tree
Hide file tree
Showing 18 changed files with 132 additions and 99 deletions.
24 changes: 20 additions & 4 deletions cpp_examples/autograd.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -195,10 +195,26 @@ void autograd() {
ctx_ext.set_arg_external_array_with_shape(2, taichi::uint64(ext_c.data()), n,
{n});

(*kernel_init)(config, ctx_init);
(*kernel_forward)(config, ctx_forward);
(*kernel_backward)(config, ctx_backward);
(*kernel_ext)(config, ctx_ext);
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_init);
program.launch_kernel(compiled_kernel_data, ctx_init);
}
{
const auto &compiled_kernel_data = program.compile_kernel(
config, program.get_device_caps(), *kernel_forward);
program.launch_kernel(compiled_kernel_data, ctx_forward);
}
{
const auto &compiled_kernel_data = program.compile_kernel(
config, program.get_device_caps(), *kernel_backward);
program.launch_kernel(compiled_kernel_data, ctx_backward);
}
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_ext);
program.launch_kernel(compiled_kernel_data, ctx_ext);
}
for (int i = 0; i < n; i++)
std::cout << ext_a[i] << " ";
std::cout << std::endl;
Expand Down
26 changes: 19 additions & 7 deletions cpp_examples/run_snode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -134,11 +134,23 @@ void run_snode() {
ctx_ext.set_arg_external_array_with_shape(0, taichi::uint64(ext_arr.data()),
n, {n});

(*kernel_init)(config, ctx_init);
(*kernel_ret)(config, ctx_ret);
std::cout << program.fetch_result<int>(0) << std::endl;
(*kernel_ext)(config, ctx_ext);
for (int i = 0; i < n; i++)
std::cout << ext_arr[i] << " ";
std::cout << std::endl;
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_init);
program.launch_kernel(compiled_kernel_data, ctx_init);
}
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_ret);
program.launch_kernel(compiled_kernel_data, ctx_ret);
std::cout << program.fetch_result<int>(0) << std::endl;
}
{
const auto &compiled_kernel_data =
program.compile_kernel(config, program.get_device_caps(), *kernel_ext);
program.launch_kernel(compiled_kernel_data, ctx_ext);
for (int i = 0; i < n; i++)
std::cout << ext_arr[i] << " ";
std::cout << std::endl;
}
}
6 changes: 5 additions & 1 deletion python/taichi/lang/kernel_impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -811,7 +811,11 @@ def call_back():
)

try:
t_kernel(launch_ctx)
prog = impl.get_runtime().prog
# Compile kernel (& Online Cache & Offline Cache)
compiled_kernel_data = prog.compile_kernel(prog.config(), prog.get_device_caps(), t_kernel)
# Launch kernel
prog.launch_kernel(compiled_kernel_data, launch_ctx)
except Exception as e:
e = handle_exception_from_cpp(e)
raise e from None
Expand Down
5 changes: 4 additions & 1 deletion taichi/aot/graph_data.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,10 @@ void CompiledGraph::jit_run(
// Compile & Run (JIT): The compilation result will be cached, so don't
// worry that the kernels dispatched by this cgraph will be compiled
// repeatedly.
(*dispatch.ti_kernel)(compile_config, launch_ctx);
auto *prog = dispatch.ti_kernel->program;
const auto &compiled_kernel_data = prog->compile_kernel(
compile_config, prog->get_device_caps(), *dispatch.ti_kernel);
prog->launch_kernel(compiled_kernel_data, launch_ctx);
}
}

Expand Down
21 changes: 0 additions & 21 deletions taichi/program/kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@ Kernel::Kernel(Program &program,
this->ir = std::move(ir);
this->program = &program;
is_accessor = false;
compiled_ = nullptr;
ir_is_ast_ = false; // CHI IR

if (autodiff_mode == AutodiffMode::kNone) {
Expand All @@ -56,25 +55,6 @@ Kernel::Kernel(Program &program,
}
}

void Kernel::compile(const CompileConfig &compile_config) {
compiled_ = program->compile(compile_config, *this);
}

void Kernel::operator()(const CompileConfig &compile_config,
LaunchContextBuilder &ctx_builder) {
if (!compiled_) {
compile(compile_config);
}

compiled_(ctx_builder);

const auto arch = compile_config.arch;
if (compile_config.debug &&
(arch_is_cpu(arch) || arch == Arch::cuda || arch == Arch::amdgpu)) {
program->check_runtime_error();
}
}

LaunchContextBuilder Kernel::make_launch_context() {
return LaunchContextBuilder(this);
}
Expand Down Expand Up @@ -166,7 +146,6 @@ void Kernel::init(Program &program,
this->program = &program;

is_accessor = false;
compiled_ = nullptr;
context = std::make_unique<FrontendContext>(program.compile_config().arch);
ir = context->get_root();
ir_is_ast_ = true;
Expand Down
7 changes: 0 additions & 7 deletions taichi/program/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,6 @@ class TI_DLL_EXPORT Kernel : public Callable {
return ir_is_ast_;
}

void compile(const CompileConfig &compile_config);

void operator()(const CompileConfig &compile_config,
LaunchContextBuilder &ctx_builder);

LaunchContextBuilder make_launch_context();

template <typename T>
Expand Down Expand Up @@ -75,8 +70,6 @@ class TI_DLL_EXPORT Kernel : public Callable {

// True if |ir| is a frontend AST. False if it's already offloaded to CHI IR.
bool ir_is_ast_{false};
// The closure that, if invoked, launches the backend kernel (shader)
FunctionType compiled_{nullptr};
mutable std::string kernel_key_;
};

Expand Down
24 changes: 15 additions & 9 deletions taichi/program/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,24 @@ Function *Program::create_function(const FunctionKey &func_key) {
return functions_.back().get();
}

FunctionType Program::compile(const CompileConfig &compile_config,
Kernel &kernel) {
const CompiledKernelData &Program::compile_kernel(
const CompileConfig &compile_config,
const DeviceCapabilityConfig &caps,
const Kernel &kernel_def) {
auto start_t = Time::get_time();
TI_AUTO_PROF;
auto ret = program_impl_->compile(compile_config, &kernel);
TI_ASSERT(ret);
auto &mgr = program_impl_->get_kernel_compilation_manager();
const auto &ckd = mgr.load_or_compile(compile_config, caps, kernel_def);
total_compilation_time_ += Time::get_time() - start_t;
return ret;
return ckd;
}

void Program::launch_kernel(const CompiledKernelData &compiled_kernel_data,
LaunchContextBuilder &ctx) {
program_impl_->get_kernel_launcher().launch_kernel(compiled_kernel_data, ctx);
if (compile_config().debug && arch_uses_llvm(compiled_kernel_data.arch())) {
program_impl_->check_runtime_error(result_buffer);
}
}

void Program::materialize_runtime() {
Expand Down Expand Up @@ -248,10 +258,6 @@ SNode *Program::get_snode_root(int tree_id) {
return snode_trees_[tree_id]->root();
}

void Program::check_runtime_error() {
program_impl_->check_runtime_error(result_buffer);
}

void Program::synchronize() {
program_impl_->synchronize();
}
Expand Down
13 changes: 9 additions & 4 deletions taichi/program/program.h
Original file line number Diff line number Diff line change
Expand Up @@ -125,11 +125,16 @@ class TI_DLL_EXPORT Program {

Function *create_function(const FunctionKey &func_key);

// TODO: This function is doing two things: 1) compiling CHI IR, and 2)
// offloading them to each backend. We should probably separate the logic?
FunctionType compile(const CompileConfig &compile_config, Kernel &kernel);
const CompiledKernelData &compile_kernel(const CompileConfig &compile_config,
const DeviceCapabilityConfig &caps,
const Kernel &kernel_def);

void check_runtime_error();
void launch_kernel(const CompiledKernelData &compiled_kernel_data,
LaunchContextBuilder &ctx);

DeviceCapabilityConfig get_device_caps() {
return program_impl_->get_device_caps();
}

Kernel &get_snode_reader(SNode *snode);

Expand Down
13 changes: 0 additions & 13 deletions taichi/program/program_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,6 @@ void ProgramImpl::compile_snode_tree_types(SNodeTree *tree) {
TI_NOT_IMPLEMENTED;
}

FunctionType ProgramImpl::compile(const CompileConfig &compile_config,
Kernel *kernel) {
// NOTE: Temporary implementation (blocked by cc backend)
// TODO(PGZXB): Final solution: compile -> load_or_compile + launch_kernel
auto &mgr = get_kernel_compilation_manager();
const auto &compiled =
mgr.load_or_compile(compile_config, get_device_caps(), *kernel);
auto &launcher = get_kernel_launcher();
return [&launcher, &compiled](LaunchContextBuilder &ctx_builder) {
launcher.launch_kernel(compiled, ctx_builder);
};
}

void ProgramImpl::dump_cache_data_to_disk() {
auto &mgr = get_kernel_compilation_manager();
mgr.clean_offline_cache(offline_cache::string_to_clean_cache_policy(
Expand Down
14 changes: 4 additions & 10 deletions taichi/program/program_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,6 @@ class ProgramImpl {
public:
explicit ProgramImpl(CompileConfig &config);

/**
* Codegen to specific backend
*/
virtual FunctionType compile(const CompileConfig &compile_config,
Kernel *kernel);

/**
* Allocate runtime buffer, e.g result_buffer or backend specific runtime
* buffer, e.g. preallocated_device_buffer on CUDA.
Expand Down Expand Up @@ -173,17 +167,17 @@ class ProgramImpl {

KernelLauncher &get_kernel_launcher();

virtual DeviceCapabilityConfig get_device_caps() {
return {};
}

protected:
virtual std::unique_ptr<KernelCompiler> make_kernel_compiler() = 0;

virtual std::unique_ptr<KernelLauncher> make_kernel_launcher() {
TI_NOT_IMPLEMENTED;
}

virtual DeviceCapabilityConfig get_device_caps() {
return {};
}

private:
std::unique_ptr<KernelCompilationManager> kernel_com_mgr_;
std::unique_ptr<KernelLauncher> kernel_launcher_;
Expand Down
24 changes: 18 additions & 6 deletions taichi/program/snode_rw_accessors_bank.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,18 @@ void SNodeRwAccessorsBank::Accessors::write_float(const std::vector<int> &I,
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
launch_ctx.set_arg_float(snode_->num_active_indices, val);
prog_->synchronize();
(*writer_)(prog_->compile_config(), launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *writer_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
}

float64 SNodeRwAccessorsBank::Accessors::read_float(const std::vector<int> &I) {
prog_->synchronize();
auto launch_ctx = reader_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
(*reader_)(prog_->compile_config(), launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *reader_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
prog_->synchronize();
if (arch_uses_llvm(prog_->compile_config().arch)) {
return launch_ctx.get_struct_ret_float({0});
Expand All @@ -64,7 +68,9 @@ void SNodeRwAccessorsBank::Accessors::write_int(const std::vector<int> &I,
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
launch_ctx.set_arg_int(snode_->num_active_indices, val);
prog_->synchronize();
(*writer_)(prog_->compile_config(), launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *writer_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
}

// for int32 and int64
Expand All @@ -74,14 +80,18 @@ void SNodeRwAccessorsBank::Accessors::write_uint(const std::vector<int> &I,
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
launch_ctx.set_arg_uint(snode_->num_active_indices, val);
prog_->synchronize();
(*writer_)(prog_->compile_config(), launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *writer_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
}

int64 SNodeRwAccessorsBank::Accessors::read_int(const std::vector<int> &I) {
prog_->synchronize();
auto launch_ctx = reader_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
(*reader_)(prog_->compile_config(), launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *reader_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
prog_->synchronize();
if (arch_uses_llvm(prog_->compile_config().arch)) {
return launch_ctx.get_struct_ret_int({0});
Expand All @@ -94,7 +104,9 @@ uint64 SNodeRwAccessorsBank::Accessors::read_uint(const std::vector<int> &I) {
prog_->synchronize();
auto launch_ctx = reader_->make_launch_context();
set_kernel_args(I, snode_->num_active_indices, &launch_ctx);
(*reader_)(prog_->compile_config(), launch_ctx);
const auto &compiled_kernel_data = prog_->compile_kernel(
prog_->compile_config(), prog_->get_device_caps(), *reader_);
prog_->launch_kernel(compiled_kernel_data, launch_ctx);
prog_->synchronize();
if (arch_uses_llvm(prog_->compile_config().arch)) {
return launch_ctx.get_struct_ret_uint({0});
Expand Down
18 changes: 12 additions & 6 deletions taichi/python/export_lang.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,12 @@ void export_lang(py::module &m) {
.def("insert_snode_access_flag", &ASTBuilder::insert_snode_access_flag)
.def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag);

py::class_<DeviceCapabilityConfig>(
m, "DeviceCapabilityConfig"); // NOLINT(bugprone-unused-raii)

py::class_<CompiledKernelData>(
m, "CompiledKernelData"); // NOLINT(bugprone-unused-raii)

py::class_<Program>(m, "Program")
.def(py::init<>())
.def("config", &Program::compile_config,
Expand Down Expand Up @@ -443,7 +449,11 @@ void export_lang(py::module &m) {
program->fill_ndarray_fast_u32(ndarray, val);
})
.def("get_graphics_device",
[](Program *program) { return program->get_graphics_device(); });
[](Program *program) { return program->get_graphics_device(); })
.def("compile_kernel", &Program::compile_kernel,
py::return_value_policy::reference)
.def("launch_kernel", &Program::launch_kernel)
.def("get_device_caps", &Program::get_device_caps);

py::class_<AotModuleBuilder>(m, "AotModuleBuilder")
.def("add_field", &AotModuleBuilder::add_field)
Expand Down Expand Up @@ -695,11 +705,7 @@ void export_lang(py::module &m) {
[](Kernel *self) -> ASTBuilder * {
return &self->context->builder();
},
py::return_value_policy::reference)
.def("__call__", [](Kernel *kernel, LaunchContextBuilder &launch_ctx) {
py::gil_scoped_release release;
kernel->operator()(kernel->program->compile_config(), launch_ctx);
});
py::return_value_policy::reference);

py::class_<LaunchContextBuilder>(m, "KernelLaunchContext")
.def("set_arg_int", &LaunchContextBuilder::set_arg_int)
Expand Down
3 changes: 2 additions & 1 deletion taichi/runtime/program_impls/dx/dx_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,11 @@ class Dx11ProgramImpl : public ProgramImpl {
return "1" + std::string(has_buffer_ptr ? "b" : "-");
};

DeviceCapabilityConfig get_device_caps() override;

protected:
std::unique_ptr<KernelCompiler> make_kernel_compiler() override;
std::unique_ptr<KernelLauncher> make_kernel_launcher() override;
DeviceCapabilityConfig get_device_caps() override;

private:
std::shared_ptr<Device> device_{nullptr};
Expand Down
3 changes: 2 additions & 1 deletion taichi/runtime/program_impls/metal/metal_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,10 +105,11 @@ class MetalProgramImpl : public ProgramImpl {
return "1" + std::string(has_buffer_ptr ? "b" : "-");
};

DeviceCapabilityConfig get_device_caps() override;

protected:
std::unique_ptr<KernelCompiler> make_kernel_compiler() override;
std::unique_ptr<KernelLauncher> make_kernel_launcher() override;
DeviceCapabilityConfig get_device_caps() override;

private:
std::unique_ptr<metal::MetalDevice> embedded_device_{nullptr};
Expand Down
Loading

0 comments on commit 570d249

Please sign in to comment.