From 8bc748b2e208a234b6b210510a1374bbe9204918 Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Thu, 26 May 2022 14:58:19 -0700 Subject: [PATCH] Get the POC functioning (#36) --- src/script/builder/builder.cc | 32 +++++++++++------------ src/script/builder/frame.cc | 3 ++- src/script/builder/frame.h | 2 -- src/script/builder/tir/base.cc | 15 ++++++++--- src/script/builder/tir/base.h | 11 ++++---- src/script/builder/tir/block_frame.cc | 5 ++-- src/script/builder/tir/for_frame.cc | 5 +++- src/script/builder/tir/prim_func_frame.cc | 28 ++++++++++++++------ src/script/builder/tir/prim_func_frame.h | 4 +-- 9 files changed, 63 insertions(+), 42 deletions(-) diff --git a/src/script/builder/builder.cc b/src/script/builder/builder.cc index ed64de24396c..4c8fc70f2615 100644 --- a/src/script/builder/builder.cc +++ b/src/script/builder/builder.cc @@ -39,31 +39,29 @@ void Builder::EnterWithScope() { CHECK(n->frames.empty()) << "ValueError: There are frame(s) left in the builder: " << n->frames.size() << ". Please use a fresh new builder every time building IRs"; - n->frames.push_back(IRModuleFrame()); + n->result = NullOpt; std::vector* stack = ThreadLocalBuilderStack(); stack->push_back(*this); } void Builder::ExitWithScope() { - BuilderNode* n = this->get(); - ICHECK_EQ(n->frames.size(), 1); - IRModuleFrame frame = Downcast(n->frames.back()); - n->frames.pop_back(); std::vector* stack = ThreadLocalBuilderStack(); ICHECK(!stack->empty()); stack->pop_back(); - if (!frame->stmts.empty()) { - ICHECK(frame->global_vars.empty()); - ICHECK(frame->functions.empty()); - n->result = frame->stmts; - } else { - Map func_map; - ICHECK_EQ(frame->functions.size(), frame->global_vars.size()); - int m = frame->functions.size(); - for (int i = 0; i < m; ++i) { - func_map.Set(frame->global_vars[i], frame->functions[i]); - } - } + // IRModuleFrame frame = Downcast(n->frames.back()); + // n->frames.pop_back(); + // if (!frame->stmts.empty()) { + // ICHECK(frame->global_vars.empty()); + // ICHECK(frame->functions.empty()); + // n->result = frame->stmts; + // } else { + // Map func_map; + // ICHECK_EQ(frame->functions.size(), frame->global_vars.size()); + // int m = frame->functions.size(); + // for (int i = 0; i < m; ++i) { + // func_map.Set(frame->global_vars[i], frame->functions[i]); + // } + // } } Builder Builder::Current() { diff --git a/src/script/builder/frame.cc b/src/script/builder/frame.cc index 8db03cfbc482..4fe10c2cc630 100644 --- a/src/script/builder/frame.cc +++ b/src/script/builder/frame.cc @@ -23,11 +23,13 @@ namespace script { namespace builder { void FrameNode::EnterWithScope() { + LOG(INFO) << "EnterWithScope: " << this->GetTypeKey(); // Push to the current builder Builder::Current()->frames.push_back(GetRef(this)); } void FrameNode::ExitWithScope() { + LOG(INFO) << "ExitWithScope: " << this->GetTypeKey(); for (auto it = callbacks.rbegin(); it != callbacks.rend(); ++it) { (*it)(); } @@ -39,7 +41,6 @@ IRModuleFrame::IRModuleFrame() { ObjectPtr n = make_object(); n->global_vars.clear(); n->functions.clear(); - n->stmts.clear(); data_ = std::move(n); } diff --git a/src/script/builder/frame.h b/src/script/builder/frame.h index 04916b6b8842..bcb1d90c88f9 100644 --- a/src/script/builder/frame.h +++ b/src/script/builder/frame.h @@ -60,13 +60,11 @@ class IRModuleFrameNode : public FrameNode { public: Array global_vars; Array functions; - Array stmts; void VisitAttrs(tvm::AttrVisitor* v) { FrameNode::VisitAttrs(v); v->Visit("global_vars", &global_vars); v->Visit("functions", &functions); - v->Visit("stmts", &stmts); } static constexpr const char* _type_key = "script.builder.IRModuleFrame"; diff --git a/src/script/builder/tir/base.cc b/src/script/builder/tir/base.cc index 4f764435136a..db0c1c3bf939 100644 --- a/src/script/builder/tir/base.cc +++ b/src/script/builder/tir/base.cc @@ -18,8 +18,11 @@ */ #include "./base.h" +#include #include +#include +#include "../../../printer/text_printer.h" #include "./block_frame.h" #include "./for_frame.h" #include "./prim_func_frame.h" @@ -39,8 +42,8 @@ void TestPOC() { With builder; { With _{T::PrimFunc_("main")}; - Buffer A = T::Buffer_({128, 128, 128}, DataType::Float(32)); - Buffer B = T::Buffer_({128, 128, 128}, DataType::Float(32)); + Buffer A = T::Arg(T::Buffer_({128, 128, 128}, DataType::Float(32))); + Buffer B = T::Arg(T::Buffer_({128, 128, 128}, DataType::Float(32))); { With _{T::Grid({128, 128, 128})}; Var i = _()->vars[0]; @@ -50,12 +53,18 @@ void TestPOC() { With _{T::Block_("block")}; IterVar vi = T::axis::Spatial(Range(0, 128), i); IterVar vj = T::axis::Spatial(Range(0, 128), j); - IterVar vk = T::axis::Spatial(Range(0, 128), k); + IterVar vk = T::axis::Reduce(Range(0, 128), k); } + LOG(INFO) << "ForFrame:\n" << _()->stmts; } + LOG(INFO) << "PrimFuncFrame:\n" << _()->stmts; } + PrimFunc func = builder()->Get(); + LOG(INFO) << "func:\n" << AsTVMScript(func); } +TVM_REGISTER_GLOBAL("test_poc").set_body_typed(TestPOC); + } // namespace tir } // namespace builder } // namespace script diff --git a/src/script/builder/tir/base.h b/src/script/builder/tir/base.h index a56826eb0718..25107bf95d75 100644 --- a/src/script/builder/tir/base.h +++ b/src/script/builder/tir/base.h @@ -54,14 +54,13 @@ class TIRFrame : public Frame { inline void AddToParent(tvm::tir::Stmt stmt) { Builder builder = Builder::Current(); - ICHECK(!builder->frames.empty()); - Frame frame = builder->frames.back(); - if (const auto* tir_frame = frame.as()) { + if (builder->frames.empty()) { + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = stmt; + } else if (const auto* tir_frame = builder->frames.back().as()) { GetRef(tir_frame)->stmts.push_back(stmt); - } else if (const auto* mod_frame = frame.as()) { - GetRef(mod_frame)->stmts.push_back(stmt); } else { - LOG(FATAL) << "TypeError: Unsupported frame type: " << frame; + LOG(FATAL) << "TypeError: Unsupported frame type: " << builder->frames.back(); } } diff --git a/src/script/builder/tir/block_frame.cc b/src/script/builder/tir/block_frame.cc index 94adfead92e7..379bc0dd113b 100644 --- a/src/script/builder/tir/block_frame.cc +++ b/src/script/builder/tir/block_frame.cc @@ -42,6 +42,7 @@ BlockFrame Block_(String name) { void BlockFrameNode::ExitWithScope() { using namespace tvm::tir; + TIRFrameNode::ExitWithScope(); AddToParent(BlockRealize(iter_values, // predicate.value_or(Bool(true)), Block(iter_vars, // @@ -82,7 +83,7 @@ tvm::tir::IterVar Spatial(Range dom, PrimExpr binding, DataType dtype) { tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype) { using namespace tvm::tir; - ICHECK(dom.defined()) << "Spatial axis must have a domain"; + ICHECK(dom.defined()) << "Reduction axis must have a domain"; int bits = std::max({dom->min.dtype().bits(), dom->extent.dtype().bits(), dtype.bits()}); return PushBlockVar(IterVar(/*dom=*/dom, // /*var=*/Var("_", dtype.with_bits(bits)), // @@ -93,7 +94,7 @@ tvm::tir::IterVar Reduce(Range dom, PrimExpr binding, DataType dtype) { Array Remap(String kinds, Array bindings, DataType dtype) { using namespace tvm::tir; - Array results; + Array results; ICHECK_EQ(kinds.size(), bindings.size()); int n = bindings.size(); results.reserve(n); diff --git a/src/script/builder/tir/for_frame.cc b/src/script/builder/tir/for_frame.cc index 82b879f7d494..f22d818cc673 100644 --- a/src/script/builder/tir/for_frame.cc +++ b/src/script/builder/tir/for_frame.cc @@ -23,7 +23,10 @@ namespace script { namespace builder { namespace tir { -void ForFrameNode::ExitWithScope() { AddToParent(f_make_for_loop(vars, doms, AsStmt(stmts))); } +void ForFrameNode::ExitWithScope() { + TIRFrameNode::ExitWithScope(); + AddToParent(f_make_for_loop(vars, doms, AsStmt(stmts))); +} #define TVM_SCRIPT_BUILDER_TIR_FOR_CREATE(Method, Kind) \ ForFrame Method(PrimExpr min, PrimExpr extent, Map attrs) { \ diff --git a/src/script/builder/tir/prim_func_frame.cc b/src/script/builder/tir/prim_func_frame.cc index 74371def630f..d052624a6123 100644 --- a/src/script/builder/tir/prim_func_frame.cc +++ b/src/script/builder/tir/prim_func_frame.cc @@ -28,12 +28,22 @@ namespace tir { void PrimFuncFrameNode::ExitWithScope() { using namespace tvm::tir; - IRModuleFrame frame = Builder::Current()->FindFrame().value(); - frame->global_vars.push_back(GlobalVar(name)); - frame->functions.push_back(PrimFunc(/*params=*/args, - /*body=*/AsStmt(stmts), - /*ret_type=*/ret_type, - /*buffer_map=*/buffer_map)); + TIRFrameNode::ExitWithScope(); + Builder builder = Builder::Current(); + PrimFunc func(/*params=*/args, + /*body=*/AsStmt(stmts), + /*ret_type=*/ret_type, + /*buffer_map=*/buffer_map); + if (builder->frames.empty()) { + ICHECK(!builder->result.defined()) << "ValueError: Builder.result has already been set"; + builder->result = func; + } else if (Optional opt_frame = builder->FindFrame()) { + IRModuleFrame frame = opt_frame.value(); + frame->global_vars.push_back(GlobalVar(name)); + frame->functions.push_back(func); + } else { + LOG(FATAL) << "ValueError: Cannot find where to insert PrimFunc"; + } } PrimFuncFrame PrimFunc_(String name) { @@ -45,17 +55,19 @@ PrimFuncFrame PrimFunc_(String name) { return PrimFuncFrame(n); } -void Arg(tvm::tir::Var var) { +tvm::tir::Var Arg(tvm::tir::Var var) { PrimFuncFrame frame = Builder::Current()->FindFrame().value(); frame->args.push_back(var); + return var; } -void Arg(tvm::tir::Buffer buffer) { +tvm::tir::Buffer Arg(tvm::tir::Buffer buffer) { using namespace tvm::tir; PrimFuncFrame frame = Builder::Current()->FindFrame().value(); Var handle(buffer->name + "_handle", DataType::Handle()); frame->args.push_back(handle); frame->buffer_map.Set(handle, buffer); + return buffer; } TVM_REGISTER_NODE_TYPE(PrimFuncFrameNode); diff --git a/src/script/builder/tir/prim_func_frame.h b/src/script/builder/tir/prim_func_frame.h index 721bfb88dc6f..7da51dbafbbe 100644 --- a/src/script/builder/tir/prim_func_frame.h +++ b/src/script/builder/tir/prim_func_frame.h @@ -54,8 +54,8 @@ class PrimFuncFrame : public TIRFrame { }; PrimFuncFrame PrimFunc_(String name); -void Arg(tvm::tir::Var var); -void Arg(tvm::tir::Buffer buffer); +tvm::tir::Var Arg(tvm::tir::Var var); +tvm::tir::Buffer Arg(tvm::tir::Buffer buffer); } // namespace tir } // namespace builder