Skip to content

Commit

Permalink
Update on "[aot] Build and run graph without serialization"
Browse files Browse the repository at this point in the history
This PR servces as the base PR with a minimal example of building and
running a Graph. Runtime values for graph arguments can be either
scalars or ndarrays.

For detailed proposal please see #4786.

Things handled in this PR:
- Maximize common code/runtime shared by the two workflows below:
  1. build -> compile -> run
  2. build -> compile -> serialize -> deserilize -> run
- Graph arguments are annotated with dtype and element shape for ndarray (temporary
until we have vec3 types in C++)

Things that we've discussed but not included in this PR:
- C API: I'll leave that for a unified C API PR in the future.
- bind IValues to graph: easy, will add later.

[ghstack-poisoned]
  • Loading branch information
Ailing Zhang committed May 20, 2022
1 parent cde1797 commit 2fb5316
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 19 deletions.
10 changes: 6 additions & 4 deletions taichi/aot/graph_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,9 @@ class AotModuleBuilder;
class Ndarray;
namespace aot {
// Currently only scalar and ndarray are supported.
enum ArgKind { SCALAR, NDARRAY, UNKNOWN };
enum class ArgKind { SCALAR, NDARRAY, UNKNOWN };

/*
/**
* Symbolic argument used in building `Dispatch` nodes in the `Graph`.
*/
struct Arg {
Expand All @@ -25,7 +25,7 @@ struct Arg {
TI_IO_DEF(name, dtype_name, tag, element_shape);
};

/*
/**
* Runtime value used in graph execution.
*/
struct IValue {
Expand All @@ -37,7 +37,8 @@ struct IValue {
return IValue(reinterpret_cast<intptr_t>(&ndarray), ArgKind::NDARRAY);
}

template <typename T>
template <typename T,
typename = std::enable_if_t<!std::is_same<T, Ndarray>::value, void>>
static IValue create(T v) {
return IValue(taichi_union_cast_with_different_sizes<uint64>(v),
ArgKind::SCALAR);
Expand All @@ -47,6 +48,7 @@ struct IValue {
IValue(uint64 val, ArgKind tag) : val(val), tag(tag) {
}
};

class TI_DLL_EXPORT Kernel {
public:
// Rule of 5 to make MSVC happy
Expand Down
1 change: 1 addition & 0 deletions taichi/backends/vulkan/aot_module_loader_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class KernelImpl : public aot::Kernel {
VkRuntime *const runtime_;
const VkRuntime::RegisterParams params_;
};

struct TI_DLL_EXPORT AotModuleParams {
std::string module_path;
VkRuntime *runtime{nullptr};
Expand Down
16 changes: 9 additions & 7 deletions taichi/program/graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ namespace lang {

void Dispatch::compile(
std::vector<aot::CompiledDispatch> &compiled_dispatches) {
if (compiled_kernel_)
return;
compiled_kernel_ = kernel_->compile_to_aot_kernel();
aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_,
compiled_kernel_.get()};
Expand All @@ -28,21 +30,21 @@ void Sequential::append(Node *node) {
sequence_.push_back(node);
}

void Sequential::emplace(Kernel *kernel, const std::vector<aot::Arg> &args) {
Node *n = owning_graph_->create_dispatch(kernel, args);
void Sequential::dispatch(Kernel *kernel, const std::vector<aot::Arg> &args) {
Node *n = owning_graph_->new_dispatch_node(kernel, args);
sequence_.push_back(n);
}

Graph::Graph(std::string name) : name_(name) {
seq_ = std::make_unique<Sequential>(this);
}
Node *Graph::create_dispatch(Kernel *kernel,
const std::vector<aot::Arg> &args) {
Node *Graph::new_dispatch_node(Kernel *kernel,
const std::vector<aot::Arg> &args) {
all_nodes_.push_back(std::make_unique<Dispatch>(kernel, args));
return all_nodes_.back().get();
}

Sequential *Graph::create_sequential() {
Sequential *Graph::new_sequential_node() {
all_nodes_.push_back(std::make_unique<Sequential>(this));
return static_cast<Sequential *>(all_nodes_.back().get());
}
Expand All @@ -55,8 +57,8 @@ Sequential *Graph::seq() const {
return seq_.get();
}

void Graph::emplace(Kernel *kernel, const std::vector<aot::Arg> &args) {
seq()->emplace(kernel, args);
void Graph::dispatch(Kernel *kernel, const std::vector<aot::Arg> &args) {
seq()->dispatch(kernel, args);
}

void Graph::run(
Expand Down
8 changes: 4 additions & 4 deletions taichi/program/graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ class Sequential : public Node {

void append(Node *node);

void emplace(Kernel *kernel, const std::vector<aot::Arg> &args);
void dispatch(Kernel *kernel, const std::vector<aot::Arg> &args);

void compile(
std::vector<aot::CompiledDispatch> &compiled_dispatches) override;
Expand Down Expand Up @@ -86,11 +86,11 @@ class Graph {

void run(const std::unordered_map<std::string, aot::IValue> &args) const;

Node *create_dispatch(Kernel *kernel, const std::vector<aot::Arg> &args);
Node *new_dispatch_node(Kernel *kernel, const std::vector<aot::Arg> &args);

Sequential *create_sequential();
Sequential *new_sequential_node();

void emplace(Kernel *kernel, const std::vector<aot::Arg> &args);
void dispatch(Kernel *kernel, const std::vector<aot::Arg> &args);

Sequential *seq() const;

Expand Down
8 changes: 4 additions & 4 deletions tests/cpp/program/graph_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,10 @@ TEST(GraphTest, SimpleGraphRun) {
auto seq = g->seq();
auto arr_arg = aot::Arg{
"arr", PrimitiveType::i32.to_string(), aot::ArgKind::NDARRAY, {}};
seq->emplace(ker1.get(), {arr_arg});
seq->emplace(ker2.get(),
{arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(),
aot::ArgKind::SCALAR}});
seq->dispatch(ker1.get(), {arr_arg});
seq->dispatch(ker2.get(),
{arr_arg, aot::Arg{"x", PrimitiveType::i32.to_string(),
aot::ArgKind::SCALAR}});
g->compile();

auto array = Ndarray(test_prog.prog(), PrimitiveType::i32, {size});
Expand Down

0 comments on commit 2fb5316

Please sign in to comment.