Skip to content

Commit

Permalink
[aot] Build and run graph without serialization
Browse files Browse the repository at this point in the history
This PR servces as the base PR with a minimal example of building and
running a Graph. Runtime values for graph arguments can be either
scalars or ndarrays.

For detailed proposal please see taichi-dev#4786.

Things handled in this PR:
- Maximize common code/runtime shared by the two workflows below:
  1. build -> compile -> run
  2. build -> compile -> serialize -> deserilize -> run
- Graph arguments are annotated with dtype and element shape for ndarray (temporary
until we have vec3 types in C++)

Things that we've discussed but not included in this PR:
- C API: I'll leave that for a unified C API PR in the future.
- bind IValues to graph: easy, will add later.
  • Loading branch information
Ailing Zhang committed May 19, 2022
1 parent b9d8c50 commit 1a2efdd
Show file tree
Hide file tree
Showing 16 changed files with 403 additions and 48 deletions.
90 changes: 90 additions & 0 deletions taichi/aot/graph_data.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
#pragma once
#include <vector>
#include <string>
#include <unordered_map>
#include "taichi/aot/module_data.h"

namespace taichi {
namespace lang {
class AotModuleBuilder;
class Ndarray;
namespace aot {
// Currently only scalar and ndarray are supported.
enum ArgKind { SCALAR, NDARRAY, UNKNOWN };

/*
* Symbolic argument used in building `Dispatch` nodes in the `Graph`.
*/
struct Arg {
std::string name;
// TODO: real element dtype = dtype + element_shape
std::string dtype_name;
ArgKind tag;
std::vector<int> element_shape;

TI_IO_DEF(name, dtype_name, tag, element_shape);
};

/*
* Runtime value used in graph execution.
*/
struct IValue {
public:
uint64 val;
ArgKind tag;

static IValue create(const Ndarray &ndarray) {
return IValue(reinterpret_cast<intptr_t>(&ndarray), ArgKind::NDARRAY);
}

template <typename T>
static IValue create(T v) {
return IValue(taichi_union_cast_with_different_sizes<uint64>(v),
ArgKind::SCALAR);
}

private:
IValue(uint64 val, ArgKind tag) : val(val), tag(tag) {
}
};
class TI_DLL_EXPORT Kernel {
public:
// Rule of 5 to make MSVC happy
Kernel() = default;
virtual ~Kernel() = default;
Kernel(const Kernel &) = delete;
Kernel &operator=(const Kernel &) = delete;
Kernel(Kernel &&) = default;
Kernel &operator=(Kernel &&) = default;

/**
* @brief Launches the kernel to the device
*
* This does not manage the device to host synchronization.
*
* @param ctx Host context
*/
virtual void launch(RuntimeContext *ctx) = 0;

virtual void save_to_module(AotModuleBuilder *builder) {
TI_NOT_IMPLEMENTED;
}
};

struct CompiledDispatch {
std::string kernel_name;
std::vector<Arg> symbolic_args;
Kernel *compiled_kernel{nullptr};

TI_IO_DEF(kernel_name, symbolic_args);
};

struct CompiledGraph {
std::vector<CompiledDispatch> dispatches;

TI_IO_DEF(dispatches);
};

} // namespace aot
} // namespace lang
} // namespace taichi
22 changes: 1 addition & 21 deletions taichi/aot/module_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
#include "taichi/aot/module_data.h"
#include "taichi/backends/device.h"
#include "taichi/ir/snode.h"
#include "taichi/aot/module_data.h"
#include "taichi/aot/graph_data.h"

