Skip to content

Commit

Permalink
Update on "[aot] Serialize built graph, deserialize and run."
Browse files Browse the repository at this point in the history
related: #4786

This PR demonstrates a minimal example of serializing a built graph,
deserializing and running it.

[ghstack-poisoned]
  • Loading branch information
Ailing Zhang committed May 20, 2022
2 parents 0cc69e0 + 9c1c50f commit 1c60025
Show file tree
Hide file tree
Showing 9 changed files with 60 additions and 61 deletions.
14 changes: 6 additions & 8 deletions taichi/aot/graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class AotModuleBuilder;
class Ndarray;
namespace aot {
// Currently only scalar and ndarray are supported.
enum ArgKind { SCALAR, NDARRAY, UNKNOWN };
enum class ArgKind { SCALAR, NDARRAY, UNKNOWN };

/*
/**
* Symbolic argument used in building `Dispatch` nodes in the `Graph`.
*/
struct Arg {
Expand All @@ -25,7 +25,7 @@ struct Arg {
TI_IO_DEF(name, dtype_name, tag, element_shape);
};

/*
/**
* Runtime value used in graph execution.
*/
struct IValue {
Expand All @@ -37,7 +37,8 @@ struct IValue {
return IValue(reinterpret_cast<intptr_t>(&ndarray), ArgKind::NDARRAY);
}

template <typename T>
template <typename T,
typename = std::enable_if_t<!std::is_same<T, Ndarray>::value, void>>
static IValue create(T v) {
return IValue(taichi_union_cast_with_different_sizes<uint64>(v),
ArgKind::SCALAR);
Expand All @@ -47,6 +48,7 @@ struct IValue {
IValue(uint64 val, ArgKind tag) : val(val), tag(tag) {
}
};

class TI_DLL_EXPORT Kernel {
public:
// Rule of 5 to make MSVC happy
Expand All @@ -65,10 +67,6 @@ class TI_DLL_EXPORT Kernel {
* @param ctx Host context
*/
virtual void launch(RuntimeContext *ctx) = 0;

virtual void save_to_module(AotModuleBuilder *builder) {
TI_NOT_IMPLEMENTED;
}
};

struct CompiledDispatch {
Expand Down
5 changes: 4 additions & 1 deletion taichi/backends/vulkan/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "taichi/aot/module_data.h"
#include "taichi/codegen/spirv/spirv_codegen.h"
#include "taichi/backends/vulkan/vulkan_graph_data.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -150,7 +151,9 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
}

void AotModuleBuilderImpl::add_compiled_kernel(aot::Kernel *kernel) {
kernel->save_to_module(this);
const auto register_params = static_cast<KernelImpl *>(kernel)->params();
ti_aot_data_.kernels.push_back(register_params.kernel_attribs);
ti_aot_data_.spirv_codes.push_back(register_params.task_spirv_source_codes);
}

void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
Expand Down
5 changes: 0 additions & 5 deletions taichi/backends/vulkan/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,6 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
void dump(const std::string &output_dir,
const std::string &filename) const override;

// FIXME: remove me once TaichiAotData is no longer backend specific.
TaichiAotData &aot_data() {
return ti_aot_data_;
}

private:
void add_per_backend(const std::string &identifier, Kernel *kernel) override;

Expand Down
29 changes: 1 addition & 28 deletions taichi/backends/vulkan/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,38 +10,11 @@
#include "taichi/aot/module_builder.h"
#include "taichi/aot/module_loader.h"
#include "taichi/backends/vulkan/aot_module_builder_impl.h"
#include "taichi/backends/vulkan/vulkan_graph_data.h"

namespace taichi {
namespace lang {
namespace vulkan {

class VkRuntime;

class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &&params)
: runtime_(runtime), params_(std::move(params)) {
}

void launch(RuntimeContext *ctx) override {
auto handle = runtime_->register_taichi_kernel(params_);
runtime_->launch_kernel(handle, ctx);
}

void save_to_module(AotModuleBuilder *builder) override {
// This hack exists because ti_aot_data_ is vulkan specific.
// We need a generic aot::ModuleData inside AotModuleBuilder.
dynamic_cast<AotModuleBuilderImpl *>(builder)->aot_data().kernels.push_back(
params_.kernel_attribs);
dynamic_cast<AotModuleBuilderImpl *>(builder)
->aot_data()
.spirv_codes.push_back(params_.task_spirv_source_codes);
}

private:
VkRuntime *const runtime_;
const VkRuntime::RegisterParams params_;
};
struct TI_DLL_EXPORT AotModuleParams {
std::string module_path;
VkRuntime *runtime{nullptr};
Expand Down
28 changes: 28 additions & 0 deletions taichi/backends/vulkan/vulkan_graph_data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once
#include "taichi/runtime/vulkan/runtime.h"

namespace taichi {
namespace lang {
namespace vulkan {
class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &&params)
: runtime_(runtime), params_(std::move(params)) {
}

void launch(RuntimeContext *ctx) override {
auto handle = runtime_->register_taichi_kernel(params_);
runtime_->launch_kernel(handle, ctx);
}

const VkRuntime::RegisterParams &params() {
return params_;
}

private:
VkRuntime *const runtime_;
const VkRuntime::RegisterParams params_;
};
} // namespace vulkan
} // namespace lang
} // namespace taichi
16 changes: 9 additions & 7 deletions taichi/program/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ namespace lang {

void Dispatch::compile(
std::vector<aot::CompiledDispatch> &compiled_dispatches) {
if (compiled_kernel_)
return;
compiled_kernel_ = kernel_->compile_to_aot_kernel();
aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_,
compiled_kernel_.get()};
Expand All @@ -28,21 +30,21 @@ void Sequential::append(Node *node) {
sequence_.push_back(node);
}

void Sequential::emplace(Kernel *kernel, const std::vector<aot::Arg> &args) {
Node *n = owning_graph_->create_dispatch(kernel, args);
void Sequential::dispatch(Kernel *kernel, const std::vector<aot::Arg> &args) {
Node *n = owning_graph_->new_dispatch_node(kernel, args);
sequence_.push_back(n);
}

Graph::Graph(std::string name) : name_(name) {
seq_ = std::make_unique<Sequential>(this);
}
Node *Graph::create_dispatch(Kernel *kernel,
const std::vector<aot::Arg> &args) {
Node *Graph::new_dispatch_node(Kernel *kernel,
const std::vector<aot::Arg> &args) {
all_nodes_.push_back(std::make_unique<Dispatch>(kernel, args));
return all_nodes_.back().get();
}

Sequential *Graph::create_sequential() {
Sequential *Graph::new_sequential_node() {
all_nodes_.push_back(std::make_unique<Sequential>(this));
return static_cast<Sequential *>(all_nodes_.back().get());
}
Expand All @@ -55,8 +57,8 @@ Sequential *Graph::seq() const {
return seq_.get();
}

void Graph::emplace(Kernel *kernel, const std::vector<aot::Arg> &args) {
seq()->emplace(kernel, args);
void Graph::dispatch(Kernel *kernel, const std::vector<aot::Arg> &args) {
seq()->dispatch(kernel, args);
}

void Graph::run(
Expand Down
8 changes: 4 additions & 4 deletions taichi/program/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Sequential : public Node {

void append(Node *node);

void emplace(Kernel *kernel, const std::vector<aot::Arg> &args);
void dispatch(Kernel *kernel, const std::vector<aot::Arg> &args);

void compile(
std::vector<aot::CompiledDispatch> &compiled_dispatches) override;
Expand Down Expand Up @@ -86,11 +86,11 @@ class Graph {

void run(const std::unordered_map<std::string, aot::IValue> &args) const;

Node *create_dispatch(Kernel *kernel, const std::vector<aot::Arg> &args);
Node *new_dispatch_node(Kernel *kernel, const std::vector<aot::Arg> &args);

Sequential *create_sequential();
Sequential *new_sequential_node();

void emplace(Kernel *kernel, const std::vector<aot::Arg> &args);
void dispatch(Kernel *kernel, const std::vector<aot::Arg> &args);

Sequential *seq() const;

Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/aot/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -312,10 +312,10 @@ TEST(AotSaveLoad, VulkanNdarray) {
auto seq = g->seq();
auto arr_arg = aot::Arg{
"arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}};
seq->emplace(ker1.get(), {arr_arg});
seq->emplace(ker2.get(),
{arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(),
aot::ArgKind::SCALAR}});
seq->dispatch(ker1.get(), {arr_arg});
seq->dispatch(ker2.get(),
{arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(),
aot::ArgKind::SCALAR}});
g->compile();

aot_builder->add_graph(g->name(), g->compiled_graph());
Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/program/graph_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ TEST(GraphTest, SimpleGraphRun) {
auto seq = g->seq();
auto arr_arg = aot::Arg{
"arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}};
seq->emplace(ker1.get(), {arr_arg});
seq->emplace(ker2.get(),
{arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(),
aot::ArgKind::SCALAR}});
seq->dispatch(ker1.get(), {arr_arg});
seq->dispatch(ker2.get(),
{arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(),
aot::ArgKind::SCALAR}});
g->compile();

auto array = Ndarray(test_prog.prog(), PrimitiveType::i32, {size});
Expand Down

0 comments on commit 1c60025

Please sign in to comment.