From 9c1c50f4099a01b7cb3367a54ff0b4be9fa50740 Mon Sep 17 00:00:00 2001 From: Ailing Zhang Date: Fri, 20 May 2022 19:19:08 +0800 Subject: [PATCH] Update base for Update on "[aot] Serialize built graph, deserialize and run." related: #4786 This PR demonstrates a minimal example of serializing a built graph, deserializing and running it. [ghstack-poisoned] --- taichi/aot/graph_data.h | 10 ++++++---- taichi/backends/vulkan/aot_module_loader_impl.h | 1 + taichi/program/graph.cpp | 16 +++++++++------- taichi/program/graph.h | 8 ++++---- tests/cpp/program/graph_test.cpp | 8 ++++---- 5 files changed, 24 insertions(+), 19 deletions(-) diff --git a/taichi/aot/graph_data.h b/taichi/aot/graph_data.h index 429c3089bd100..5da03b4d89d41 100644 --- a/taichi/aot/graph_data.h +++ b/taichi/aot/graph_data.h @@ -10,9 +10,9 @@ class AotModuleBuilder; class Ndarray; namespace aot { // Currently only scalar and ndarray are supported. -enum ArgKind { SCALAR, NDARRAY, UNKNOWN }; +enum class ArgKind { SCALAR, NDARRAY, UNKNOWN }; -/* +/** * Symbolic argument used in building `Dispatch` nodes in the `Graph`. */ struct Arg { @@ -25,7 +25,7 @@ struct Arg { TI_IO_DEF(name, dtype_name, tag, element_shape); }; -/* +/** * Runtime value used in graph execution. */ struct IValue { @@ -37,7 +37,8 @@ struct IValue { return IValue(reinterpret_cast(&ndarray), ArgKind::NDARRAY); } - template + template ::value, void>> static IValue create(T v) { return IValue(taichi_union_cast_with_different_sizes(v), ArgKind::SCALAR); @@ -47,6 +48,7 @@ struct IValue { IValue(uint64 val, ArgKind tag) : val(val), tag(tag) { } }; + class TI_DLL_EXPORT Kernel { public: // Rule of 5 to make MSVC happy diff --git a/taichi/backends/vulkan/aot_module_loader_impl.h b/taichi/backends/vulkan/aot_module_loader_impl.h index f1e2cc7d00925..23990b56a5c68 100644 --- a/taichi/backends/vulkan/aot_module_loader_impl.h +++ b/taichi/backends/vulkan/aot_module_loader_impl.h @@ -31,6 +31,7 @@ class KernelImpl : public aot::Kernel { VkRuntime *const runtime_; const VkRuntime::RegisterParams params_; }; + struct TI_DLL_EXPORT AotModuleParams { std::string module_path; VkRuntime *runtime{nullptr}; diff --git a/taichi/program/graph.cpp b/taichi/program/graph.cpp index 690cdef026912..3fa1c703c7064 100644 --- a/taichi/program/graph.cpp +++ b/taichi/program/graph.cpp @@ -10,6 +10,8 @@ namespace lang { void Dispatch::compile( std::vector &compiled_dispatches) { + if (compiled_kernel_) + return; compiled_kernel_ = kernel_->compile_to_aot_kernel(); aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_, compiled_kernel_.get()}; @@ -28,21 +30,21 @@ void Sequential::append(Node *node) { sequence_.push_back(node); } -void Sequential::emplace(Kernel *kernel, const std::vector &args) { - Node *n = owning_graph_->create_dispatch(kernel, args); +void Sequential::dispatch(Kernel *kernel, const std::vector &args) { + Node *n = owning_graph_->new_dispatch_node(kernel, args); sequence_.push_back(n); } Graph::Graph(std::string name) : name_(name) { seq_ = std::make_unique(this); } -Node *Graph::create_dispatch(Kernel *kernel, - const std::vector &args) { +Node *Graph::new_dispatch_node(Kernel *kernel, + const std::vector &args) { all_nodes_.push_back(std::make_unique(kernel, args)); return all_nodes_.back().get(); } -Sequential *Graph::create_sequential() { +Sequential *Graph::new_sequential_node() { all_nodes_.push_back(std::make_unique(this)); return static_cast(all_nodes_.back().get()); } @@ -55,8 +57,8 @@ Sequential *Graph::seq() const { return seq_.get(); } -void Graph::emplace(Kernel *kernel, const std::vector &args) { - seq()->emplace(kernel, args); +void Graph::dispatch(Kernel *kernel, const std::vector &args) { + seq()->dispatch(kernel, args); } void Graph::run( diff --git a/taichi/program/graph.h b/taichi/program/graph.h index c97c023ae321c..1999b9b015e74 100644 --- a/taichi/program/graph.h +++ b/taichi/program/graph.h @@ -50,7 +50,7 @@ class Sequential : public Node { void append(Node *node); - void emplace(Kernel *kernel, const std::vector &args); + void dispatch(Kernel *kernel, const std::vector &args); void compile( std::vector &compiled_dispatches) override; @@ -86,11 +86,11 @@ class Graph { void run(const std::unordered_map &args) const; - Node *create_dispatch(Kernel *kernel, const std::vector &args); + Node *new_dispatch_node(Kernel *kernel, const std::vector &args); - Sequential *create_sequential(); + Sequential *new_sequential_node(); - void emplace(Kernel *kernel, const std::vector &args); + void dispatch(Kernel *kernel, const std::vector &args); Sequential *seq() const; diff --git a/tests/cpp/program/graph_test.cpp b/tests/cpp/program/graph_test.cpp index 5d7819714e6e6..2ac688f63cad8 100644 --- a/tests/cpp/program/graph_test.cpp +++ b/tests/cpp/program/graph_test.cpp @@ -31,10 +31,10 @@ TEST(GraphTest, SimpleGraphRun) { auto seq = g->seq(); auto arr_arg = aot::Arg{ "arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}}; - seq->emplace(ker1.get(), {arr_arg}); - seq->emplace(ker2.get(), - {arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(), - aot::ArgKind::SCALAR}}); + seq->dispatch(ker1.get(), {arr_arg}); + seq->dispatch(ker2.get(), + {arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(), + aot::ArgKind::SCALAR}}); g->compile(); auto array = Ndarray(test_prog.prog(), PrimitiveType::i32, {size});