Skip to content

Commit

Permalink
[aot] Serialize built graph, deserialize and run.
Browse files Browse the repository at this point in the history
related: taichi-dev#4786

This PR demonstrates a minimal example of serializing a built graph,
deserializing and running it.

ghstack-source-id: e369b326734b3cffaf91262ed84806044e121c3c
Pull Request resolved: taichi-dev#5016
  • Loading branch information
Ailing Zhang authored and Ailing Zhang committed May 22, 2022
1 parent fba92cf commit 5e713b7
Show file tree
Hide file tree
Showing 15 changed files with 333 additions and 169 deletions.
48 changes: 48 additions & 0 deletions taichi/aot/graph_data.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
#include "taichi/aot/graph_data.h"
#include "taichi/program/ndarray.h"
#define TI_RUNTIME_HOST
#include "taichi/program/context.h"
#undef TI_RUNTIME_HOST

namespace taichi {
namespace lang {
namespace aot {
void CompiledGraph::run(
const std::unordered_map<std::string, IValue> &args) const {
RuntimeContext ctx;
for (const auto &dispatch : dispatches) {
memset(&ctx, 0, sizeof(RuntimeContext));

TI_ASSERT(dispatch.compiled_kernel);
// Populate args metadata into RuntimeContext
const auto &symbolic_args_ = dispatch.symbolic_args;
for (int i = 0; i < symbolic_args_.size(); ++i) {
auto &symbolic_arg = symbolic_args_[i];
auto found = args.find(symbolic_arg.name);
TI_ERROR_IF(found == args.end(), "Missing runtime value for {}",
symbolic_arg.name);
const aot::IValue &ival = found->second;
if (ival.tag == aot::ArgKind::NDARRAY) {
Ndarray *arr = reinterpret_cast<Ndarray *>(ival.val);
TI_ERROR_IF(ival.tag != aot::ArgKind::NDARRAY,
"Required a ndarray for argument {}", symbolic_arg.name);
auto ndarray_elem_shape = std::vector<int>(
arr->shape.end() - symbolic_arg.element_shape.size(),
arr->shape.end());
TI_ERROR_IF(ndarray_elem_shape != symbolic_arg.element_shape,
"Mismatched shape information for argument {}",
symbolic_arg.name);
set_runtime_ctx_ndarray(&ctx, i, arr);
} else {
TI_ERROR_IF(ival.tag != aot::ArgKind::SCALAR,
"Required a scalar for argument {}", symbolic_arg.name);
ctx.set_arg(i, ival.val);
}
}

dispatch.compiled_kernel->launch(&ctx);
}
}
} // namespace aot
} // namespace lang
} // namespace taichi
10 changes: 6 additions & 4 deletions taichi/aot/graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,14 @@
#include <unordered_map>
#include "taichi/aot/module_data.h"

template <typename T, typename G>
T taichi_union_cast_with_different_sizes(G g);

