diff --git a/taichi/aot/graph_data.cpp b/taichi/aot/graph_data.cpp new file mode 100644 index 0000000000000..6b988356f6831 --- /dev/null +++ b/taichi/aot/graph_data.cpp @@ -0,0 +1,48 @@ +#include "taichi/aot/graph_data.h" +#include "taichi/program/ndarray.h" +#define TI_RUNTIME_HOST +#include "taichi/program/context.h" +#undef TI_RUNTIME_HOST + +namespace taichi { +namespace lang { +namespace aot { +void CompiledGraph::run( + const std::unordered_map &args) const { + RuntimeContext ctx; + for (const auto &dispatch : dispatches) { + memset(&ctx, 0, sizeof(RuntimeContext)); + + TI_ASSERT(dispatch.compiled_kernel); + // Populate args metadata into RuntimeContext + const auto &symbolic_args_ = dispatch.symbolic_args; + for (int i = 0; i < symbolic_args_.size(); ++i) { + auto &symbolic_arg = symbolic_args_[i]; + auto found = args.find(symbolic_arg.name); + TI_ERROR_IF(found == args.end(), "Missing runtime value for {}", + symbolic_arg.name); + const aot::IValue &ival = found->second; + if (ival.tag == aot::ArgKind::NDARRAY) { + Ndarray *arr = reinterpret_cast(ival.val); + TI_ERROR_IF(ival.tag != aot::ArgKind::NDARRAY, + "Required a ndarray for argument {}", symbolic_arg.name); + auto ndarray_elem_shape = std::vector( + arr->shape.end() - symbolic_arg.element_shape.size(), + arr->shape.end()); + TI_ERROR_IF(ndarray_elem_shape != symbolic_arg.element_shape, + "Mismatched shape information for argument {}", + symbolic_arg.name); + set_runtime_ctx_ndarray(&ctx, i, arr); + } else { + TI_ERROR_IF(ival.tag != aot::ArgKind::SCALAR, + "Required a scalar for argument {}", symbolic_arg.name); + ctx.set_arg(i, ival.val); + } + } + + dispatch.compiled_kernel->launch(&ctx); + } +} +} // namespace aot +} // namespace lang +} // namespace taichi diff --git a/taichi/aot/graph_data.h b/taichi/aot/graph_data.h index 5da03b4d89d41..12e96062f59d6 100644 --- a/taichi/aot/graph_data.h +++ b/taichi/aot/graph_data.h @@ -4,10 +4,14 @@ #include #include "taichi/aot/module_data.h" +template +T taichi_union_cast_with_different_sizes(G g); + namespace taichi { namespace lang { class AotModuleBuilder; class Ndarray; +struct RuntimeContext; namespace aot { // Currently only scalar and ndarray are supported. enum class ArgKind { SCALAR, NDARRAY, UNKNOWN }; @@ -67,10 +71,6 @@ class TI_DLL_EXPORT Kernel { * @param ctx Host context */ virtual void launch(RuntimeContext *ctx) = 0; - - virtual void save_to_module(AotModuleBuilder *builder) { - TI_NOT_IMPLEMENTED; - } }; struct CompiledDispatch { @@ -84,6 +84,8 @@ struct CompiledDispatch { struct CompiledGraph { std::vector dispatches; + void run(const std::unordered_map &args) const; + TI_IO_DEF(dispatches); }; diff --git a/taichi/aot/module_builder.cpp b/taichi/aot/module_builder.cpp index b194d2ee384c6..d5f7668a0a430 100644 --- a/taichi/aot/module_builder.cpp +++ b/taichi/aot/module_builder.cpp @@ -52,5 +52,21 @@ void AotModuleBuilder::load(const std::string &output_dir) { TI_ERROR("Aot loader not supported"); } +void AotModuleBuilder::dump_graph(std::string output_dir) const { + const std::string graph_file = fmt::format("{}/graphs.tcb", output_dir); + write_to_binary_file(graphs_, graph_file); +} + +void AotModuleBuilder::add_graph(const std::string &name, + const aot::CompiledGraph &graph) { + if (graphs_.count(name) != 0) { + TI_ERROR("Graph {} already exists", name); + } + // Handle adding kernels separately. + for (const auto &dispatch : graph.dispatches) { + add_compiled_kernel(dispatch.compiled_kernel); + } + graphs_[name] = graph; +} } // namespace lang } // namespace taichi diff --git a/taichi/aot/module_builder.h b/taichi/aot/module_builder.h index 29c4869d509b3..d891e05a1f63f 100644 --- a/taichi/aot/module_builder.h +++ b/taichi/aot/module_builder.h @@ -7,6 +7,7 @@ #include "taichi/backends/device.h" #include "taichi/ir/snode.h" #include "taichi/aot/module_data.h" +#include "taichi/aot/graph_data.h" namespace taichi { namespace lang { @@ -37,6 +38,10 @@ class AotModuleBuilder { virtual void dump(const std::string &output_dir, const std::string &filename) const = 0; + void dump_graph(std::string output_dir) const; + + void add_graph(const std::string &name, const aot::CompiledGraph &graph); + protected: /** * Intended to be overriden by each backend's implementation. @@ -62,6 +67,10 @@ class AotModuleBuilder { TI_NOT_IMPLEMENTED; } + virtual void add_compiled_kernel(aot::Kernel *kernel) { + TI_NOT_IMPLEMENTED; + } + virtual void add_per_backend_tmpl(const std::string &identifier, const std::string &key, Kernel *kernel) { @@ -69,6 +78,9 @@ class AotModuleBuilder { } static bool all_fields_are_dense_in_container(const SNode *container); + + private: + std::unordered_map graphs_; }; } // namespace lang diff --git a/taichi/aot/module_loader.h b/taichi/aot/module_loader.h index b63fd9e676277..caf4c857ce156 100644 --- a/taichi/aot/module_loader.h +++ b/taichi/aot/module_loader.h @@ -16,7 +16,7 @@ namespace taichi { namespace lang { struct RuntimeContext; - +class Graph; namespace aot { class TI_DLL_EXPORT Field { @@ -90,11 +90,16 @@ class TI_DLL_EXPORT Module { KernelTemplate *get_kernel_template(const std::string &name); Field *get_field(const std::string &name); + virtual std::unique_ptr get_graph(std::string name) { + TI_NOT_IMPLEMENTED; + } + protected: virtual std::unique_ptr make_new_kernel(const std::string &name) = 0; virtual std::unique_ptr make_new_kernel_template( const std::string &name) = 0; virtual std::unique_ptr make_new_field(const std::string &name) = 0; + std::unordered_map graphs_; private: std::unordered_map> loaded_kernels_; diff --git a/taichi/backends/vulkan/aot_module_builder_impl.cpp b/taichi/backends/vulkan/aot_module_builder_impl.cpp index 60cc5f7aa266f..ed03800098bd1 100644 --- a/taichi/backends/vulkan/aot_module_builder_impl.cpp +++ b/taichi/backends/vulkan/aot_module_builder_impl.cpp @@ -5,6 +5,7 @@ #include "taichi/aot/module_data.h" #include "taichi/codegen/spirv/spirv_codegen.h" +#include "taichi/backends/vulkan/vulkan_graph_data.h" namespace taichi { namespace lang { @@ -135,6 +136,8 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir, const std::string json_path = fmt::format("{}/metadata.json", output_dir); converted.dump_json(json_path); + + dump_graph(output_dir); } void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, @@ -147,6 +150,12 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier, ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes); } +void AotModuleBuilderImpl::add_compiled_kernel(aot::Kernel *kernel) { + const auto register_params = static_cast(kernel)->params(); + ti_aot_data_.kernels.push_back(register_params.kernel_attribs); + ti_aot_data_.spirv_codes.push_back(register_params.task_spirv_source_codes); +} + void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier, const SNode *rep_snode, bool is_scalar, diff --git a/taichi/backends/vulkan/aot_module_builder_impl.h b/taichi/backends/vulkan/aot_module_builder_impl.h index 0accfcc203343..40dc4157c06f4 100644 --- a/taichi/backends/vulkan/aot_module_builder_impl.h +++ b/taichi/backends/vulkan/aot_module_builder_impl.h @@ -36,6 +36,8 @@ class AotModuleBuilderImpl : public AotModuleBuilder { const std::string &key, Kernel *kernel) override; + void add_compiled_kernel(aot::Kernel *kernel) override; + std::string write_spv_file(const std::string &output_dir, const TaskAttributes &k, const std::vector &source_code) const; diff --git a/taichi/backends/vulkan/aot_module_loader_impl.cpp b/taichi/backends/vulkan/aot_module_loader_impl.cpp index 149b172ac4c2c..4ea34de89fc7c 100644 --- a/taichi/backends/vulkan/aot_module_loader_impl.cpp +++ b/taichi/backends/vulkan/aot_module_loader_impl.cpp @@ -4,6 +4,7 @@ #include #include "taichi/runtime/vulkan/runtime.h" +#include "taichi/aot/graph_data.h" namespace taichi { namespace lang { @@ -39,6 +40,21 @@ class AotModuleImpl : public aot::Module { } ti_aot_data_.spirv_codes.push_back(spirv_sources_codes); } + + const std::string graph_path = + fmt::format("{}/graphs.tcb", params.module_path); + read_from_binary_file(graphs_, graph_path); + } + + std::unique_ptr get_graph(std::string name) override { + TI_ERROR_IF(graphs_.count(name) == 0, "Cannot find graph {}", name); + std::vector dispatches; + for (auto &dispatch : graphs_[name].dispatches) { + dispatches.push_back({dispatch.kernel_name, dispatch.symbolic_args, + get_kernel(dispatch.kernel_name)}); + } + aot::CompiledGraph graph{dispatches}; + return std::make_unique(std::move(graph)); } size_t get_root_size() const override { diff --git a/taichi/backends/vulkan/aot_module_loader_impl.h b/taichi/backends/vulkan/aot_module_loader_impl.h index 23990b56a5c68..16230e1411306 100644 --- a/taichi/backends/vulkan/aot_module_loader_impl.h +++ b/taichi/backends/vulkan/aot_module_loader_impl.h @@ -7,31 +7,14 @@ #include "taichi/backends/vulkan/aot_utils.h" #include "taichi/runtime/vulkan/runtime.h" #include "taichi/codegen/spirv/kernel_utils.h" - +#include "taichi/aot/module_builder.h" #include "taichi/aot/module_loader.h" +#include "taichi/backends/vulkan/aot_module_builder_impl.h" +#include "taichi/backends/vulkan/vulkan_graph_data.h" namespace taichi { namespace lang { namespace vulkan { - -class VkRuntime; - -class KernelImpl : public aot::Kernel { - public: - explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &¶ms) - : runtime_(runtime), params_(std::move(params)) { - } - - void launch(RuntimeContext *ctx) override { - auto handle = runtime_->register_taichi_kernel(params_); - runtime_->launch_kernel(handle, ctx); - } - - private: - VkRuntime *const runtime_; - const VkRuntime::RegisterParams params_; -}; - struct TI_DLL_EXPORT AotModuleParams { std::string module_path; VkRuntime *runtime{nullptr}; diff --git a/taichi/backends/vulkan/vulkan_graph_data.h b/taichi/backends/vulkan/vulkan_graph_data.h new file mode 100644 index 0000000000000..6fa3cafc1e3e0 --- /dev/null +++ b/taichi/backends/vulkan/vulkan_graph_data.h @@ -0,0 +1,29 @@ +#pragma once +#include "taichi/runtime/vulkan/runtime.h" + +namespace taichi { +namespace lang { +namespace vulkan { +class KernelImpl : public aot::Kernel { + public: + explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &¶ms) + : runtime_(runtime), params_(std::move(params)) { + handle_ = runtime_->register_taichi_kernel(params_); + } + + void launch(RuntimeContext *ctx) override { + runtime_->launch_kernel(handle_, ctx); + } + + const VkRuntime::RegisterParams ¶ms() { + return params_; + } + + private: + VkRuntime *const runtime_; + VkRuntime::KernelHandle handle_; + const VkRuntime::RegisterParams params_; +}; +} // namespace vulkan +} // namespace lang +} // namespace taichi diff --git a/taichi/program/graph.cpp b/taichi/program/graph.cpp deleted file mode 100644 index 3fa1c703c7064..0000000000000 --- a/taichi/program/graph.cpp +++ /dev/null @@ -1,101 +0,0 @@ -#include "taichi/program/graph.h" -#include "taichi/program/kernel.h" -#include "taichi/aot/module_builder.h" -#include "spdlog/fmt/fmt.h" - -#include - -namespace taichi { -namespace lang { - -void Dispatch::compile( - std::vector &compiled_dispatches) { - if (compiled_kernel_) - return; - compiled_kernel_ = kernel_->compile_to_aot_kernel(); - aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_, - compiled_kernel_.get()}; - compiled_dispatches.push_back(std::move(dispatch)); -} - -void Sequential::compile( - std::vector &compiled_dispatches) { - // In the future we can do more across-kernel optimization here. - for (Node *n : sequence_) { - n->compile(compiled_dispatches); - } -} - -void Sequential::append(Node *node) { - sequence_.push_back(node); -} - -void Sequential::dispatch(Kernel *kernel, const std::vector &args) { - Node *n = owning_graph_->new_dispatch_node(kernel, args); - sequence_.push_back(n); -} - -Graph::Graph(std::string name) : name_(name) { - seq_ = std::make_unique(this); -} -Node *Graph::new_dispatch_node(Kernel *kernel, - const std::vector &args) { - all_nodes_.push_back(std::make_unique(kernel, args)); - return all_nodes_.back().get(); -} - -Sequential *Graph::new_sequential_node() { - all_nodes_.push_back(std::make_unique(this)); - return static_cast(all_nodes_.back().get()); -} - -void Graph::compile() { - seq()->compile(compiled_graph_.dispatches); -} - -Sequential *Graph::seq() const { - return seq_.get(); -} - -void Graph::dispatch(Kernel *kernel, const std::vector &args) { - seq()->dispatch(kernel, args); -} - -void Graph::run( - const std::unordered_map &args) const { - RuntimeContext ctx; - for (const auto &dispatch : compiled_graph_.dispatches) { - memset(&ctx, 0, sizeof(RuntimeContext)); - - TI_ASSERT(dispatch.compiled_kernel); - // Populate args metadata into RuntimeContext - const auto &symbolic_args_ = dispatch.symbolic_args; - for (int i = 0; i < symbolic_args_.size(); ++i) { - auto &symbolic_arg = symbolic_args_[i]; - auto found = args.find(symbolic_arg.name); - TI_ERROR_IF(found == args.end(), "Missing runtime value for {}", - symbolic_arg.name); - const aot::IValue &ival = found->second; - if (ival.tag == aot::ArgKind::NDARRAY) { - Ndarray *arr = reinterpret_cast(ival.val); - TI_ERROR_IF(ival.tag != aot::ArgKind::NDARRAY, - "Required a ndarray for argument {}", symbolic_arg.name); - auto ndarray_elem_shape = std::vector( - arr->shape.end() - symbolic_arg.element_shape.size(), - arr->shape.end()); - TI_ERROR_IF(ndarray_elem_shape != symbolic_arg.element_shape, - "Mismatched shape information for argument {}", - symbolic_arg.name); - set_runtime_ctx_ndarray(&ctx, i, arr); - } else { - TI_ERROR_IF(ival.tag != aot::ArgKind::SCALAR, - "Required a scalar for argument {}", symbolic_arg.name); - ctx.set_arg(i, ival.val); - } - } - - dispatch.compiled_kernel->launch(&ctx); - } -} -} // namespace lang -} // namespace taichi diff --git a/taichi/program/graph_builder.cpp b/taichi/program/graph_builder.cpp new file mode 100644 index 0000000000000..76c579d70e74b --- /dev/null +++ b/taichi/program/graph_builder.cpp @@ -0,0 +1,65 @@ +#include "taichi/program/graph_builder.h" +#include "taichi/program/ndarray.h" +#include "taichi/program/program.h" + +namespace taichi { +namespace lang { +void Dispatch::compile( + std::vector &compiled_dispatches) { + if (!compiled_kernel_) { + compiled_kernel_ = kernel_->compile_to_aot_kernel(); + } + aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_, + compiled_kernel_.get()}; + compiled_dispatches.push_back(std::move(dispatch)); +} + +void Sequential::compile( + std::vector &compiled_dispatches) { + // In the future we can do more across-kernel optimization here. + for (Node *n : sequence_) { + n->compile(compiled_dispatches); + } +} + +void Sequential::append(Node *node) { + sequence_.push_back(node); +} + +void Sequential::dispatch(Kernel *kernel, const std::vector &args) { + Node *n = owning_graph_->new_dispatch_node(kernel, args); + sequence_.push_back(n); +} + +GraphBuilder::GraphBuilder() { + seq_ = std::make_unique(this); +} + +Node *GraphBuilder::new_dispatch_node(Kernel *kernel, + const std::vector &args) { + all_nodes_.push_back(std::make_unique(kernel, args)); + return all_nodes_.back().get(); +} + +Sequential *GraphBuilder::new_sequential_node() { + all_nodes_.push_back(std::make_unique(this)); + return static_cast(all_nodes_.back().get()); +} + +std::unique_ptr GraphBuilder::compile() { + std::vector dispatches; + seq()->compile(dispatches); + aot::CompiledGraph graph{dispatches}; + return std::make_unique(std::move(graph)); +} + +Sequential *GraphBuilder::seq() const { + return seq_.get(); +} + +void GraphBuilder::dispatch(Kernel *kernel, const std::vector &args) { + seq()->dispatch(kernel, args); +} + +} // namespace lang +} // namespace taichi diff --git a/taichi/program/graph.h b/taichi/program/graph_builder.h similarity index 58% rename from taichi/program/graph.h rename to taichi/program/graph_builder.h index 1999b9b015e74..129e5adc7b94b 100644 --- a/taichi/program/graph.h +++ b/taichi/program/graph_builder.h @@ -2,18 +2,14 @@ #include #include -#include -#include "taichi/program/ndarray.h" -#include "taichi/program/program.h" #include "taichi/ir/type.h" #include "taichi/aot/graph_data.h" -#include "taichi/aot/module_builder.h" namespace taichi { namespace lang { class Kernel; -class Graph; +class GraphBuilder; class Node { public: @@ -27,6 +23,7 @@ class Node { virtual void compile( std::vector &compiled_dispatches) = 0; }; + class Dispatch : public Node { public: explicit Dispatch(Kernel *kernel, const std::vector &args) @@ -45,7 +42,7 @@ class Dispatch : public Node { class Sequential : public Node { public: - explicit Sequential(Graph *graph) : owning_graph_(graph) { + explicit Sequential(GraphBuilder *graph) : owning_graph_(graph) { } void append(Node *node); @@ -57,34 +54,15 @@ class Sequential : public Node { private: std::vector sequence_; - Graph *owning_graph_{nullptr}; + GraphBuilder *owning_graph_{nullptr}; }; -/* - * Graph class works as both builder and runner. - * - * Two typical workflows using Graph: - * - build graph -> compile -> run - * - build graph -> compile -> serialize -> deserialize -> run - * - * Thus Graph can be constructed in two ways, either as an empty object - * or from an `aot::CompiledGraph` loaded from aot module. - * - * Currently Graph only supports sequential launches without returning value - * to host. - */ -class Graph { +class GraphBuilder { public: - explicit Graph(std::string name); - - explicit Graph(std::string name, const aot::CompiledGraph &compiled) - : name_(name), compiled_graph_(compiled) { - } + explicit GraphBuilder(); // TODO: compile() can take in Arch argument - void compile(); - - void run(const std::unordered_map &args) const; + std::unique_ptr compile(); Node *new_dispatch_node(Kernel *kernel, const std::vector &args); @@ -94,19 +72,9 @@ class Graph { Sequential *seq() const; - aot::CompiledGraph compiled_graph() const { - return compiled_graph_; - } - - std::string name() const { - return name_; - } - private: - std::string name_; std::unique_ptr seq_{nullptr}; std::vector> all_nodes_; - aot::CompiledGraph compiled_graph_; }; } // namespace lang diff --git a/tests/cpp/aot/aot_save_load_test.cpp b/tests/cpp/aot/aot_save_load_test.cpp index b64b49b2350fb..4ed35ada03585 100644 --- a/tests/cpp/aot/aot_save_load_test.cpp +++ b/tests/cpp/aot/aot_save_load_test.cpp @@ -5,6 +5,8 @@ #include "taichi/program/program.h" #include "tests/cpp/ir/ndarray_kernel.h" #include "tests/cpp/program/test_program.h" +#include "taichi/aot/graph_data.h" +#include "taichi/program/graph_builder.h" #ifdef TI_WITH_VULKAN #include "taichi/backends/vulkan/aot_module_loader_impl.h" #include "taichi/backends/device.h" @@ -299,4 +301,104 @@ TEST(AotSaveLoad, VulkanNdarray) { // Deallocate embedded_device->device()->dealloc_memory(devalloc_arr_); } + +[[maybe_unused]] static void save_graph() { + TestProgram test_prog; + test_prog.setup(Arch::vulkan); + auto aot_builder = test_prog.prog()->make_aot_module_builder(Arch::vulkan); + auto ker1 = setup_kernel1(test_prog.prog()); + auto ker2 = setup_kernel2(test_prog.prog()); + + auto g_builder = std::make_unique(); + auto seq = g_builder->seq(); + auto arr_arg = aot::Arg{ + "arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}}; + seq->dispatch(ker1.get(), {arr_arg}); + seq->dispatch(ker2.get(), + {arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(), + aot::ArgKind::SCALAR}}); + auto graph = g_builder->compile(); + + aot_builder->add_graph("test", *graph); + aot_builder->dump(".", ""); +} + +TEST(AotLoadGraph, Vulkan) { + // Otherwise will segfault on macOS VM, + // where Vulkan is installed but no devices are present + if (!vulkan::is_vulkan_api_available()) { + return; + } + + save_graph(); + + // API based on proposal https://github.com/taichi-dev/taichi/issues/3642 + // Initialize Vulkan program + taichi::uint64 *result_buffer{nullptr}; + taichi::lang::RuntimeContext host_ctx; + auto memory_pool = + std::make_unique(Arch::vulkan, nullptr); + result_buffer = (taichi::uint64 *)memory_pool->allocate( + sizeof(taichi::uint64) * taichi_result_buffer_entries, 8); + host_ctx.result_buffer = result_buffer; + + // Create Taichi Device for computation + lang::vulkan::VulkanDeviceCreator::Params evd_params; + evd_params.api_version = + taichi::lang::vulkan::VulkanEnvSettings::kApiVersion(); + auto embedded_device = + std::make_unique(evd_params); + taichi::lang::vulkan::VulkanDevice *device_ = + static_cast( + embedded_device->device()); + // Create Vulkan runtime + vulkan::VkRuntime::Params params; + params.host_result_buffer = result_buffer; + params.device = device_; + auto vulkan_runtime = + std::make_unique(std::move(params)); + + // Run AOT module loader + vulkan::AotModuleParams mod_params; + mod_params.module_path = "."; + mod_params.runtime = vulkan_runtime.get(); + + std::unique_ptr vk_module = + aot::Module::load(Arch::vulkan, mod_params); + EXPECT_TRUE(vk_module); + + // Retrieve kernels/fields/etc from AOT module + auto root_size = vk_module->get_root_size(); + EXPECT_EQ(root_size, 0); + vulkan_runtime->add_root_buffer(root_size); + + auto graph = vk_module->get_graph("test"); + + const int size = 10; + taichi::lang::Device::AllocParams alloc_params; + alloc_params.host_write = true; + alloc_params.size = size * sizeof(int); + alloc_params.usage = taichi::lang::AllocUsage::Storage; + DeviceAllocation devalloc_arr_ = device_->allocate_memory(alloc_params); + + int src[size] = {0}; + src[0] = 2; + src[2] = 40; + write_devalloc(vulkan_runtime.get(), devalloc_arr_, src, sizeof(src)); + + std::unordered_map args; + auto arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size}); + args.insert({"arr", aot::IValue::create(arr)}); + args.insert({"x", aot::IValue::create(2)}); + graph->run(args); + vulkan_runtime->synchronize(); + + int dst[size] = {1}; + load_devalloc(vulkan_runtime.get(), devalloc_arr_, dst, sizeof(dst)); + + EXPECT_EQ(dst[0], 2); + EXPECT_EQ(dst[1], 2); + EXPECT_EQ(dst[2], 42); + device_->dealloc_memory(devalloc_arr_); +} #endif diff --git a/tests/cpp/program/graph_test.cpp b/tests/cpp/program/graph_test.cpp index 2ac688f63cad8..c6b36b1a64fd2 100644 --- a/tests/cpp/program/graph_test.cpp +++ b/tests/cpp/program/graph_test.cpp @@ -4,8 +4,9 @@ #include "taichi/inc/constants.h" #include "taichi/program/program.h" #include "tests/cpp/program/test_program.h" -#include "taichi/program/graph.h" +#include "taichi/aot/graph_data.h" #include "tests/cpp/ir/ndarray_kernel.h" +#include "taichi/program/graph_builder.h" #ifdef TI_WITH_VULKAN #include "taichi/backends/vulkan/vulkan_loader.h" #endif @@ -27,15 +28,16 @@ TEST(GraphTest, SimpleGraphRun) { auto ker1 = setup_kernel1(test_prog.prog()); auto ker2 = setup_kernel2(test_prog.prog()); - auto g = std::make_unique("test"); - auto seq = g->seq(); + auto g_builder = std::make_unique(); + auto seq = g_builder->seq(); auto arr_arg = aot::Arg{ "arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}}; seq->dispatch(ker1.get(), {arr_arg}); seq->dispatch(ker2.get(), {arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(), aot::ArgKind::SCALAR}}); - g->compile(); + + auto g = g_builder->compile(); auto array = Ndarray(test_prog.prog(), PrimitiveType::i32, {size}); array.write_int({0}, 2);