forked from taichi-dev/taichi
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[aot] Build and run graph without serialization
This PR servces as the base PR with a minimal example of building and running a Graph. Runtime values for graph arguments can be either scalars or ndarrays. For detailed proposal please see taichi-dev#4786. Things handled in this PR: - Maximize common code/runtime shared by the two workflows below: 1. build -> compile -> run 2. build -> compile -> serialize -> deserilize -> run - Graph arguments are annotated with dtype and element shape for ndarray (temporary until we have vec3 types in C++) Things that we've discussed but not included in this PR: - C API: I'll leave that for a unified C API PR in the future. - bind IValues to graph: easy, will add later.
- Loading branch information
Ailing Zhang
committed
May 19, 2022
1 parent
b9d8c50
commit 1a2efdd
Showing
16 changed files
with
403 additions
and
48 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
#pragma once | ||
#include <vector> | ||
#include <string> | ||
#include <unordered_map> | ||
#include "taichi/aot/module_data.h" | ||
|
||
namespace taichi { | ||
namespace lang { | ||
class AotModuleBuilder; | ||
class Ndarray; | ||
namespace aot { | ||
// Currently only scalar and ndarray are supported. | ||
enum ArgKind { SCALAR, NDARRAY, UNKNOWN }; | ||
|
||
/* | ||
* Symbolic argument used in building `Dispatch` nodes in the `Graph`. | ||
*/ | ||
struct Arg { | ||
std::string name; | ||
// TODO: real element dtype = dtype + element_shape | ||
std::string dtype_name; | ||
ArgKind tag; | ||
std::vector<int> element_shape; | ||
|
||
TI_IO_DEF(name, dtype_name, tag, element_shape); | ||
}; | ||
|
||
/* | ||
* Runtime value used in graph execution. | ||
*/ | ||
struct IValue { | ||
public: | ||
uint64 val; | ||
ArgKind tag; | ||
|
||
static IValue create(const Ndarray &ndarray) { | ||
return IValue(reinterpret_cast<intptr_t>(&ndarray), ArgKind::NDARRAY); | ||
} | ||
|
||
template <typename T> | ||
static IValue create(T v) { | ||
return IValue(taichi_union_cast_with_different_sizes<uint64>(v), | ||
ArgKind::SCALAR); | ||
} | ||
|
||
private: | ||
IValue(uint64 val, ArgKind tag) : val(val), tag(tag) { | ||
} | ||
}; | ||
class TI_DLL_EXPORT Kernel { | ||
public: | ||
// Rule of 5 to make MSVC happy | ||
Kernel() = default; | ||
virtual ~Kernel() = default; | ||
Kernel(const Kernel &) = delete; | ||
Kernel &operator=(const Kernel &) = delete; | ||
Kernel(Kernel &&) = default; | ||
Kernel &operator=(Kernel &&) = default; | ||
|
||
/** | ||
* @brief Launches the kernel to the device | ||
* | ||
* This does not manage the device to host synchronization. | ||
* | ||
* @param ctx Host context | ||
*/ | ||
virtual void launch(RuntimeContext *ctx) = 0; | ||
|
||
virtual void save_to_module(AotModuleBuilder *builder) { | ||
TI_NOT_IMPLEMENTED; | ||
} | ||
}; | ||
|
||
struct CompiledDispatch { | ||
std::string kernel_name; | ||
std::vector<Arg> symbolic_args; | ||
Kernel *compiled_kernel{nullptr}; | ||
|
||
TI_IO_DEF(kernel_name, symbolic_args); | ||
}; | ||
|
||
struct CompiledGraph { | ||
std::vector<CompiledDispatch> dispatches; | ||
|
||
TI_IO_DEF(dispatches); | ||
}; | ||
|
||
} // namespace aot | ||
} // namespace lang | ||
} // namespace taichi |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
#include "taichi/program/graph.h" | ||
#include "taichi/program/kernel.h" | ||
#include "taichi/aot/module_builder.h" | ||
#include "spdlog/fmt/fmt.h" | ||
|
||
#include <fstream> | ||
|
||
namespace taichi { | ||
namespace lang { | ||
|
||
void Dispatch::compile( | ||
std::vector<aot::CompiledDispatch> &compiled_dispatches) { | ||
compiled_kernel_ = kernel_->compile_to_aot_kernel(); | ||
aot::CompiledDispatch dispatch{kernel_->get_name(), symbolic_args_, | ||
compiled_kernel_.get()}; | ||
compiled_dispatches.push_back(std::move(dispatch)); | ||
} | ||
|
||
void Sequential::compile( | ||
std::vector<aot::CompiledDispatch> &compiled_dispatches) { | ||
// In the future we can do more across-kernel optimization here. | ||
for (Node *n : sequence_) { | ||
n->compile(compiled_dispatches); | ||
} | ||
} | ||
|
||
void Sequential::append(Node *node) { | ||
sequence_.push_back(node); | ||
} | ||
|
||
void Sequential::emplace(Kernel *kernel, const std::vector<aot::Arg> &args) { | ||
Node *n = owning_graph_->create_dispatch(kernel, args); | ||
sequence_.push_back(n); | ||
} | ||
|
||
Graph::Graph(std::string name) : name_(name) { | ||
seq_ = std::make_unique<Sequential>(this); | ||
} | ||
Node *Graph::create_dispatch(Kernel *kernel, | ||
const std::vector<aot::Arg> &args) { | ||
all_nodes_.push_back(std::make_unique<Dispatch>(kernel, args)); | ||
return all_nodes_.back().get(); | ||
} | ||
|
||
Sequential *Graph::create_sequential() { | ||
all_nodes_.push_back(std::make_unique<Sequential>(this)); | ||
return static_cast<Sequential *>(all_nodes_.back().get()); | ||
} | ||
|
||
void Graph::compile() { | ||
seq()->compile(compiled_graph_.dispatches); | ||
} | ||
|
||
Sequential *Graph::seq() const { | ||
return seq_.get(); | ||
} | ||
|
||
void Graph::emplace(Kernel *kernel, const std::vector<aot::Arg> &args) { | ||
seq()->emplace(kernel, args); | ||
} | ||
|
||
void Graph::run( | ||
const std::unordered_map<std::string, aot::IValue> &args) const { | ||
RuntimeContext ctx; | ||
for (const auto &dispatch : compiled_graph_.dispatches) { | ||
memset(&ctx, 0, sizeof(RuntimeContext)); | ||
|
||
TI_ASSERT(dispatch.compiled_kernel); | ||
// Populate args metadata into RuntimeContext | ||
const auto &symbolic_args_ = dispatch.symbolic_args; | ||
for (int i = 0; i < symbolic_args_.size(); ++i) { | ||
auto &symbolic_arg = symbolic_args_[i]; | ||
auto found = args.find(symbolic_arg.name); | ||
TI_ERROR_IF(found == args.end(), "Missing runtime value for {}", | ||
symbolic_arg.name); | ||
const aot::IValue &ival = found->second; | ||
if (ival.tag == aot::ArgKind::NDARRAY) { | ||
Ndarray *arr = reinterpret_cast<Ndarray *>(ival.val); | ||
TI_ERROR_IF((symbolic_arg.tag != ival.tag) || | ||
(symbolic_arg.element_shape != arr->shape), | ||
"Mismatched shape information for argument {}", | ||
symbolic_arg.name); | ||
set_runtime_ctx_ndarray(&ctx, i, arr); | ||
} else { | ||
TI_ERROR_IF(symbolic_arg.tag != aot::ArgKind::SCALAR, | ||
"Required a scalar for argument {}", symbolic_arg.name); | ||
ctx.set_arg(i, ival.val); | ||
} | ||
} | ||
|
||
dispatch.compiled_kernel->launch(&ctx); | ||
} | ||
} | ||
} // namespace lang | ||
} // namespace taichi |
Oops, something went wrong.