Skip to content

Commit

Permalink
Get the POC functioning (apache#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Jul 4, 2022
1 parent 7ceb2c8 commit 212daec
Show file tree
Hide file tree
Showing 9 changed files with 63 additions and 42 deletions.
32 changes: 15 additions & 17 deletions src/script/builder/builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -39,31 +39,29 @@ void Builder::EnterWithScope() {
CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: "
<< n->frames.size()
<< ". Please use a fresh new builder every time building IRs";
n->frames.push_back(IRModuleFrame());
n->result = NullOpt;
std::vector<Builder>* stack = ThreadLocalBuilderStack();
stack->push_back(*this);
}

void Builder::ExitWithScope() {
BuilderNode* n = this->get();
ICHECK_EQ(n->frames.size(), 1);
IRModuleFrame frame = Downcast<IRModuleFrame>(n->frames.back());
n->frames.pop_back();
std::vector<Builder>* stack = ThreadLocalBuilderStack();
ICHECK(!stack->empty());
stack->pop_back();
if (!frame->stmts.empty()) {
ICHECK(frame->global_vars.empty());
ICHECK(frame->functions.empty());
n->result = frame->stmts;
} else {
Map<GlobalVar, BaseFunc> func_map;
ICHECK_EQ(frame->functions.size(), frame->global_vars.size());
int m = frame->functions.size();
for (int i = 0; i < m; ++i) {
func_map.Set(frame->global_vars[i], frame->functions[i]);
}
}
// IRModuleFrame frame = Downcast<IRModuleFrame>(n->frames.back());
// n->frames.pop_back();
// if (!frame->stmts.empty()) {
// ICHECK(frame->global_vars.empty());
// ICHECK(frame->functions.empty());
// n->result = frame->stmts;
// } else {
// Map<GlobalVar, BaseFunc> func_map;
// ICHECK_EQ(frame->functions.size(), frame->global_vars.size());
// int m = frame->functions.size();
// for (int i = 0; i < m; ++i) {
// func_map.Set(frame->global_vars[i], frame->functions[i]);
// }
// }
}

Builder Builder::Current() {
Expand Down
3 changes: 2 additions & 1 deletion src/script/builder/frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ namespace script {
namespace builder {

void FrameNode::EnterWithScope() {
LOG(INFO) << "EnterWithScope: " << this->GetTypeKey();
// Push to the current builder
Builder::Current()->frames.push_back(GetRef<Frame>(this));
}

void FrameNode::ExitWithScope() {
LOG(INFO) << "ExitWithScope: " << this->GetTypeKey();
for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) {
(*it)();
}
Expand All @@ -39,7 +41,6 @@ IRModuleFrame::IRModuleFrame() {
ObjectPtr<IRModuleFrameNode> n = make_object<IRModuleFrameNode>();
n->global_vars.clear();
n->functions.clear();
n->stmts.clear();
data_ = std::move(n);
}

Expand Down
2 changes: 0 additions & 2 deletions src/script/builder/frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,11 @@ class IRModuleFrameNode : public FrameNode {
public:
Array<GlobalVar> global_vars;
Array<BaseFunc> functions;
Array<ObjectRef> stmts;

void VisitAttrs(tvm::AttrVisitor* v) {
FrameNode::VisitAttrs(v);
v->Visit("global_vars", &global_vars);
v->Visit("functions", &functions);
v->Visit("stmts", &stmts);
}

static constexpr const char* _type_key = "script.builder.IRModuleFrame";
Expand Down
15 changes: 12 additions & 3 deletions src/script/builder/tir/base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
*/
#include "./base.h"

#include <tvm/node/node.h>
#include <tvm/support/with.h>
#include <tvm/tir/function.h>

#include "../../../printer/text_printer.h"
#include "./block_frame.h"
#include "./for_frame.h"
#include "./prim_func_frame.h"
Expand All @@ -39,8 +42,8 @@ void TestPOC() {
With<Builder> builder;
{
With<PrimFuncFrame> _{T::PrimFunc_("main")};
Buffer A = T::Buffer_({128, 128, 128}, DataType::Float(32));
Buffer B = T::Buffer_({128, 128, 128}, DataType::Float(32));
Buffer A = T::Arg(T::Buffer_({128, 128, 128}, DataType::Float(32)));
Buffer B = T::Arg(T::Buffer_({128, 128, 128}, DataType::Float(32)));
{
With<ForFrame> _{T::Grid({128, 128, 128})};
Var i = _()->vars[0];
Expand All @@ -50,12 +53,18 @@ void TestPOC() {
With<BlockFrame> _{T::Block_("block")};
IterVar vi = T::axis::Spatial(Range(0, 128), i);
IterVar vj = T::axis::Spatial(Range(0, 128), j);
IterVar vk = T::axis::Spatial(Range(0, 128), k);
IterVar vk = T::axis::Reduce(Range(0, 128), k);
}
LOG(INFO) << "ForFrame:\n" << _()->stmts;
}
LOG(INFO) << "PrimFuncFrame:\n" << _()->stmts;
}
PrimFunc func = builder()->Get<PrimFunc>();
LOG(INFO) << "func:\n" << AsTVMScript(func);
}

TVM_REGISTER_GLOBAL("test_poc").set_body_typed(TestPOC);

} // namespace tir
} // namespace builder
} // namespace script
Expand Down
11 changes: 5 additions & 6 deletions src/script/builder/tir/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,14 +54,13 @@ class TIRFrame : public Frame {

inline void AddToParent(tvm::tir::Stmt stmt) {
Builder builder = Builder::Current();
ICHECK(!builder->frames.empty());
Frame frame = builder->frames.back();
if (const auto* tir_frame = frame.as<TIRFrameNode>()) {
if (builder->frames.empty()) {
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
builder->result = stmt;
} else if (const auto* tir_frame = builder->frames.back().as<TIRFrameNode>()) {
GetRef<TIRFrame>(tir_frame)->stmts.push_back(stmt);
} else if (const auto* mod_frame = frame.as<IRModuleFrameNode>()) {
GetRef<IRModuleFrame>(mod_frame)->stmts.push_back(stmt);
} else {
LOG(FATAL) << "TypeError: Unsupported frame type: " << frame;
LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back();
}
}

Expand Down
5 changes: 3 additions & 2 deletions src/script/builder/tir/block_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ BlockFrame Block_(String name) {

void BlockFrameNode::ExitWithScope() {
using namespace tvm::tir;
TIRFrameNode::ExitWithScope();
AddToParent(BlockRealize(iter_values, //
predicate.value_or(Bool(true)),
Block(iter_vars, //
Expand Down Expand Up @@ -82,7 +83,7 @@ tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype) {

tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype) {
using namespace tvm::tir;
ICHECK(dom.defined()) << "Spatial axis must have a domain";
ICHECK(dom.defined()) << "Reduction axis must have a domain";
int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()});
return PushBlockVar(IterVar(/*dom=*/dom, //
/*var=*/Var("_", dtype.with_bits(bits)), //
Expand All @@ -93,7 +94,7 @@ tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype) {

Array<tvm::tir::IterVar> Remap(String kinds, Array<PrimExpr> bindings, DataType dtype) {
using namespace tvm::tir;
Array<tvm::tir::IterVar> results;
Array<IterVar> results;
ICHECK_EQ(kinds.size(), bindings.size());
int n = bindings.size();
results.reserve(n);
Expand Down
5 changes: 4 additions & 1 deletion src/script/builder/tir/for_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,10 @@ namespace script {
namespace builder {
namespace tir {

void ForFrameNode::ExitWithScope() { AddToParent(f_make_for_loop(vars, doms, AsStmt(stmts))); }
void ForFrameNode::ExitWithScope() {
TIRFrameNode::ExitWithScope();
AddToParent(f_make_for_loop(vars, doms, AsStmt(stmts)));
}

#define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \
ForFrame Method(PrimExpr min, PrimExpr extent, Map<String, ObjectRef> attrs) { \
Expand Down
28 changes: 20 additions & 8 deletions src/script/builder/tir/prim_func_frame.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,22 @@ namespace tir {

void PrimFuncFrameNode::ExitWithScope() {
using namespace tvm::tir;
IRModuleFrame frame = Builder::Current()->FindFrame<IRModuleFrame>().value();
frame->global_vars.push_back(GlobalVar(name));
frame->functions.push_back(PrimFunc(/*params=*/args,
/*body=*/AsStmt(stmts),
/*ret_type=*/ret_type,
/*buffer_map=*/buffer_map));
TIRFrameNode::ExitWithScope();
Builder builder = Builder::Current();
PrimFunc func(/*params=*/args,
/*body=*/AsStmt(stmts),
/*ret_type=*/ret_type,
/*buffer_map=*/buffer_map);
if (builder->frames.empty()) {
ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set";
builder->result = func;
} else if (Optional<IRModuleFrame> opt_frame = builder->FindFrame<IRModuleFrame>()) {
IRModuleFrame frame = opt_frame.value();
frame->global_vars.push_back(GlobalVar(name));
frame->functions.push_back(func);
} else {
LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc";
}
}

PrimFuncFrame PrimFunc_(String name) {
Expand All @@ -45,17 +55,19 @@ PrimFuncFrame PrimFunc_(String name) {
return PrimFuncFrame(n);
}

void Arg(tvm::tir::Var var) {
tvm::tir::Var Arg(tvm::tir::Var var) {
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
frame->args.push_back(var);
return var;
}

void Arg(tvm::tir::Buffer buffer) {
tvm::tir::Buffer Arg(tvm::tir::Buffer buffer) {
using namespace tvm::tir;
PrimFuncFrame frame = Builder::Current()->FindFrame<PrimFuncFrame>().value();
Var handle(buffer->name + "_handle", DataType::Handle());
frame->args.push_back(handle);
frame->buffer_map.Set(handle, buffer);
return buffer;
}

TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode);
Expand Down
4 changes: 2 additions & 2 deletions src/script/builder/tir/prim_func_frame.h
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ class PrimFuncFrame : public TIRFrame {
};

PrimFuncFrame PrimFunc_(String name);
void Arg(tvm::tir::Var var);
void Arg(tvm::tir::Buffer buffer);
tvm::tir::Var Arg(tvm::tir::Var var);
tvm::tir::Buffer Arg(tvm::tir::Buffer buffer);

} // namespace tir
} // namespace builder
Expand Down

0 comments on commit 212daec

Please sign in to comment.