namespace taichi {
namespace lang {
Expand All @@ -30,26 +30,6 @@ class TI_DLL_EXPORT Field {
Field &operator=(Field &&) = default;
};

class TI_DLL_EXPORT Kernel {
public:
// Rule of 5 to make MSVC happy
Kernel() = default;
virtual ~Kernel() = default;
Kernel(const Kernel &) = delete;
Kernel &operator=(const Kernel &) = delete;
Kernel(Kernel &&) = default;
Kernel &operator=(Kernel &&) = default;

/**
* @brief Launches the kernel to the device
*
* This does not manage the device to host synchronization.
*
* @param ctx Host context
*/
virtual void launch(RuntimeContext *ctx) = 0;
};

class TI_DLL_EXPORT KernelTemplateArg {
public:
using ArgUnion = std::variant<bool, int64_t, uint64_t, const Field *>;
Expand Down
21 changes: 1 addition & 20 deletions taichi/backends/vulkan/aot_module_loader_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,6 @@ namespace taichi {
namespace lang {
namespace vulkan {
namespace {

using KernelHandle = VkRuntime::KernelHandle;

class FieldImpl : public aot::Field {
public:
explicit FieldImpl(VkRuntime *runtime, const aot::CompiledFieldData &field)
Expand All @@ -23,21 +20,6 @@ class FieldImpl : public aot::Field {
aot::CompiledFieldData field_;
};

class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, KernelHandle handle)
: runtime_(runtime), handle_(handle) {
}

void launch(RuntimeContext *ctx) override {
runtime_->launch_kernel(handle_, ctx);
}

private:
VkRuntime *const runtime_;
const KernelHandle handle_;
};

class AotModuleImpl : public aot::Module {
public:
explicit AotModuleImpl(const AotModuleParams &params)
Expand Down Expand Up @@ -109,8 +91,7 @@ class AotModuleImpl : public aot::Module {
TI_DEBUG("Failed to load kernel {}", name);
return nullptr;
}
auto handle = runtime_->register_taichi_kernel(kparams);
return std::make_unique<KernelImpl>(runtime_, handle);
return std::make_unique<KernelImpl>(runtime_, std::move(kparams));
}

std::unique_ptr<aot::KernelTemplate> make_new_kernel_template(
Expand Down
15 changes: 15 additions & 0 deletions taichi/backends/vulkan/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,21 @@ namespace vulkan {

class VkRuntime;

class KernelImpl : public aot::Kernel {
public:
explicit KernelImpl(VkRuntime *runtime, VkRuntime::RegisterParams &&params)
: runtime_(runtime), params_(std::move(params)) {
}

void launch(RuntimeContext *ctx) override {
auto handle = runtime_->register_taichi_kernel(params_);
runtime_->launch_kernel(handle, ctx);
}

private:
VkRuntime *const runtime_;
const VkRuntime::RegisterParams params_;
};
struct TI_DLL_EXPORT AotModuleParams {
std::string module_path;
VkRuntime *runtime{nullptr};
Expand Down
11 changes: 11 additions & 0 deletions taichi/backends/vulkan/vulkan_program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

#include "taichi/backends/vulkan/aot_module_builder_impl.h"
#include "taichi/backends/vulkan/snode_tree_manager.h"
#include "taichi/backends/vulkan/aot_module_loader_impl.h"

#if !defined(ANDROID) && !defined(TI_EMSCRIPTENED)
#include "GLFW/glfw3.h"
Expand Down Expand Up @@ -183,6 +184,16 @@ DeviceAllocation VulkanProgramImpl::allocate_memory_ndarray(
/*export_sharing=*/false});
}

std::unique_ptr<aot::Kernel> VulkanProgramImpl::make_aot_kernel(
Kernel &kernel) {
spirv::lower(&kernel);
std::vector<CompiledSNodeStructs> compiled_structs;
VkRuntime::RegisterParams kparams =
run_codegen(&kernel, get_compute_device(), compiled_structs);
return std::make_unique<KernelImpl>(vulkan_runtime_.get(),
std::move(kparams));
}

VulkanProgramImpl::~VulkanProgramImpl() {
vulkan_runtime_.reset();
embedded_device_.reset();
Expand Down
2 changes: 2 additions & 0 deletions taichi/backends/vulkan/vulkan_program.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class VulkanProgramImpl : public ProgramImpl {
return snode_tree_mgr_->get_snode_tree_device_ptr(tree_id);
}

std::unique_ptr<aot::Kernel> make_aot_kernel(Kernel &kernel) override;

~VulkanProgramImpl();

private:
Expand Down
95 changes: 95 additions & 0 deletions taichi/program/graph.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
#include "taichi/program/graph.h"
#include "taichi/program/kernel.h"
#include "taichi/aot/module_builder.h"
#include "spdlog/fmt/fmt.h"

#include <fstream>

namespace taichi {
namespace lang {

void Dispatch::compile(
std::vector<aot::CompiledDispatch> &compiled_dispatches) {
compiled_kernel_ = kernel_->compile_to_aot_kernel();
aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_,
compiled_kernel_.get()};
compiled_dispatches.push_back(std::move(dispatch));
}

void Sequential::compile(
std::vector<aot::CompiledDispatch> &compiled_dispatches) {
// In the future we can do more across-kernel optimization here.
for (Node *n : sequence_) {
n->compile(compiled_dispatches);
}
}

void Sequential::append(Node *node) {
sequence_.push_back(node);
}

void Sequential::emplace(Kernel *kernel, const std::vector<aot::Arg> &args) {
Node *n = owning_graph_->create_dispatch(kernel, args);
sequence_.push_back(n);
}

Graph::Graph(std::string name) : name_(name) {
seq_ = std::make_unique<Sequential>(this);
}
Node *Graph::create_dispatch(Kernel *kernel,
const std::vector<aot::Arg> &args) {
all_nodes_.push_back(std::make_unique<Dispatch>(kernel, args));
return all_nodes_.back().get();
}

Sequential *Graph::create_sequential() {
all_nodes_.push_back(std::make_unique<Sequential>(this));
return static_cast<Sequential *>(all_nodes_.back().get());
}

void Graph::compile() {
seq()->compile(compiled_graph_.dispatches);
}

Sequential *Graph::seq() const {
return seq_.get();
}

void Graph::emplace(Kernel *kernel, const std::vector<aot::Arg> &args) {
seq()->emplace(kernel, args);
}

void Graph::run(
const std::unordered_map<std::string, aot::IValue> &args) const {
RuntimeContext ctx;
for (const auto &dispatch : compiled_graph_.dispatches) {
memset(&ctx, 0, sizeof(RuntimeContext));

TI_ASSERT(dispatch.compiled_kernel);
// Populate args metadata into RuntimeContext
const auto &symbolic_args_ = dispatch.symbolic_args;
for (int i = 0; i < symbolic_args_.size(); ++i) {
auto &symbolic_arg = symbolic_args_[i];
auto found = args.find(symbolic_arg.name);
TI_ERROR_IF(found == args.end(), "Missing runtime value for {}",
symbolic_arg.name);
const aot::IValue &ival = found->second;
if (ival.tag == aot::ArgKind::NDARRAY) {
Ndarray *arr = reinterpret_cast<Ndarray *>(ival.val);
TI_ERROR_IF((symbolic_arg.tag != ival.tag) ||
(symbolic_arg.element_shape != arr->shape),
"Mismatched shape information for argument {}",
symbolic_arg.name);
set_runtime_ctx_ndarray(&ctx, i, arr);
} else {
TI_ERROR_IF(symbolic_arg.tag != aot::ArgKind::SCALAR,
"Required a scalar for argument {}", symbolic_arg.name);
ctx.set_arg(i, ival.val);
}
}

dispatch.compiled_kernel->launch(&ctx);
}
}
} // namespace lang
} // namespace taichi
Loading

0 comments on commit 1a2efdd

Please sign in to comment.