Skip to content

Commit

Permalink
[aot] Serialize built graph, deserialize and run.
Browse files Browse the repository at this point in the history
related: taichi-dev#4786

This PR demonstrates a minimal example of serializing a built graph,
deserializing and running it.

ghstack-source-id: e369b326734b3cffaf91262ed84806044e121c3c
Pull Request resolved: taichi-dev#5016
  • Loading branch information
Ailing Zhang authored and Ailing Zhang committed May 22, 2022
1 parent fba92cf commit 303525b
Show file tree
Hide file tree
Showing 10 changed files with 193 additions and 25 deletions.
4 changes: 0 additions & 4 deletions taichi/aot/graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,6 @@ class TI_DLL_EXPORT Kernel {
* @param ctx Host context
*/
virtual void launch(RuntimeContext *ctx) = 0;

virtual void save_to_module(AotModuleBuilder *builder) {
TI_NOT_IMPLEMENTED;
}
};

struct CompiledDispatch {
Expand Down
17 changes: 17 additions & 0 deletions taichi/aot/module_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,5 +52,22 @@ void AotModuleBuilder::load(const std::string &output_dir) {
TI_ERROR("Aot loader not supported");
}

void AotModuleBuilder::dump_graph(std::string output_dir) const {
const std::string graph_file = fmt::format("{}/graphs.tcb", output_dir);

write_to_binary_file(graphs_, graph_file);
}

void AotModuleBuilder::add_graph(const std::string &name,
aot::CompiledGraph &&graph) {
if (graphs_.count(name) != 0) {
TI_ERROR("Graph {} already exists", name);
}
// Handle adding kernels separately.
for (const auto &dispatch : graph.dispatches) {
add_compiled_kernel(dispatch.compiled_kernel);
}
graphs_[name] = std::move(graph);
}
} // namespace lang
} // namespace taichi
12 changes: 12 additions & 0 deletions taichi/aot/module_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "taichi/backends/device.h"
#include "taichi/ir/snode.h"
#include "taichi/aot/module_data.h"
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -37,6 +38,10 @@ class AotModuleBuilder {
virtual void dump(const std::string &output_dir,
const std::string &filename) const = 0;

void dump_graph(std::string output_dir) const;

void add_graph(const std::string &name, aot::CompiledGraph &&graph);

protected:
/**
* Intended to be overriden by each backend's implementation.
Expand All @@ -62,13 +67,20 @@ class AotModuleBuilder {
TI_NOT_IMPLEMENTED;
}

virtual void add_compiled_kernel(aot::Kernel *kernel) {
TI_NOT_IMPLEMENTED;
}

virtual void add_per_backend_tmpl(const std::string &identifier,
const std::string &key,
Kernel *kernel) {
TI_NOT_IMPLEMENTED;
}

static bool all_fields_are_dense_in_container(const SNode *container);

private:
std::unordered_map<std::string, aot::CompiledGraph> graphs_;
};

} // namespace lang
Expand Down
7 changes: 6 additions & 1 deletion taichi/aot/module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace taichi {
namespace lang {

struct RuntimeContext;

class Graph;
namespace aot {

class TI_DLL_EXPORT Field {
Expand Down Expand Up @@ -90,11 +90,16 @@ class TI_DLL_EXPORT Module {
KernelTemplate *get_kernel_template(const std::string &name);
Field *get_field(const std::string &name);

virtual std::unique_ptr<Graph> get_graph(std::string name) {
TI_NOT_IMPLEMENTED;
}

protected:
virtual std::unique_ptr<Kernel> make_new_kernel(const std::string &name) = 0;
virtual std::unique_ptr<KernelTemplate> make_new_kernel_template(
const std::string &name) = 0;
virtual std::unique_ptr<Field> make_new_field(const std::string &name) = 0;
std::unordered_map<std::string, CompiledGraph> graphs_;

private:
std::unordered_map<std::string, std::unique_ptr<Kernel>> loaded_kernels_;
Expand Down
9 changes: 9 additions & 0 deletions taichi/backends/vulkan/aot_module_builder_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "taichi/aot/module_data.h"
#include "taichi/codegen/spirv/spirv_codegen.h"
#include "taichi/backends/vulkan/vulkan_graph_data.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -135,6 +136,8 @@ void AotModuleBuilderImpl::dump(const std::string &output_dir,

const std::string json_path = fmt::format("{}/metadata.json", output_dir);
converted.dump_json(json_path);

dump_graph(output_dir);
}

void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
Expand All @@ -147,6 +150,12 @@ void AotModuleBuilderImpl::add_per_backend(const std::string &identifier,
ti_aot_data_.spirv_codes.push_back(compiled.task_spirv_source_codes);
}

void AotModuleBuilderImpl::add_compiled_kernel(aot::Kernel *kernel) {
const auto register_params = static_cast<KernelImpl *>(kernel)->params();
ti_aot_data_.kernels.push_back(register_params.kernel_attribs);
ti_aot_data_.spirv_codes.push_back(register_params.task_spirv_source_codes);
}

void AotModuleBuilderImpl::add_field_per_backend(const std::string &identifier,
const SNode *rep_snode,
bool is_scalar,
Expand Down
2 changes: 2 additions & 0 deletions taichi/backends/vulkan/aot_module_builder_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ class AotModuleBuilderImpl : public AotModuleBuilder {
const std::string &key,
Kernel *kernel) override;

void add_compiled_kernel(aot::Kernel *kernel) override;

std::string write_spv_file(const std::string &output_dir,
const TaskAttributes &k,
const std::vector<uint32_t> &source_code) const;
Expand Down
15 changes: 15 additions & 0 deletions taichi/backends/vulkan/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <type_traits>

#include "taichi/runtime/vulkan/runtime.h"
#include "taichi/program/graph.h"

namespace taichi {
namespace lang {
Expand Down Expand Up @@ -39,6 +40,20 @@ class AotModuleImpl : public aot::Module {
}
ti_aot_data_.spirv_codes.push_back(spirv_sources_codes);
}

const std::string graph_path =
fmt::format("{}/graphs.tcb", params.module_path);
read_from_binary_file(graphs_, graph_path);
}

std::unique_ptr<Graph> get_graph(std::string name) override {
TI_ERROR_IF(graphs_.count(name) == 0, "Cannot find graph {}", name);
auto &compiled_graph = graphs_[name];
for (auto &dispatch : compiled_graph.dispatches) {
dispatch.compiled_kernel = get_kernel(dispatch.kernel_name);
}
auto graph = std::make_unique<Graph>(name, compiled_graph);
return graph;
}

size_t get_root_size() const override {
Expand Down
23 changes: 3 additions & 20 deletions taichi/backends/vulkan/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,31 +7,14 @@
#include "taichi/backends/vulkan/aot_utils.h"
#include "taichi/runtime/vulkan/runtime.h"
#include "taichi/codegen/spirv/kernel_utils.h"

#include "taichi/aot/module_builder.h"
#include "taichi/aot/module_loader.h"
#include "taichi/backends/vulkan/aot_module_builder_impl.h"
#include "taichi/backends/vulkan/vulkan_graph_data.h"

namespace taichi {
namespace lang {
namespace vulkan {

class VkRuntime;

class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &&params)
: runtime_(runtime), params_(std::move(params)) {
}

void launch(RuntimeContext *ctx) override {
auto handle = runtime_->register_taichi_kernel(params_);
runtime_->launch_kernel(handle, ctx);
}

private:
VkRuntime *const runtime_;
const VkRuntime::RegisterParams params_;
};

struct TI_DLL_EXPORT AotModuleParams {
std::string module_path;
VkRuntime *runtime{nullptr};
Expand Down
28 changes: 28 additions & 0 deletions taichi/backends/vulkan/vulkan_graph_data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#pragma once
#include "taichi/runtime/vulkan/runtime.h"

namespace taichi {
namespace lang {
namespace vulkan {
class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &&params)
: runtime_(runtime), params_(std::move(params)) {
}

void launch(RuntimeContext *ctx) override {
auto handle = runtime_->register_taichi_kernel(params_);
runtime_->launch_kernel(handle, ctx);
}

const VkRuntime::RegisterParams &params() {
return params_;
}

private:
VkRuntime *const runtime_;
const VkRuntime::RegisterParams params_;
};
} // namespace vulkan
} // namespace lang
} // namespace taichi
101 changes: 101 additions & 0 deletions tests/cpp/aot/aot_save_load_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "taichi/program/program.h"
#include "tests/cpp/ir/ndarray_kernel.h"
#include "tests/cpp/program/test_program.h"
#include "taichi/program/graph.h"
#ifdef TI_WITH_VULKAN
#include "taichi/backends/vulkan/aot_module_loader_impl.h"
#include "taichi/backends/device.h"
Expand Down Expand Up @@ -299,4 +300,104 @@ TEST(AotSaveLoad, VulkanNdarray) {
// Deallocate
embedded_device->device()->dealloc_memory(devalloc_arr_);
}

[[maybe_unused]] static void save_graph() {
TestProgram test_prog;
test_prog.setup(Arch::vulkan);
auto aot_builder = test_prog.prog()->make_aot_module_builder(Arch::vulkan);
auto ker1 = setup_kernel1(test_prog.prog());
auto ker2 = setup_kernel2(test_prog.prog());

auto g = std::make_unique<Graph>("test");
auto seq = g->seq();
auto arr_arg = aot::Arg{
"arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}};
seq->dispatch(ker1.get(), {arr_arg});
seq->dispatch(ker2.get(),
{arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(),
aot::ArgKind::SCALAR}});
g->compile();

aot_builder->add_graph(g->name(), g->compiled_graph());
aot_builder->dump(".", "");
}

TEST(AotLoadGraph, Vulkan) {
// Otherwise will segfault on macOS VM,
// where Vulkan is installed but no devices are present
if (!vulkan::is_vulkan_api_available()) {
return;
}

save_graph();

// API based on proposal https://github.com/taichi-dev/taichi/issues/3642
// Initialize Vulkan program
taichi::uint64 *result_buffer{nullptr};
taichi::lang::RuntimeContext host_ctx;
auto memory_pool =
std::make_unique<taichi::lang::MemoryPool>(Arch::vulkan, nullptr);
result_buffer = (taichi::uint64 *)memory_pool->allocate(
sizeof(taichi::uint64) * taichi_result_buffer_entries, 8);
host_ctx.result_buffer = result_buffer;

// Create Taichi Device for computation
lang::vulkan::VulkanDeviceCreator::Params evd_params;
evd_params.api_version =
taichi::lang::vulkan::VulkanEnvSettings::kApiVersion();
auto embedded_device =
std::make_unique<taichi::lang::vulkan::VulkanDeviceCreator>(evd_params);
taichi::lang::vulkan::VulkanDevice *device_ =
static_cast<taichi::lang::vulkan::VulkanDevice *>(
embedded_device->device());
// Create Vulkan runtime
vulkan::VkRuntime::Params params;
params.host_result_buffer = result_buffer;
params.device = device_;
auto vulkan_runtime =
std::make_unique<taichi::lang::vulkan::VkRuntime>(std::move(params));

// Run AOT module loader
vulkan::AotModuleParams mod_params;
mod_params.module_path = ".";
mod_params.runtime = vulkan_runtime.get();

std::unique_ptr<aot::Module> vk_module =
aot::Module::load(Arch::vulkan, mod_params);
EXPECT_TRUE(vk_module);

// Retrieve kernels/fields/etc from AOT module
auto root_size = vk_module->get_root_size();
EXPECT_EQ(root_size, 0);
vulkan_runtime->add_root_buffer(root_size);

auto graph = vk_module->get_graph("test");

const int size = 10;
taichi::lang::Device::AllocParams alloc_params;
alloc_params.host_write = true;
alloc_params.size = size * sizeof(int);
alloc_params.usage = taichi::lang::AllocUsage::Storage;
DeviceAllocation devalloc_arr_ = device_->allocate_memory(alloc_params);

int src[size] = {0};
src[0] = 2;
src[2] = 40;
write_devalloc(vulkan_runtime.get(), devalloc_arr_, src, sizeof(src));

std::unordered_map<std::string, aot::IValue> args;
auto arr = Ndarray(devalloc_arr_, PrimitiveType::i32, {size});
args.insert({"arr", aot::IValue::create(arr)});
args.insert({"x", aot::IValue::create<int>(2)});
graph->run(args);
vulkan_runtime->synchronize();

int dst[size] = {1};
load_devalloc(vulkan_runtime.get(), devalloc_arr_, dst, sizeof(dst));

EXPECT_EQ(dst[0], 2);
EXPECT_EQ(dst[1], 2);
EXPECT_EQ(dst[2], 42);
device_->dealloc_memory(devalloc_arr_);
}
#endif

0 comments on commit 303525b

Please sign in to comment.