-
Notifications
You must be signed in to change notification settings - Fork 2.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[aot] Serialize built graph, deserialize and run. #5016
Closed
Closed
Changes from 3 commits
Commits
Show all changes
5 commits
Select commit
Hold shift + click to select a range
014593e
[aot] Serialize built graph, deserialize and run.
0cc69e0
Update on "[aot] Serialize built graph, deserialize and run."
1c60025
Update on "[aot] Serialize built graph, deserialize and run."
785ed6f
Update on "[aot] Serialize built graph, deserialize and run."
dc4499b
Update on "[aot] Serialize built graph, deserialize and run."
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
#pragma once | ||
#include "taichi/runtime/vulkan/runtime.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
namespace vulkan { | ||
class KernelImpl : public aot::Kernel { | ||
public: | ||
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &¶ms) | ||
: runtime_(runtime), params_(std::move(params)) { | ||
} | ||
|
||
void launch(RuntimeContext *ctx) override { | ||
auto handle = runtime_->register_taichi_kernel(params_); | ||
runtime_->launch_kernel(handle, ctx); | ||
} | ||
|
||
const VkRuntime::RegisterParams ¶ms() { | ||
return params_; | ||
} | ||
|
||
private: | ||
VkRuntime *const runtime_; | ||
const VkRuntime::RegisterParams params_; | ||
}; | ||
} // namespace vulkan | ||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,6 +5,7 @@ | |
#include "taichi/program/program.h" | ||
#include "tests/cpp/ir/ndarray_kernel.h" | ||
#include "tests/cpp/program/test_program.h" | ||
#include "taichi/program/graph.h" | ||
#ifdef TI_WITH_VULKAN | ||
#include "taichi/backends/vulkan/aot_module_loader_impl.h" | ||
#include "taichi/backends/device.h" | ||
|
@@ -299,4 +300,104 @@ TEST(AotSaveLoad, VulkanNdarray) { | |
// Deallocate | ||
embedded_device->device()->dealloc_memory(devalloc_arr_); | ||
} | ||
|
||
[[maybe_unused]] static void save_graph() { | ||
TestProgram test_prog; | ||
test_prog.setup(Arch::vulkan); | ||
auto aot_builder = test_prog.prog()->make_aot_module_builder(Arch::vulkan); | ||
auto ker1 = setup_kernel1(test_prog.prog()); | ||
auto ker2 = setup_kernel2(test_prog.prog()); | ||
|
||
auto g = std::make_unique<Graph>("test"); | ||
auto seq = g->seq(); | ||
auto arr_arg = aot::Arg{ | ||
"arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}}; | ||
seq->dispatch(ker1.get(), {arr_arg}); | ||
seq->dispatch(ker2.get(), | ||
{arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(), | ||
aot::ArgKind::SCALAR}}); | ||
g->compile(); | ||
|
||
aot_builder->add_graph(g->name(), g->compiled_graph()); | ||
aot_builder->dump(".", ""); | ||
} | ||
|
||
TEST(AotLoadGraph, Vulkan) { | ||
// Otherwise will segfault on macOS VM, | ||
// where Vulkan is installed but no devices are present | ||
if (!vulkan::is_vulkan_api_available()) { | ||
return; | ||
} | ||
|
||
save_graph(); | ||
|
||
// API based on proposal https://github.com/taichi-dev/taichi/issues/3642 | ||
// Initialize Vulkan program | ||
taichi::uint64 *result_buffer{nullptr}; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, I feel like it would be great if these tests can made the same for all the archs supporting AOT! (cc @turbo0628 @jim19930609 ) A more realistic result probably is that we have different setup routines for each arch. But the graph part stays the same. |
||
taichi::lang::RuntimeContext host_ctx; | ||
auto memory_pool = | ||
std::make_unique<taichi::lang::MemoryPool>(Arch::vulkan, nullptr); | ||
result_buffer = (taichi::uint64 *)memory_pool->allocate( | ||
sizeof(taichi::uint64) * taichi_result_buffer_entries, 8); | ||
host_ctx.result_buffer = result_buffer; | ||
|
||
// Create Taichi Device for computation | ||
lang::vulkan::VulkanDeviceCreator::Params evd_params; | ||
evd_params.api_version = | ||
taichi::lang::vulkan::VulkanEnvSettings::kApiVersion(); | ||
auto embedded_device = | ||
std::make_unique<taichi::lang::vulkan::VulkanDeviceCreator>(evd_params); | ||
taichi::lang::vulkan::VulkanDevice *device_ = | ||
static_cast<taichi::lang::vulkan::VulkanDevice *>( | ||
embedded_device->device()); | ||
// Create Vulkan runtime | ||
vulkan::VkRuntime::Params params; | ||
params.host_result_buffer = result_buffer; | ||
params.device = device_; | ||
auto vulkan_runtime = | ||
std::make_unique<taichi::lang::vulkan::VkRuntime>(std::move(params)); | ||
|
||
// Run AOT module loader | ||
vulkan::AotModuleParams mod_params; | ||
mod_params.module_path = "."; | ||
mod_params.runtime = vulkan_runtime.get(); | ||
|
||
std::unique_ptr<aot::Module> vk_module = | ||
aot::Module::load(Arch::vulkan, mod_params); | ||
EXPECT_TRUE(vk_module); | ||
|
||
// Retrieve kernels/fields/etc from AOT module | ||
auto root_size = vk_module->get_root_size(); | ||
EXPECT_EQ(root_size, 0); | ||
vulkan_runtime->add_root_buffer(root_size); | ||
|
||
auto graph = vk_module->get_graph("test"); | ||
|
||
const int size = 10; | ||
taichi::lang::Device::AllocParams alloc_params; | ||
alloc_params.host_write = true; | ||
alloc_params.size = size * sizeof(int); | ||
alloc_params.usage = taichi::lang::AllocUsage::Storage; | ||
DeviceAllocation devalloc_arr_ = device_->allocate_memory(alloc_params); | ||
|
||
int src[size] = {0}; | ||
src[0] = 2; | ||
src[2] = 40; | ||
write_devalloc(vulkan_runtime.get(), devalloc_arr_, src, sizeof(src)); | ||
|
||
std::unordered_map<std::string, aot::IValue> args; | ||
auto arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size}); | ||
args.insert({"arr", aot::IValue::create(arr)}); | ||
args.insert({"x", aot::IValue::create<int>(2)}); | ||
graph->run(args); | ||
vulkan_runtime->synchronize(); | ||
|
||
int dst[size] = {1}; | ||
load_devalloc(vulkan_runtime.get(), devalloc_arr_, dst, sizeof(dst)); | ||
|
||
EXPECT_EQ(dst[0], 2); | ||
EXPECT_EQ(dst[1], 2); | ||
EXPECT_EQ(dst[2], 42); | ||
device_->dealloc_memory(devalloc_arr_); | ||
} | ||
#endif |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OOC, why we need a separate
dump_graph
, instead of dumping everything indump
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@k-ye Since
dump_graph
is backend independent so I wanted to keep it in the base class. Note it's called in thedump
in every backend indeed we didn't call them separately.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
doesnt need to be public?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
aha nice catch! it was a legacy public api used in the first prototype. Removed in #5024! :D