diff --git a/taichi/aot/graph_data.h b/taichi/aot/graph_data.h new file mode 100644 index 00000000000000..429c3089bd100d --- /dev/null +++ b/taichi/aot/graph_data.h @@ -0,0 +1,90 @@ +#pragma once +#include +#include +#include +#include "taichi/aot/module_data.h" + +namespace taichi { +namespace lang { +class AotModuleBuilder; +class Ndarray; +namespace aot { +// Currently only scalar and ndarray are supported. +enum ArgKind { SCALAR, NDARRAY, UNKNOWN }; + +/* + * Symbolic argument used in building `Dispatch` nodes in the `Graph`. + */ +struct Arg { + std::string name; + // TODO: real element dtype = dtype + element_shape + std::string dtype_name; + ArgKind tag; + std::vector element_shape; + + TI_IO_DEF(name, dtype_name, tag, element_shape); +}; + +/* + * Runtime value used in graph execution. + */ +struct IValue { + public: + uint64 val; + ArgKind tag; + + static IValue create(const Ndarray &ndarray) { + return IValue(reinterpret_cast(&ndarray), ArgKind::NDARRAY); + } + + template + static IValue create(T v) { + return IValue(taichi_union_cast_with_different_sizes(v), + ArgKind::SCALAR); + } + + private: + IValue(uint64 val, ArgKind tag) : val(val), tag(tag) { + } +}; +class TI_DLL_EXPORT Kernel { + public: + // Rule of 5 to make MSVC happy + Kernel() = default; + virtual ~Kernel() = default; + Kernel(const Kernel &) = delete; + Kernel &operator=(const Kernel &) = delete; + Kernel(Kernel &&) = default; + Kernel &operator=(Kernel &&) = default; + + /** + * @brief Launches the kernel to the device + * + * This does not manage the device to host synchronization. + * + * @param ctx Host context + */ + virtual void launch(RuntimeContext *ctx) = 0; + + virtual void save_to_module(AotModuleBuilder *builder) { + TI_NOT_IMPLEMENTED; + } +}; + +struct CompiledDispatch { + std::string kernel_name; + std::vector symbolic_args; + Kernel *compiled_kernel{nullptr}; + + TI_IO_DEF(kernel_name, symbolic_args); +}; + +struct CompiledGraph { + std::vector dispatches; + + TI_IO_DEF(dispatches); +}; + +} // namespace aot +} // namespace lang +} // namespace taichi diff --git a/taichi/aot/module_loader.h b/taichi/aot/module_loader.h index 0551152cfae4ee..b63fd9e6762772 100644 --- a/taichi/aot/module_loader.h +++ b/taichi/aot/module_loader.h @@ -10,7 +10,7 @@ #include "taichi/aot/module_data.h" #include "taichi/backends/device.h" #include "taichi/ir/snode.h" -#include "taichi/aot/module_data.h" +#include "taichi/aot/graph_data.h" namespace taichi { namespace lang { @@ -30,26 +30,6 @@ class TI_DLL_EXPORT Field { Field &operator=(Field &&) = default; }; -class TI_DLL_EXPORT Kernel { - public: - // Rule of 5 to make MSVC happy - Kernel() = default; - virtual ~Kernel() = default; - Kernel(const Kernel &) = delete; - Kernel &operator=(const Kernel &) = delete; - Kernel(Kernel &&) = default; - Kernel &operator=(Kernel &&) = default; - - /** - * @brief Launches the kernel to the device - * - * This does not manage the device to host synchronization. - * - * @param ctx Host context - */ - virtual void launch(RuntimeContext *ctx) = 0; -}; - class TI_DLL_EXPORT KernelTemplateArg { public: using ArgUnion = std::variant; diff --git a/taichi/backends/vulkan/aot_module_loader_impl.cpp b/taichi/backends/vulkan/aot_module_loader_impl.cpp index 5f87eb4fc9ca8c..149b172ac4c2c9 100644 --- a/taichi/backends/vulkan/aot_module_loader_impl.cpp +++ b/taichi/backends/vulkan/aot_module_loader_impl.cpp @@ -9,9 +9,6 @@ namespace taichi { namespace lang { namespace vulkan { namespace { - -using KernelHandle = VkRuntime::KernelHandle; - class FieldImpl : public aot::Field { public: explicit FieldImpl(VkRuntime *runtime, const aot::CompiledFieldData &field) @@ -23,21 +20,6 @@ class FieldImpl : public aot::Field { aot::CompiledFieldData field_; }; -class KernelImpl : public aot::Kernel { - public: - explicit KernelImpl(VkRuntime *runtime, KernelHandle handle) - : runtime_(runtime), handle_(handle) { - } - - void launch(RuntimeContext *ctx) override { - runtime_->launch_kernel(handle_, ctx); - } - - private: - VkRuntime *const runtime_; - const KernelHandle handle_; -}; - class AotModuleImpl : public aot::Module { public: explicit AotModuleImpl(const AotModuleParams ¶ms) @@ -109,8 +91,7 @@ class AotModuleImpl : public aot::Module { TI_DEBUG("Failed to load kernel {}", name); return nullptr; } - auto handle = runtime_->register_taichi_kernel(kparams); - return std::make_unique(runtime_, handle); + return std::make_unique(runtime_, std::move(kparams)); } std::unique_ptr make_new_kernel_template( diff --git a/taichi/backends/vulkan/aot_module_loader_impl.h b/taichi/backends/vulkan/aot_module_loader_impl.h index b188281cb749d1..f1e2cc7d009256 100644 --- a/taichi/backends/vulkan/aot_module_loader_impl.h +++ b/taichi/backends/vulkan/aot_module_loader_impl.h @@ -16,6 +16,21 @@ namespace vulkan { class VkRuntime; +class KernelImpl : public aot::Kernel { + public: + explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &¶ms) + : runtime_(runtime), params_(std::move(params)) { + } + + void launch(RuntimeContext *ctx) override { + auto handle = runtime_->register_taichi_kernel(params_); + runtime_->launch_kernel(handle, ctx); + } + + private: + VkRuntime *const runtime_; + const VkRuntime::RegisterParams params_; +}; struct TI_DLL_EXPORT AotModuleParams { std::string module_path; VkRuntime *runtime{nullptr}; diff --git a/taichi/backends/vulkan/vulkan_program.cpp b/taichi/backends/vulkan/vulkan_program.cpp index b2d0c6a8d80a3c..cddb446ae469f5 100644 --- a/taichi/backends/vulkan/vulkan_program.cpp +++ b/taichi/backends/vulkan/vulkan_program.cpp @@ -2,6 +2,7 @@ #include "taichi/backends/vulkan/aot_module_builder_impl.h" #include "taichi/backends/vulkan/snode_tree_manager.h" +#include "taichi/backends/vulkan/aot_module_loader_impl.h" #if !defined(ANDROID) && !defined(TI_EMSCRIPTENED) #include "GLFW/glfw3.h" @@ -183,6 +184,16 @@ DeviceAllocation VulkanProgramImpl::allocate_memory_ndarray( /*export_sharing=*/false}); } +std::unique_ptr VulkanProgramImpl::make_aot_kernel( + Kernel &kernel) { + spirv::lower(&kernel); + std::vector compiled_structs; + VkRuntime::RegisterParams kparams = + run_codegen(&kernel, get_compute_device(), compiled_structs); + return std::make_unique(vulkan_runtime_.get(), + std::move(kparams)); +} + VulkanProgramImpl::~VulkanProgramImpl() { vulkan_runtime_.reset(); embedded_device_.reset(); diff --git a/taichi/backends/vulkan/vulkan_program.h b/taichi/backends/vulkan/vulkan_program.h index 8285fdf7e89470..b8d7a820fb7be0 100644 --- a/taichi/backends/vulkan/vulkan_program.h +++ b/taichi/backends/vulkan/vulkan_program.h @@ -82,6 +82,8 @@ class VulkanProgramImpl : public ProgramImpl { return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id); } + std::unique_ptr make_aot_kernel(Kernel &kernel) override; + ~VulkanProgramImpl(); private: diff --git a/taichi/program/graph.cpp b/taichi/program/graph.cpp new file mode 100644 index 00000000000000..e527c93ee80a13 --- /dev/null +++ b/taichi/program/graph.cpp @@ -0,0 +1,95 @@ +#include "taichi/program/graph.h" +#include "taichi/program/kernel.h" +#include "taichi/aot/module_builder.h" +#include "spdlog/fmt/fmt.h" + +#include + +namespace taichi { +namespace lang { + +void Dispatch::compile( + std::vector &compiled_dispatches) { + compiled_kernel_ = kernel_->compile_to_aot_kernel(); + aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_, + compiled_kernel_.get()}; + compiled_dispatches.push_back(std::move(dispatch)); +} + +void Sequential::compile( + std::vector &compiled_dispatches) { + // In the future we can do more across-kernel optimization here. + for (Node *n : sequence_) { + n->compile(compiled_dispatches); + } +} + +void Sequential::append(Node *node) { + sequence_.push_back(node); +} + +void Sequential::emplace(Kernel *kernel, const std::vector &args) { + Node *n = owning_graph_->create_dispatch(kernel, args); + sequence_.push_back(n); +} + +Graph::Graph(std::string name) : name_(name) { + seq_ = std::make_unique(this); +} +Node *Graph::create_dispatch(Kernel *kernel, + const std::vector &args) { + all_nodes_.push_back(std::make_unique(kernel, args)); + return all_nodes_.back().get(); +} + +Sequential *Graph::create_sequential() { + all_nodes_.push_back(std::make_unique(this)); + return static_cast(all_nodes_.back().get()); +} + +void Graph::compile() { + seq()->compile(compiled_graph_.dispatches); +} + +Sequential *Graph::seq() const { + return seq_.get(); +} + +void Graph::emplace(Kernel *kernel, const std::vector &args) { + seq()->emplace(kernel, args); +} + +void Graph::run( + const std::unordered_map &args) const { + RuntimeContext ctx; + for (const auto &dispatch : compiled_graph_.dispatches) { + memset(&ctx, 0, sizeof(RuntimeContext)); + + TI_ASSERT(dispatch.compiled_kernel); + // Populate args metadata into RuntimeContext + const auto &symbolic_args_ = dispatch.symbolic_args; + for (int i = 0; i < symbolic_args_.size(); ++i) { + auto &symbolic_arg = symbolic_args_[i]; + auto found = args.find(symbolic_arg.name); + TI_ERROR_IF(found == args.end(), "Missing runtime value for {}", + symbolic_arg.name); + const aot::IValue &ival = found->second; + if (ival.tag == aot::ArgKind::NDARRAY) { + Ndarray *arr = reinterpret_cast(ival.val); + TI_ERROR_IF((symbolic_arg.tag != ival.tag) || + (symbolic_arg.element_shape != arr->shape), + "Mismatched shape information for argument {}", + symbolic_arg.name); + set_runtime_ctx_ndarray(&ctx, i, arr); + } else { + TI_ERROR_IF(symbolic_arg.tag != aot::ArgKind::SCALAR, + "Required a scalar for argument {}", symbolic_arg.name); + ctx.set_arg(i, ival.val); + } + } + + dispatch.compiled_kernel->launch(&ctx); + } +} +} // namespace lang +} // namespace taichi diff --git a/taichi/program/graph.h b/taichi/program/graph.h new file mode 100644 index 00000000000000..c97c023ae321c4 --- /dev/null +++ b/taichi/program/graph.h @@ -0,0 +1,113 @@ +#pragma once + +#include +#include +#include + +#include "taichi/program/ndarray.h" +#include "taichi/program/program.h" +#include "taichi/ir/type.h" +#include "taichi/aot/graph_data.h" +#include "taichi/aot/module_builder.h" + +namespace taichi { +namespace lang { +class Kernel; +class Graph; + +class Node { + public: + Node() = default; + virtual ~Node() = default; + Node(const Node &) = delete; + Node &operator=(const Node &) = delete; + Node(Node &&) = default; + Node &operator=(Node &&) = default; + + virtual void compile( + std::vector &compiled_dispatches) = 0; +}; +class Dispatch : public Node { + public: + explicit Dispatch(Kernel *kernel, const std::vector &args) + : kernel_(kernel), symbolic_args_(args) { + } + + void compile( + std::vector &compiled_dispatches) override; + + private: + mutable bool serialized_{false}; + Kernel *kernel_{nullptr}; + std::unique_ptr compiled_kernel_{nullptr}; + std::vector symbolic_args_; +}; + +class Sequential : public Node { + public: + explicit Sequential(Graph *graph) : owning_graph_(graph) { + } + + void append(Node *node); + + void emplace(Kernel *kernel, const std::vector &args); + + void compile( + std::vector &compiled_dispatches) override; + + private: + std::vector sequence_; + Graph *owning_graph_{nullptr}; +}; + +/* + * Graph class works as both builder and runner. + * + * Two typical workflows using Graph: + * - build graph -> compile -> run + * - build graph -> compile -> serialize -> deserialize -> run + * + * Thus Graph can be constructed in two ways, either as an empty object + * or from an `aot::CompiledGraph` loaded from aot module. + * + * Currently Graph only supports sequential launches without returning value + * to host. + */ +class Graph { + public: + explicit Graph(std::string name); + + explicit Graph(std::string name, const aot::CompiledGraph &compiled) + : name_(name), compiled_graph_(compiled) { + } + + // TODO: compile() can take in Arch argument + void compile(); + + void run(const std::unordered_map &args) const; + + Node *create_dispatch(Kernel *kernel, const std::vector &args); + + Sequential *create_sequential(); + + void emplace(Kernel *kernel, const std::vector &args); + + Sequential *seq() const; + + aot::CompiledGraph compiled_graph() const { + return compiled_graph_; + } + + std::string name() const { + return name_; + } + + private: + std::string name_; + std::unique_ptr seq_{nullptr}; + std::vector> all_nodes_; + aot::CompiledGraph compiled_graph_; +}; + +} // namespace lang +} // namespace taichi diff --git a/taichi/program/kernel.cpp b/taichi/program/kernel.cpp index 3965bd9b275661..19f9ee8370ec67 100644 --- a/taichi/program/kernel.cpp +++ b/taichi/program/kernel.cpp @@ -64,6 +64,10 @@ void Kernel::compile() { compiled_ = program->compile(*this); } +std::unique_ptr Kernel::compile_to_aot_kernel() { + return program->make_aot_kernel(*this); +} + void Kernel::lower(bool to_executable) { TI_ASSERT(!lowered_); TI_ASSERT(supports_lowering(arch)); diff --git a/taichi/program/kernel.h b/taichi/program/kernel.h index d29aac4399f4a7..0d98252b8c4994 100644 --- a/taichi/program/kernel.h +++ b/taichi/program/kernel.h @@ -6,6 +6,7 @@ #include "taichi/backends/arch.h" #include "taichi/program/callable.h" #include "taichi/program/ndarray.h" +#include "taichi/aot/graph_data.h" TLANG_NAMESPACE_BEGIN @@ -86,6 +87,7 @@ class TI_DLL_EXPORT Kernel : public Callable { void compile(); + std::unique_ptr compile_to_aot_kernel(); /** * Lowers |ir| to CHI IR level * diff --git a/taichi/program/ndarray.cpp b/taichi/program/ndarray.cpp index 46209a56abfc83..ae19dde3ea6238 100644 --- a/taichi/program/ndarray.cpp +++ b/taichi/program/ndarray.cpp @@ -83,10 +83,10 @@ void Ndarray::write_float(const std::vector &i, float64 val) { rw_accessors_bank_->get(this).write_float(i, val); } -void set_runtime_ctx_ndarray(RuntimeContext &ctx, +void set_runtime_ctx_ndarray(RuntimeContext *ctx, int arg_id, - Ndarray &ndarray) { - ctx.set_arg_devalloc(arg_id, ndarray.ndarray_alloc_, ndarray.shape); + Ndarray *ndarray) { + ctx->set_arg_devalloc(arg_id, ndarray->ndarray_alloc_, ndarray->shape); } } // namespace lang diff --git a/taichi/program/ndarray.h b/taichi/program/ndarray.h index 66ff63f3001c9c..31181fc55705a6 100644 --- a/taichi/program/ndarray.h +++ b/taichi/program/ndarray.h @@ -60,7 +60,6 @@ class Ndarray { // TODO: move this as a method inside RuntimeContext once Ndarray is decoupled // with Program -void set_runtime_ctx_ndarray(RuntimeContext &ctx, int arg_id, Ndarray &ndarray); - +void set_runtime_ctx_ndarray(RuntimeContext *ctx, int arg_id, Ndarray *ndarray); } // namespace lang } // namespace taichi diff --git a/taichi/program/program.h b/taichi/program/program.h index 06f00f4cbe761b..4ff8227b483a1a 100644 --- a/taichi/program/program.h +++ b/taichi/program/program.h @@ -199,6 +199,10 @@ class TI_DLL_EXPORT Program { // future. FunctionType compile(Kernel &kernel, OffloadedStmt *offloaded = nullptr); + std::unique_ptr make_aot_kernel(Kernel &kernel) { + return program_impl_->make_aot_kernel(kernel); + } + void check_runtime_error(); Kernel &get_snode_reader(SNode *snode); diff --git a/taichi/program/program_impl.h b/taichi/program/program_impl.h index c168bf1246d2aa..f0ae8d243c69a1 100644 --- a/taichi/program/program_impl.h +++ b/taichi/program/program_impl.h @@ -8,6 +8,7 @@ #include "taichi/program/snode_expr_utils.h" #include "taichi/program/kernel_profiler.h" #include "taichi/backends/device.h" +#include "taichi/aot/graph_data.h" namespace taichi { namespace lang { @@ -66,6 +67,13 @@ class ProgramImpl { */ virtual std::unique_ptr make_aot_module_builder() = 0; + /** + * Compile a taichi::lang::Kernel to taichi::lang::aot::Kernel. + */ + virtual std::unique_ptr make_aot_kernel(Kernel &kernel) { + TI_NOT_IMPLEMENTED; + } + /** * Dump Offline-cache data to disk */ diff --git a/tests/cpp/aot/aot_save_load_test.cpp b/tests/cpp/aot/aot_save_load_test.cpp index 3a095988128ecd..b64b49b2350fbc 100644 --- a/tests/cpp/aot/aot_save_load_test.cpp +++ b/tests/cpp/aot/aot_save_load_test.cpp @@ -271,8 +271,7 @@ TEST(AotSaveLoad, VulkanNdarray) { DeviceAllocation devalloc_arr_ = embedded_device->device()->allocate_memory(alloc_params); Ndarray arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size}); - taichi::lang::set_runtime_ctx_ndarray(host_ctx, 0, arr); - + taichi::lang::set_runtime_ctx_ndarray(&host_ctx, 0, &arr); int src[size] = {0}; src[0] = 2; src[2] = 40; diff --git a/tests/cpp/program/graph_test.cpp b/tests/cpp/program/graph_test.cpp new file mode 100644 index 00000000000000..918f13d6a185c4 --- /dev/null +++ b/tests/cpp/program/graph_test.cpp @@ -0,0 +1,52 @@ +#include "gtest/gtest.h" +#include "taichi/ir/ir_builder.h" +#include "taichi/ir/statements.h" +#include "taichi/inc/constants.h" +#include "taichi/program/program.h" +#include "tests/cpp/program/test_program.h" +#include "taichi/program/graph.h" +#include "tests/cpp/ir/ndarray_kernel.h" +#ifdef TI_WITH_VULKAN +#include "taichi/backends/vulkan/vulkan_loader.h" +#endif + +using namespace taichi; +using namespace lang; +#ifdef TI_WITH_VULKAN +TEST(GraphTest, SimpleGraphRun) { + // Otherwise will segfault on macOS VM, + // where Vulkan is installed but no devices are present + if (!vulkan::is_vulkan_api_available()) { + return; + } + TestProgram test_prog; + test_prog.setup(Arch::vulkan); + + const int size = 10; + + auto ker1 = setup_kernel1(test_prog.prog()); + auto ker2 = setup_kernel2(test_prog.prog()); + + auto g = std::make_unique("test"); + auto seq = g->seq(); + auto arr_arg = aot::Arg{ + "arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {size}}; + seq->emplace(ker1.get(), {arr_arg}); + seq->emplace(ker2.get(), + {arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(), + aot::ArgKind::SCALAR}}); + g->compile(); + + auto array = Ndarray(test_prog.prog(), PrimitiveType::i32, {size}); + array.write_int({0}, 2); + array.write_int({2}, 40); + std::unordered_map args; + args.insert({"arr", aot::IValue::create(array)}); + args.insert({"x", aot::IValue::create(2)}); + + g->run(args); + EXPECT_EQ(array.read_int({0}), 2); + EXPECT_EQ(array.read_int({1}), 2); + EXPECT_EQ(array.read_int({2}), 42); +} +#endif