Skip to content

Commit

Permalink
[TIR] Support AllocateConst nodes in TensorIR scheduling flow (#12489)
Browse files Browse the repository at this point in the history
* [TIR] Support AllocConstantNode in CreatePrimFunc

* Handle AllocConstantNode in LeafBlockRemovalPlan

* Properly handle AllocConstNode in BufferAllocationLocator

* handle AllocateConst in EstimateFlops

* remove NDArray printing

* doc update

* add test

* cpplint

* Removed dependency on link-params attribute from target

* Restored NDArray printing to unbreak test
  • Loading branch information
masahi authored Aug 22, 2022
1 parent 7c318d7 commit 8146a9b
Show file tree
Hide file tree
Showing 19 changed files with 267 additions and 87 deletions.
4 changes: 2 additions & 2 deletions include/tvm/meta_schedule/apply_history_best.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ namespace meta_schedule {
class ApplyHistoryBestNode : public runtime::Object {
public:
/*! \brief A callback function that filters TE compute */
using FTEFilterFunc =
runtime::TypedPackedFunc<Optional<tir::PrimFunc>(const Array<te::Tensor, void>&)>;
using FTEFilterFunc = runtime::TypedPackedFunc<Optional<tir::PrimFunc>(
const Array<te::Tensor, void>&, const Array<runtime::NDArray>&)>;
/*! \brief A callback function that takes a tuning record and does something with it */
using FTakeTuningRecord = runtime::TypedPackedFunc<void(const TuningRecord&)>;
using FDirectDispatch = runtime::TypedPackedFunc<Optional<IRModule>(const IRModule&)>;
Expand Down
10 changes: 8 additions & 2 deletions include/tvm/meta_schedule/extracted_task.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,16 +79,22 @@ class ExtractedTask : public runtime::ObjectRef {
/*!
* \brief The default TE task filter
* \param args The input/output arguments of the TE compute graph
* \param constants Raw data for constant tensors in args. If the size of this array is N, the last
* N tensors in args will be treated as constant tensors.
* \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
*/
Optional<tvm::tir::PrimFunc> DefaultTaskFilter(const Array<tvm::te::Tensor, void>& args);
Optional<tvm::tir::PrimFunc> DefaultTaskFilter(const Array<tvm::te::Tensor, void>& args,
const Array<runtime::NDArray>& constants);

/*!
* \brief The default TE task filter, with `te.extern` allowed
* \param args The input/output arguments of the TE compute graph
* \param constants Raw data for constant tensors in args. If the size of this array is N, the last
* N tensors in args will be treated as constant tensors.
* \return NullOpt if the task is filtered out, otherwise the task in PrimFunc
*/
Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<tvm::te::Tensor, void>& args);
Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<tvm::te::Tensor, void>& args,
const Array<runtime::NDArray>& constants);

} // namespace meta_schedule
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/apply_history_best.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class ApplyHistoryBest(Object):
----------
database : Database
The database to be queried from
te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None
te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None
The filtering function for TE computation
If it's a string, it's the name of the filtering function. Built in functions are
- "meta_schedule.DefaultTaskFilter"
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/meta_schedule/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def extract_task_from_relay(
The pass config of the compiler
disabled_pass : Optional[List[str]]
The list of disabled passes of the compiler
te_filter_func : Callable[[List[tvm.te.Tensor]], bool]
te_filter_func : Callable[[List[tvm.te.Tensor], List[NDArray]], bool]
The filter function to filter out the extracted tasks
If it's a string, it's the name of the filtering function. Built in functions are
- "meta_schedule.DefaultTaskFilter"
Expand Down
13 changes: 7 additions & 6 deletions python/tvm/meta_schedule/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# under the License.
"""Testing utility functions in meta schedule"""
from typing import Callable, Dict, Optional, Union
from tvm.ir import IRModule
from tvm.ir import IRModule, transform
from tvm.relay import Function as RelayFunc
from tvm.runtime import NDArray
from tvm.target import Target
Expand Down Expand Up @@ -45,7 +45,7 @@ def apply_fixed_schedules(
schedule_fn : Callable[[ExtractedTask, Schedule], bool]
A callable that is applied for each extracted task and the corresponding default schedule.
Returns True if the given schedule should be committed to the database, False otherwise.
te_filter_func : Union[str, None, Callable[[List[Tensor]], PrimFunc]] = None
te_filter_func : Union[str, None, Callable[[List[Tensor], List[NDArray]], PrimFunc]] = None
The filtering function for TE computation
If it's a string, it's the name of the filtering function. Built in functions are
- "meta_schedule.DefaultTaskFilter"
Expand All @@ -59,11 +59,12 @@ def apply_fixed_schedules(
The database containing dummy tuning records for manually scheduled traces.
"""
target = Target(target) if isinstance(target, str) else target
config = {"relay.backend.use_meta_schedule": True}
for k, v in transform.PassContext.current().config.items():
config[k] = v

extracted_tasks = ms.extract_task_from_relay(
relay_mod,
target,
params,
te_filter_func=te_filter_func,
relay_mod, target, params, te_filter_func=te_filter_func, pass_config=config
)
database = ms.database.MemoryDatabase()
for task in extracted_tasks:
Expand Down
27 changes: 19 additions & 8 deletions src/meta_schedule/extracted_task.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ ExtractedTask::ExtractedTask(String task_name, IRModule mod, Target target,
data_ = n;
}

Optional<tir::PrimFunc> DefaultTaskFilterImpl(const Array<te::Tensor>& args, bool allow_extern_op) {
Optional<tir::PrimFunc> DefaultTaskFilterImpl(const Array<te::Tensor>& args,
const Array<runtime::NDArray>& constants,
bool allow_extern_op) {
using namespace ::tvm::te;
std::vector<Tensor> stack;
std::unordered_set<const TensorNode*> visited;
Expand Down Expand Up @@ -72,7 +74,7 @@ Optional<tir::PrimFunc> DefaultTaskFilterImpl(const Array<te::Tensor>& args, boo
return NullOpt;
}
}
PrimFunc func = te::CreatePrimFunc(args);
PrimFunc func = te::CreatePrimFuncWithConstants(args, constants);
bool dynamic_loop_extent = false;
PostOrderVisit(func->body, [&dynamic_loop_extent](const ObjectRef& obj) -> void {
if (const auto* loop = obj.as<tir::ForNode>()) {
Expand All @@ -87,12 +89,14 @@ Optional<tir::PrimFunc> DefaultTaskFilterImpl(const Array<te::Tensor>& args, boo
return func;
}

Optional<tir::PrimFunc> DefaultTaskFilter(const Array<te::Tensor>& args) {
return DefaultTaskFilterImpl(args, false);
Optional<tir::PrimFunc> DefaultTaskFilter(const Array<te::Tensor>& args,
const Array<runtime::NDArray>& constants) {
return DefaultTaskFilterImpl(args, constants, false);
}

Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<te::Tensor>& args) {
return DefaultTaskFilterImpl(args, true);
Optional<tir::PrimFunc> DefaultTaskFilterAllowExtern(const Array<te::Tensor>& args,
const Array<runtime::NDArray>& constants) {
return DefaultTaskFilterImpl(args, constants, true);
}

TVM_REGISTER_NODE_TYPE(ExtractedTaskNode);
Expand All @@ -101,8 +105,15 @@ TVM_REGISTER_GLOBAL("meta_schedule.ExtractedTask")
int weight) -> ExtractedTask {
return ExtractedTask(task_name, mod, target, dispatched, weight);
});
TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter").set_body_typed(DefaultTaskFilter);

TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilter")
.set_body_typed([](const Array<te::Tensor>& args, const Array<runtime::NDArray>& constants) {
return DefaultTaskFilter(args, constants);
});

TVM_REGISTER_GLOBAL("meta_schedule.DefaultTaskFilterAllowExtern")
.set_body_typed(DefaultTaskFilterAllowExtern);
.set_body_typed([](const Array<te::Tensor>& args, const Array<runtime::NDArray>& constants) {
return DefaultTaskFilterAllowExtern(args, constants);
});
} // namespace meta_schedule
} // namespace tvm
19 changes: 19 additions & 0 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,14 @@ void NDArrayToTIR(::tvm::runtime::NDArray arr, std::ostream& os) {
tot_dim *= arr->shape[i];
}
T* data_ptr = reinterpret_cast<T*>(arr->data);
constexpr int NUM_PRINT = 20;
os << "[";
for (int i = 0; i < tot_dim; i++) {
os << (i != 0 ? ", " : "") << data_ptr[i];
if (i == NUM_PRINT) {
os << "...";
break;
}
}
os << "]";
}
Expand Down Expand Up @@ -1121,6 +1126,20 @@ Doc TVMScriptPrinter::VisitStmt_(const AllocateConstNode* alloc) {
NDArrayToTIR<int16_t>(data, ss);
} else if (alloc->dtype.bits() == 32) {
NDArrayToTIR<int32_t>(data, ss);
} else if (alloc->dtype.bits() == 64) {
NDArrayToTIR<int64_t>(data, ss);
} else {
LOG(FATAL) << "DataType not supported";
}
} else if (alloc->dtype.is_uint()) {
if (alloc->dtype.bits() == 8) {
// NDArrayToTIR<uint8_t>(data, ss);
} else if (alloc->dtype.bits() == 16) {
NDArrayToTIR<uint16_t>(data, ss);
} else if (alloc->dtype.bits() == 32) {
NDArrayToTIR<uint32_t>(data, ss);
} else if (alloc->dtype.bits() == 64) {
NDArrayToTIR<int64_t>(data, ss);
} else {
LOG(FATAL) << "DataType not supported";
}
Expand Down
10 changes: 5 additions & 5 deletions src/relay/backend/task_extraction.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/

#include <tvm/meta_schedule/apply_history_best.h>
#include <tvm/meta_schedule/extracted_task.h>
#include <tvm/relay/expr.h>
#include <tvm/relay/expr_functor.h>
Expand All @@ -33,7 +34,7 @@ namespace backend {

Array<meta_schedule::ExtractedTask> ExtractTask(
IRModule mod, Target target, Map<String, runtime::NDArray> params,
runtime::TypedPackedFunc<Optional<tir::PrimFunc>(const Array<te::Tensor>&)> filter_func) {
meta_schedule::ApplyHistoryBestNode::FTEFilterFunc filter_func) {
using meta_schedule::ExtractedTask;
if (filter_func == nullptr) {
filter_func = tvm::meta_schedule::DefaultTaskFilter;
Expand All @@ -42,6 +43,7 @@ Array<meta_schedule::ExtractedTask> ExtractTask(
// is_vm=true for backward compatibility
Array<Pass> pass_seqs = relay::backend::GetPassPrefix(/*is_homogenous=*/true, /*is_vm=*/true);
pass_seqs.push_back(transform::FuseOps());

mod = transform::Sequential(pass_seqs)(std::move(mod));

std::vector<ExtractedTask> tasks;
Expand All @@ -58,11 +60,9 @@ Array<meta_schedule::ExtractedTask> ExtractTask(
it->second->weight += 1;
return;
}
Array<te::Tensor> inputs_outputs{nullptr};
std::string fused_name;
std::tie(inputs_outputs, fused_name) =
auto [inputs_outputs, constants, fused_name] =
tec::LowerTECompute(relay_func, target, /*return_inputs=*/true);
if (Optional<tir::PrimFunc> prim_func = filter_func(inputs_outputs)) {
if (Optional<tir::PrimFunc> prim_func = filter_func(inputs_outputs, constants)) {
GlobalVar prim_fn_var(fused_name);
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, prim_func.value()}});
Expand Down
26 changes: 19 additions & 7 deletions src/relay/backend/te_compiler_cache.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,6 @@
#include "../../te/operation/create_primfunc.h"
#include "../op/memory/memory.h"
#include "../transforms/meta_schedule_layout_rewrite.h"
#include "../transforms/pass_utils.h"
#include "utils.h"

namespace tvm {
Expand Down Expand Up @@ -362,8 +361,14 @@ class ScheduleBuilder : public ExprVisitor {
}
if (meta_schedule_ctx_) {
Array<te::Tensor> te_args = Concat(fn_inputs, tensor_outs);
Array<runtime::NDArray> constants;
for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) {
te_args.push_back(te_tensor);
constants.push_back(const_node->data);
}

if (Optional<tir::PrimFunc> tir_func =
meta_schedule_ctx_.value()->te_filter_func(te_args)) {
meta_schedule_ctx_.value()->te_filter_func(te_args, constants)) {
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, tir_func.value()}});
if (Optional<IRModule> opt_scheduled_mod = meta_schedule_ctx_.value()->Query(
Expand Down Expand Up @@ -785,8 +790,8 @@ CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
return MakeShapeFunc().Create(prim_func, target, global_var_supply);
}

std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function& source_func, Target target,
bool return_inputs) {
std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompute(
const Function& source_func, Target target, bool return_inputs) {
LowerToTECompute lower_te_compute(target);
Array<te::Tensor> outputs = lower_te_compute.Lower(source_func);
// Following ScheduleBuilder, remove placeholder ops from outputs.
Expand All @@ -796,11 +801,18 @@ std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function& source_
tensor_outs.push_back(tensor);
}
}

tvm::Array<runtime::NDArray> constants;
for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) {
tensor_outs.push_back(te_tensor);
constants.push_back(const_node->data);
}

if (return_inputs) {
return std::make_pair(Concat(lower_te_compute.fn_inputs_, tensor_outs),
lower_te_compute.candidate_name_);
return std::make_tuple(Concat(lower_te_compute.fn_inputs_, tensor_outs), constants,
lower_te_compute.candidate_name_);
}
return std::make_pair(tensor_outs, lower_te_compute.candidate_name_);
return std::make_tuple(tensor_outs, constants, lower_te_compute.candidate_name_);
}

TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
Expand Down
7 changes: 4 additions & 3 deletions src/relay/backend/te_compiler_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include <functional>
#include <string>
#include <tuple>
#include <unordered_map>
#include <utility>

Expand Down Expand Up @@ -215,10 +216,10 @@ Array<IndexExpr> GetShape(const Array<IndexExpr>& shape);
* \param source_func The primitive function to be lowered.
* \param target The target we want to create schedule for.
* \param return_inputs If true, prepend input tensors to the output array of tensors.
* \return Pair of schedule and fused function name.
* \return Tuple of the lowered TE compute, constant raw data, and fused function name.
*/
std::pair<Array<te::Tensor>, std::string> LowerTECompute(const Function& source_func, Target target,
bool return_inputs = true);
std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompute(
const Function& source_func, Target target, bool return_inputs = true);

/*!
* \brief Create schedule for target.
Expand Down
12 changes: 12 additions & 0 deletions src/te/operation/create_primfunc.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,22 @@
* under the License.
*/

#include "create_primfunc.h"

#include <tvm/arith/analyzer.h>
#include <tvm/ir/name_supply.h>
#include <tvm/runtime/registry.h>
#include <tvm/tir/function.h>
#include <tvm/tir/stmt_functor.h>

#include <algorithm>
#include <set>
#include <unordered_map>
#include <unordered_set>
#include <utility>

#include "../../tir/ir/functor_common.h"
#include "../../tir/transforms/ir_utils.h"
#include "../schedule/graph.h"

namespace tvm {
Expand Down Expand Up @@ -492,6 +498,12 @@ PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list) {
return GenerateAndCompletePrimFunc(arg_list, root_stmts, &info);
}

PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
const Array<runtime::NDArray>& constants) {
PrimFunc func = CreatePrimFunc(arg_list);
return tir::BindParams(func, constants);
}

TVM_REGISTER_GLOBAL("te.CreatePrimFunc").set_body_typed(CreatePrimFunc);

} // namespace tir
Expand Down
8 changes: 8 additions & 0 deletions src/te/operation/create_primfunc.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,14 @@ namespace tir {
/*! \brief Use Tensor Expression to create a schedulable TensorIR func. */
PrimFunc CreatePrimFunc(const Array<te::Tensor>& arg_list);

/*! \brief The same as above but create a PrimFunc with AllocateConstNode. If the size of the
* constants array is N, the last N tensors in arg_list will be treated as constant tensors.
* Constant tensors will not be part of the parameters of the created PrimFunc, instead constants
* will be embedded in the body as AllocateConstNode.
*/
PrimFunc CreatePrimFuncWithConstants(const Array<te::Tensor>& arg_list,
const Array<runtime::NDArray>& constants);

} // namespace tir
} // namespace tvm

Expand Down
1 change: 1 addition & 0 deletions src/tir/analysis/estimate_flops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ class FlopEstimator : private ExprFunctor<TResult(const PrimExpr& n)>,
TResult VisitExpr_(const IntImmNode* op) override { return TResult(); }
TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); }
TResult VisitStmt_(const AllocateConstNode* op) override { return VisitStmt(op->body); }

TResult VisitStmt_(const SeqStmtNode* seq) override {
TResult result;
Expand Down
19 changes: 17 additions & 2 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,9 +220,24 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
}
}
if (const auto* block = sref->StmtAs<BlockNode>()) {
if (const auto* seq = block->body.as<SeqStmtNode>()) {
auto body = block->body;
// Peel off AllocateConst nodes at the beginning of the block body.
std::vector<const AllocateConstNode*> allocs;
while (const auto* alloc = body.as<AllocateConstNode>()) {
allocs.push_back(alloc);
body = alloc->body;
}
if (const auto* seq = body.as<SeqStmtNode>()) {
ObjectPtr<BlockNode> n = make_object<BlockNode>(*block);
n->body = RemoveFromSeqStmt(GetRef<SeqStmt>(seq), GetRef<Stmt>(last_stmt));
auto new_seq = RemoveFromSeqStmt(GetRef<SeqStmt>(seq), GetRef<Stmt>(last_stmt));
// Re-attach AllocateConst nodes
auto new_body = new_seq;
for (int i = 0; i < static_cast<int>(allocs.size()); ++i) {
auto alloc = allocs[allocs.size() - 1 - i];
new_body = AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data,
new_body, alloc->annotations, alloc->span);
}
n->body = new_body;
*src_stmt = GetRef<Stmt>(block);
*tgt_stmt = Stmt(std::move(n));
return;
Expand Down
Loading

0 comments on commit 8146a9b

Please sign in to comment.