Skip to content

Commit

Permalink
[aot] Add a generic set of AOT structs (#3973)
Browse files Browse the repository at this point in the history
Co-authored-by: Taichi Gardener <taichigardener@gmail.com>
  • Loading branch information
k-ye and taichi-gardener authored Jan 8, 2022
1 parent 990dbaa commit d5f4951
Show file tree
Hide file tree
Showing 3 changed files with 152 additions and 12 deletions.
64 changes: 64 additions & 0 deletions taichi/aot/module_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,70 @@ struct CompiledFieldData {
column_num);
};

struct CompiledOffloadedTask {
std::string type;
std::string name;
// Do we need to inline the source code?
std::string source_path;
int gpu_block_size{0};

TI_IO_DEF(type, name, source_path, gpu_block_size);
};

struct ScalarArg {
std::string dtype_name;
// Unit: byte
size_t offset_in_args_buf{0};

TI_IO_DEF(dtype_name, offset_in_args_buf);
};

struct ArrayArg {
std::string dtype_name;
std::size_t field_dim{0};
// If |element_shape| is empty, it means this is a scalar
std::vector<int> element_shape;
// Unit: byte
std::size_t shape_offset_in_args_buf{0};
// For Vulkan/OpenGL/Metal, this is the binding index
int bind_index{0};

TI_IO_DEF(dtype_name,
field_dim,
element_shape,
shape_offset_in_args_buf,
bind_index);
};

struct CompiledTaichiKernel {
std::vector<CompiledOffloadedTask> tasks;
int args_count{0};
int rets_count{0};
size_t args_buffer_size{0};
size_t rets_buffer_size{0};

std::unordered_map<int, ScalarArg> scalar_args;
std::unordered_map<int, ArrayArg> arr_args;

TI_IO_DEF(tasks,
args_count,
rets_count,
args_buffer_size,
rets_buffer_size,
scalar_args,
arr_args);
};

struct ModuleData {
std::unordered_map<std::string, CompiledTaichiKernel> kernels;
std::unordered_map<std::string, CompiledTaichiKernel> kernel_tmpls;
std::vector<aot::CompiledFieldData> fields;

size_t root_buffer_size;

TI_IO_DEF(kernels, kernel_tmpls, fields, root_buffer_size);
};

} // namespace aot
} // namespace lang
} // namespace taichi
99 changes: 87 additions & 12 deletions taichi/backends/opengl/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#include "taichi/backends/opengl/aot_module_builder_impl.h"

#include "taichi/aot/module_data.h"
#include "taichi/backends/opengl/opengl_utils.h"

#if !defined(TI_PLATFORM_WINDOWS)
Expand All @@ -8,16 +10,80 @@
namespace taichi {
namespace lang {
namespace opengl {
namespace {

AotModuleBuilderImpl::AotModuleBuilderImpl(
StructCompiledResult &compiled_structs,
bool allow_nv_shader_extension)
: compiled_structs_(compiled_structs),
allow_nv_shader_extension_(allow_nv_shader_extension) {
aot_data_.root_buffer_size = compiled_structs_.root_size;
}
class AotDataConverter {
public:
static aot::ModuleData convert(const opengl::AotData &in) {
AotDataConverter c{};
return c.visit(in);
}

private:
explicit AotDataConverter() = default;

aot::ModuleData visit(const opengl::AotData &in) const {
aot::ModuleData res{};
for (const auto &[key, val] : in.kernels) {
res.kernels[key] = visit(val);
}
for (const auto &[key, val] : in.kernel_tmpls) {
res.kernel_tmpls[key] = visit(val);
}
res.fields = in.fields;
res.root_buffer_size = in.root_buffer_size;
return res;
}

aot::CompiledTaichiKernel visit(
const opengl::CompiledTaichiKernel &in) const {
aot::CompiledTaichiKernel res{};
res.tasks.reserve(in.tasks.size());
for (const auto &t : in.tasks) {
res.tasks.push_back(visit(t));
}
res.args_count = in.arg_count;
res.rets_count = in.ret_count;
res.args_buffer_size = in.args_buf_size;
res.rets_buffer_size = in.ret_buf_size;
for (const auto &[arg_id, val] : in.scalar_args) {
res.scalar_args[arg_id] = visit(val);
}
for (const auto &[arg_id, val] : in.arr_args) {
aot::ArrayArg out_arr = visit(val);
out_arr.bind_index = in.used.arr_arg_to_bind_idx.at(arg_id);
res.arr_args[arg_id] = out_arr;
}
return res;
}

aot::CompiledOffloadedTask visit(
const opengl::CompiledOffloadedTask &in) const {
aot::CompiledOffloadedTask res{};
res.type = offloaded_task_type_name(in.type);
res.name = in.name;
res.source_path = in.src;
res.gpu_block_size = in.workgroup_size;
return res;
}

aot::ScalarArg visit(const opengl::ScalarArg &in) const {
aot::ScalarArg res{};
res.dtype_name = in.dtype_name;
res.offset_in_args_buf = in.offset_in_bytes_in_args_buf;
return res;
}

aot::ArrayArg visit(const opengl::CompiledArrayArg &in) const {
aot::ArrayArg res{};
res.dtype_name = in.dtype_name;
res.field_dim = in.field_dim;
res.element_shape = in.element_shape;
res.shape_offset_in_args_buf = in.shape_offset_in_bytes_in_args_buf;
return res;
}
};

namespace {
void write_glsl_file(const std::string &output_dir, CompiledOffloadedTask &t) {
const std::string glsl_path = fmt::format("{}/{}.glsl", output_dir, t.name);
std::ofstream fs{glsl_path};
Expand All @@ -28,28 +94,37 @@ void write_glsl_file(const std::string &output_dir, CompiledOffloadedTask &t) {

} // namespace

AotModuleBuilderImpl::AotModuleBuilderImpl(
StructCompiledResult &compiled_structs,
bool allow_nv_shader_extension)
: compiled_structs_(compiled_structs),
allow_nv_shader_extension_(allow_nv_shader_extension) {
aot_data_.root_buffer_size = compiled_structs_.root_size;
}

void AotModuleBuilderImpl::dump(const std::string &output_dir,
const std::string &filename) const {
TI_WARN_IF(!filename.empty(),
"Filename prefix is ignored on opengl backend.");
// TODO(#3334): Convert |aot_data_| with AotDataConverter
const std::string bin_path = fmt::format("{}/metadata.tcb", output_dir);
write_to_binary_file(aot_data_, bin_path);
// Json format doesn't support multiple line strings.
AotData new_aot_data = aot_data_;
for (auto &k : new_aot_data.kernels) {
AotData aot_data_copy = aot_data_;
for (auto &k : aot_data_copy.kernels) {
for (auto &t : k.second.tasks) {
write_glsl_file(output_dir, t);
}
}
for (auto &k : new_aot_data.kernel_tmpls) {
for (auto &k : aot_data_copy.kernel_tmpls) {
for (auto &t : k.second.tasks) {
write_glsl_file(output_dir, t);
}
}

const std::string txt_path = fmt::format("{}/metadata.json", output_dir);
TextSerializer ts;
ts.serialize_to_json("aot_data", new_aot_data);
ts.serialize_to_json("aot_data", aot_data_copy);
ts.write_to_file(txt_path);
}

Expand Down
1 change: 1 addition & 0 deletions taichi/program/aot_module.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "taichi/aot/module_data.h"
#include "taichi/backends/device.h"
#include "taichi/ir/snode.h"
#include "taichi/aot/module_data.h"

namespace taichi {
namespace lang {
Expand Down

0 comments on commit d5f4951

Please sign in to comment.