Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[aot] Serialize built graph, deserialize and run. #5016

Closed
wants to merge 5 commits into from
Closed
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions taichi/aot/graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ class TI_DLL_EXPORT Kernel {
* @param ctx Host context
*/
virtual void launch(RuntimeContext *ctx) = 0;

virtual void save_to_module(AotModuleBuilder *builder) {
TI_NOT_IMPLEMENTED;
}
};

struct CompiledDispatch {
Expand Down
17 changes: 17 additions & 0 deletions taichi/aot/module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,22 @@ void AotModuleBuilder::load(const std::string &output_dir) {
TI_ERROR("Aot loader not supported");
}

void AotModuleBuilder::dump_graph(std::string output_dir) const {
const std::string graph_file = fmt::format("{}/graphs.tcb", output_dir);

write_to_binary_file(graphs_, graph_file);
}

void AotModuleBuilder::add_graph(const std::string &name,
aot::CompiledGraph &&graph) {
if (graphs_.count(name) != 0) {
TI_ERROR("Graph {} already exists", name);
}
// Handle adding kernels separately.
for (const auto &dispatch : graph.dispatches) {
add_compiled_kernel(dispatch.compiled_kernel);
}
graphs_[name] = std::move(graph);
}
} // namespace lang
} // namespace taichi
12 changes: 12 additions & 0 deletions taichi/aot/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "taichi/backends/device.h"
#include "taichi/ir/snode.h"
#include "taichi/aot/module_data.h"
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -37,6 +38,10 @@ class AotModuleBuilder {
virtual void dump(const std::string &output_dir,
const std::string &filename) const = 0;

void dump_graph(std::string output_dir) const;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OOC, why we need a separate dump_graph, instead of dumping everything in dump?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@k-ye Since dump_graph is backend independent so I wanted to keep it in the base class. Note it's called in the dump in every backend indeed we didn't call them separately.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

doesnt need to be public?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aha nice catch! it was a legacy public api used in the first prototype. Removed in #5024! :D


void add_graph(const std::string &name, aot::CompiledGraph &&graph);

protected:
/**
* Intended to be overriden by each backend's implementation.
Expand All @@ -59,11 +64,18 @@ class AotModuleBuilder {
TI_NOT_IMPLEMENTED;
}

virtual void add_compiled_kernel(aot::Kernel *kernel) {
TI_NOT_IMPLEMENTED;
}

virtual void add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) = 0;

static bool all_fields_are_dense_in_container(const SNode *container);

private:
std::unordered_map<std::string, aot::CompiledGraph> graphs_;
};

} // namespace lang
Expand Down
7 changes: 6 additions & 1 deletion taichi/aot/module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace taichi {
namespace lang {

struct RuntimeContext;

class Graph;
namespace aot {

class TI_DLL_EXPORT Field {
Expand Down Expand Up @@ -90,11 +90,16 @@ class TI_DLL_EXPORT Module {
KernelTemplate *get_kernel_template(const std::string &name);
Field *get_field(const std::string &name);

virtual std::unique_ptr<Graph> get_graph(std::string name) {
TI_NOT_IMPLEMENTED;
}

protected:
virtual std::unique_ptr<Kernel> make_new_kernel(const std::string &name) = 0;
virtual std::unique_ptr<KernelTemplate> make_new_kernel_template(
const std::string &name) = 0;
virtual std::unique_ptr<Field> make_new_field(const std::string &name) = 0;
std::unordered_map<std::string, CompiledGraph> graphs_;

private:
std::unordered_map<std::string, std::unique_ptr<Kernel>> loaded_kernels_;
Expand Down
9 changes: 9 additions & 0 deletions taichi/backends/vulkan/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "taichi/aot/module_data.h"
#include "taichi/codegen/spirv/spirv_codegen.h"
#include "taichi/backends/vulkan/vulkan_graph_data.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -135,6 +136,8 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,

const std::string json_path = fmt::format("{}/metadata.json", output_dir);
converted.dump_json(json_path);

dump_graph(output_dir);
}

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Expand All @@ -147,6 +150,12 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes);
}

void AotModuleBuilderImpl::add_compiled_kernel(aot::Kernel *kernel) {
const auto register_params = static_cast<KernelImpl *>(kernel)->params();
ti_aot_data_.kernels.push_back(register_params.kernel_attribs);
ti_aot_data_.spirv_codes.push_back(register_params.task_spirv_source_codes);
}

void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
Expand Down
2 changes: 2 additions & 0 deletions taichi/backends/vulkan/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const std::string &key,
Kernel *kernel) override;

void add_compiled_kernel(aot::Kernel *kernel) override;

std::string write_spv_file(const std::string &output_dir,
const TaskAttributes &k,
const std::vector<uint32_t> &source_code) const;
Expand Down
15 changes: 15 additions & 0 deletions taichi/backends/vulkan/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <type_traits>

#include "taichi/runtime/vulkan/runtime.h"
#include "taichi/program/graph.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -39,6 +40,20 @@ class AotModuleImpl : public aot::Module {
}
ti_aot_data_.spirv_codes.push_back(spirv_sources_codes);
}

const std::string graph_path =
fmt::format("{}/graphs.tcb", params.module_path);
read_from_binary_file(graphs_, graph_path);
}

std::unique_ptr<Graph> get_graph(std::string name) override {
TI_ERROR_IF(graphs_.count(name) == 0, "Cannot find graph {}", name);
auto &compiled_graph = graphs_[name];
for (auto &dispatch : compiled_graph.dispatches) {
dispatch.compiled_kernel = get_kernel(dispatch.kernel_name);
}
auto graph = std::make_unique<Graph>(name, compiled_graph);
return graph;
}

size_t get_root_size() const override {
Expand Down
23 changes: 3 additions & 20 deletions taichi/backends/vulkan/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,14 @@
#include "taichi/backends/vulkan/aot_utils.h"
#include "taichi/runtime/vulkan/runtime.h"
#include "taichi/codegen/spirv/kernel_utils.h"

#include "taichi/aot/module_builder.h"
#include "taichi/aot/module_loader.h"
#include "taichi/backends/vulkan/aot_module_builder_impl.h"
#include "taichi/backends/vulkan/vulkan_graph_data.h"

namespace taichi {
namespace lang {
namespace vulkan {

class VkRuntime;

class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &&params)
: runtime_(runtime), params_(std::move(params)) {
}

void launch(RuntimeContext *ctx) override {
auto handle = runtime_->register_taichi_kernel(params_);
runtime_->launch_kernel(handle, ctx);
}

private:
VkRuntime *const runtime_;
const VkRuntime::RegisterParams params_;
};

struct TI_DLL_EXPORT AotModuleParams {
std::string module_path;
VkRuntime *runtime{nullptr};
Expand Down
28 changes: 28 additions & 0 deletions taichi/backends/vulkan/vulkan_graph_data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once
#include "taichi/runtime/vulkan/runtime.h"

namespace taichi {
namespace lang {
namespace vulkan {
class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &&params)
: runtime_(runtime), params_(std::move(params)) {
}

void launch(RuntimeContext *ctx) override {
auto handle = runtime_->register_taichi_kernel(params_);
runtime_->launch_kernel(handle, ctx);
}

const VkRuntime::RegisterParams &params() {
return params_;
}

private:
VkRuntime *const runtime_;
const VkRuntime::RegisterParams params_;
};
} // namespace vulkan
} // namespace lang
} // namespace taichi
101 changes: 101 additions & 0 deletions tests/cpp/aot/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "taichi/program/program.h"
#include "tests/cpp/ir/ndarray_kernel.h"
#include "tests/cpp/program/test_program.h"
#include "taichi/program/graph.h"
#ifdef TI_WITH_VULKAN
#include "taichi/backends/vulkan/aot_module_loader_impl.h"
#include "taichi/backends/device.h"
Expand Down Expand Up @@ -299,4 +300,104 @@ TEST(AotSaveLoad, VulkanNdarray) {
// Deallocate
embedded_device->device()->dealloc_memory(devalloc_arr_);
}

[[maybe_unused]] static void save_graph() {
TestProgram test_prog;
test_prog.setup(Arch::vulkan);
auto aot_builder = test_prog.prog()->make_aot_module_builder(Arch::vulkan);
auto ker1 = setup_kernel1(test_prog.prog());
auto ker2 = setup_kernel2(test_prog.prog());

auto g = std::make_unique<Graph>("test");
auto seq = g->seq();
auto arr_arg = aot::Arg{
"arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}};
seq->dispatch(ker1.get(), {arr_arg});
seq->dispatch(ker2.get(),
{arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(),
aot::ArgKind::SCALAR}});
g->compile();

aot_builder->add_graph(g->name(), g->compiled_graph());
aot_builder->dump(".", "");
}

TEST(AotLoadGraph, Vulkan) {
// Otherwise will segfault on macOS VM,
// where Vulkan is installed but no devices are present
if (!vulkan::is_vulkan_api_available()) {
return;
}

save_graph();

// API based on proposal https://github.com/taichi-dev/taichi/issues/3642
// Initialize Vulkan program
taichi::uint64 *result_buffer{nullptr};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I feel like it would be great if these tests can made the same for all the archs supporting AOT! (cc @turbo0628 @jim19930609 ) A more realistic result probably is that we have different setup routines for each arch. But the graph part stays the same.

taichi::lang::RuntimeContext host_ctx;
auto memory_pool =
std::make_unique<taichi::lang::MemoryPool>(Arch::vulkan, nullptr);
result_buffer = (taichi::uint64 *)memory_pool->allocate(
sizeof(taichi::uint64) * taichi_result_buffer_entries, 8);
host_ctx.result_buffer = result_buffer;

// Create Taichi Device for computation
lang::vulkan::VulkanDeviceCreator::Params evd_params;
evd_params.api_version =
taichi::lang::vulkan::VulkanEnvSettings::kApiVersion();
auto embedded_device =
std::make_unique<taichi::lang::vulkan::VulkanDeviceCreator>(evd_params);
taichi::lang::vulkan::VulkanDevice *device_ =
static_cast<taichi::lang::vulkan::VulkanDevice *>(
embedded_device->device());
// Create Vulkan runtime
vulkan::VkRuntime::Params params;
params.host_result_buffer = result_buffer;
params.device = device_;
auto vulkan_runtime =
std::make_unique<taichi::lang::vulkan::VkRuntime>(std::move(params));

// Run AOT module loader
vulkan::AotModuleParams mod_params;
mod_params.module_path = ".";
mod_params.runtime = vulkan_runtime.get();

std::unique_ptr<aot::Module> vk_module =
aot::Module::load(Arch::vulkan, mod_params);
EXPECT_TRUE(vk_module);

// Retrieve kernels/fields/etc from AOT module
auto root_size = vk_module->get_root_size();
EXPECT_EQ(root_size, 0);
vulkan_runtime->add_root_buffer(root_size);

auto graph = vk_module->get_graph("test");

const int size = 10;
taichi::lang::Device::AllocParams alloc_params;
alloc_params.host_write = true;
alloc_params.size = size * sizeof(int);
alloc_params.usage = taichi::lang::AllocUsage::Storage;
DeviceAllocation devalloc_arr_ = device_->allocate_memory(alloc_params);

int src[size] = {0};
src[0] = 2;
src[2] = 40;
write_devalloc(vulkan_runtime.get(), devalloc_arr_, src, sizeof(src));

std::unordered_map<std::string, aot::IValue> args;
auto arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size});
args.insert({"arr", aot::IValue::create(arr)});
args.insert({"x", aot::IValue::create<int>(2)});
graph->run(args);
vulkan_runtime->synchronize();

int dst[size] = {1};
load_devalloc(vulkan_runtime.get(), devalloc_arr_, dst, sizeof(dst));

EXPECT_EQ(dst[0], 2);
EXPECT_EQ(dst[1], 2);
EXPECT_EQ(dst[2], 42);
device_->dealloc_memory(devalloc_arr_);
}
#endif