namespace taichi {
namespace lang {
class AotModuleBuilder;
class Ndarray;
struct RuntimeContext;
namespace aot {
// Currently only scalar and ndarray are supported.
enum class ArgKind { SCALAR, NDARRAY, UNKNOWN };
Expand Down Expand Up @@ -67,10 +71,6 @@ class TI_DLL_EXPORT Kernel {
* @param ctx Host context
*/
virtual void launch(RuntimeContext *ctx) = 0;

virtual void save_to_module(AotModuleBuilder *builder) {
TI_NOT_IMPLEMENTED;
}
};

struct CompiledDispatch {
Expand All @@ -84,6 +84,8 @@ struct CompiledDispatch {
struct CompiledGraph {
std::vector<CompiledDispatch> dispatches;

void run(const std::unordered_map<std::string, IValue> &args) const;

TI_IO_DEF(dispatches);
};

Expand Down
17 changes: 17 additions & 0 deletions taichi/aot/module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,22 @@ void AotModuleBuilder::load(const std::string &output_dir) {
TI_ERROR("Aot loader not supported");
}

void AotModuleBuilder::dump_graph(std::string output_dir) const {
const std::string graph_file = fmt::format("{}/graphs.tcb", output_dir);

write_to_binary_file(graphs_, graph_file);
}

void AotModuleBuilder::add_graph(const std::string &name,
const aot::CompiledGraph &graph) {
if (graphs_.count(name) != 0) {
TI_ERROR("Graph {} already exists", name);
}
// Handle adding kernels separately.
for (const auto &dispatch : graph.dispatches) {
add_compiled_kernel(dispatch.compiled_kernel);
}
graphs_[name] = graph;
}
} // namespace lang
} // namespace taichi
12 changes: 12 additions & 0 deletions taichi/aot/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "taichi/backends/device.h"
#include "taichi/ir/snode.h"
#include "taichi/aot/module_data.h"
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -37,6 +38,10 @@ class AotModuleBuilder {
virtual void dump(const std::string &output_dir,
const std::string &filename) const = 0;

void dump_graph(std::string output_dir) const;

void add_graph(const std::string &name, const aot::CompiledGraph &graph);

protected:
/**
* Intended to be overriden by each backend's implementation.
Expand All @@ -62,13 +67,20 @@ class AotModuleBuilder {
TI_NOT_IMPLEMENTED;
}

virtual void add_compiled_kernel(aot::Kernel *kernel) {
TI_NOT_IMPLEMENTED;
}

virtual void add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
TI_NOT_IMPLEMENTED;
}

static bool all_fields_are_dense_in_container(const SNode *container);

private:
std::unordered_map<std::string, aot::CompiledGraph> graphs_;
};

} // namespace lang
Expand Down
7 changes: 6 additions & 1 deletion taichi/aot/module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace taichi {
namespace lang {

struct RuntimeContext;

class Graph;
namespace aot {

class TI_DLL_EXPORT Field {
Expand Down Expand Up @@ -90,11 +90,16 @@ class TI_DLL_EXPORT Module {
KernelTemplate *get_kernel_template(const std::string &name);
Field *get_field(const std::string &name);

virtual std::unique_ptr<aot::CompiledGraph> get_graph(std::string name) {
TI_NOT_IMPLEMENTED;
}

protected:
virtual std::unique_ptr<Kernel> make_new_kernel(const std::string &name) = 0;
virtual std::unique_ptr<KernelTemplate> make_new_kernel_template(
const std::string &name) = 0;
virtual std::unique_ptr<Field> make_new_field(const std::string &name) = 0;
std::unordered_map<std::string, CompiledGraph> graphs_;

private:
std::unordered_map<std::string, std::unique_ptr<Kernel>> loaded_kernels_;
Expand Down
9 changes: 9 additions & 0 deletions taichi/backends/vulkan/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "taichi/aot/module_data.h"
#include "taichi/codegen/spirv/spirv_codegen.h"
#include "taichi/backends/vulkan/vulkan_graph_data.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -135,6 +136,8 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,

const std::string json_path = fmt::format("{}/metadata.json", output_dir);
converted.dump_json(json_path);

dump_graph(output_dir);
}

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Expand All @@ -147,6 +150,12 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes);
}

void AotModuleBuilderImpl::add_compiled_kernel(aot::Kernel *kernel) {
const auto register_params = static_cast<KernelImpl *>(kernel)->params();
ti_aot_data_.kernels.push_back(register_params.kernel_attribs);
ti_aot_data_.spirv_codes.push_back(register_params.task_spirv_source_codes);
}

void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
Expand Down
2 changes: 2 additions & 0 deletions taichi/backends/vulkan/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const std::string &key,
Kernel *kernel) override;

void add_compiled_kernel(aot::Kernel *kernel) override;

std::string write_spv_file(const std::string &output_dir,
const TaskAttributes &k,
const std::vector<uint32_t> &source_code) const;
Expand Down
18 changes: 18 additions & 0 deletions taichi/backends/vulkan/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <type_traits>

#include "taichi/runtime/vulkan/runtime.h"
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -39,6 +40,23 @@ class AotModuleImpl : public aot::Module {
}
ti_aot_data_.spirv_codes.push_back(spirv_sources_codes);
}

const std::string graph_path =
fmt::format("{}/graphs.tcb", params.module_path);
read_from_binary_file(graphs_, graph_path);
}

std::unique_ptr<aot::CompiledGraph> get_graph(std::string name) override {
TI_ERROR_IF(graphs_.count(name) == 0, "Cannot find graph {}", name);
auto &compiled_graph = graphs_[name];
std::vector<aot::CompiledDispatch> dispatches;
for (auto &dispatch : compiled_graph.dispatches) {
// dispatch.compiled_kernel = get_kernel(dispatch.kernel_name);
dispatches.push_back({dispatch.kernel_name, dispatch.symbolic_args,
get_kernel(dispatch.kernel_name)});
}
aot::CompiledGraph graph{dispatches};
return std::make_unique<aot::CompiledGraph>(std::move(graph));
}

size_t get_root_size() const override {
Expand Down
23 changes: 3 additions & 20 deletions taichi/backends/vulkan/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,14 @@
#include "taichi/backends/vulkan/aot_utils.h"
#include "taichi/runtime/vulkan/runtime.h"
#include "taichi/codegen/spirv/kernel_utils.h"

#include "taichi/aot/module_builder.h"
#include "taichi/aot/module_loader.h"
#include "taichi/backends/vulkan/aot_module_builder_impl.h"
#include "taichi/backends/vulkan/vulkan_graph_data.h"

namespace taichi {
namespace lang {
namespace vulkan {

class VkRuntime;

class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &&params)
: runtime_(runtime), params_(std::move(params)) {
}

void launch(RuntimeContext *ctx) override {
auto handle = runtime_->register_taichi_kernel(params_);
runtime_->launch_kernel(handle, ctx);
}

private:
VkRuntime *const runtime_;
const VkRuntime::RegisterParams params_;
};

struct TI_DLL_EXPORT AotModuleParams {
std::string module_path;
VkRuntime *runtime{nullptr};
Expand Down
29 changes: 29 additions & 0 deletions taichi/backends/vulkan/vulkan_graph_data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#pragma once
#include "taichi/runtime/vulkan/runtime.h"

namespace taichi {
namespace lang {
namespace vulkan {
class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &&params)
: runtime_(runtime), params_(std::move(params)) {
handle_ = runtime_->register_taichi_kernel(params_);
}

void launch(RuntimeContext *ctx) override {
runtime_->launch_kernel(handle_, ctx);
}

const VkRuntime::RegisterParams &params() {
return params_;
}

private:
VkRuntime *const runtime_;
VkRuntime::KernelHandle handle_;
const VkRuntime::RegisterParams params_;
};
} // namespace vulkan
} // namespace lang
} // namespace taichi
Loading

0 comments on commit 5e713b7

Please sign in to comment.