Skip to content

Commit

Permalink
[TIR][MetaSchedule] Estimate TIR FLOPs
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao committed Mar 25, 2022
1 parent 8ebdf6e commit 9f75fe7
Show file tree
Hide file tree
Showing 5 changed files with 289 additions and 204 deletions.
14 changes: 14 additions & 0 deletions include/tvm/tir/analysis.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,20 @@ inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) {
}
}

/*!
* \brief Estimate the FLOPs of a TIR fragment.
* \param stmt The TIR fragment to be estimated.
* \return The estimated FLOPs.
*/
TVM_DLL double EstimateTIRFlops(const Stmt& stmt);

/*!
* \brief Estimate the FLOPs of TIRs in an IRModule.
* \param mod The IRModule to be estimated.
* \return The estimated FLOPs.
*/
TVM_DLL double EstimateTIRFlops(const IRModule& mod);

/*!
* \brief Find undefined vars in the statement.
* \param stmt The function to be checked.
Expand Down
27 changes: 22 additions & 5 deletions python/tvm/tir/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,16 @@
# under the License.
"""Wrapping existing analysis utils."""
# pylint: disable=invalid-name
from typing import Dict, List
from typing import Dict, List, Union

from tvm import Object
from tvm.tir.stmt import Block, BufferRegion
from tvm.tir.stmt import PrimExpr
from tvm.ir import IRModule
from tvm.tir.expr import Var
from . import _ffi_api
from ..function import PrimFunc
from tvm.tir.stmt import Block, BufferRegion, PrimExpr

from .. import Buffer, Stmt
from ..function import PrimFunc
from . import _ffi_api


def expr_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool:
Expand Down Expand Up @@ -199,6 +200,22 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]:
return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member


def estimate_tir_flops(stmt_or_mod: Union[Stmt, IRModule]) -> float:
"""Estimate the FLOPs of a TIR fragment.
Parameters
----------
stmt_or_mod: Union[Stmt, IRModule]
The TIR fragment or IRModule to be estimated.
Returns
-------
flops: float
The estimated FLOPs.
"""
return _ffi_api.EstimateTIRFlops(stmt_or_mod) # type: ignore # pylint: disable=no-member


# NOTE: relay_func_type in the following two functions should be relay.FuncType however that would
# introduce a cycling dependency. We make do with Object.

Expand Down
200 changes: 1 addition & 199 deletions src/meta_schedule/measure_callback/echo_statistics.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,204 +20,6 @@

#include "../utils.h"

namespace tvm {
namespace tir {

double CountFlop(const IRModule& mod) {
struct TResult {
using TTable = std::unordered_map<int32_t, double>;

TResult() = default;

explicit TResult(const tvm::DataType& dtype) { Add(dtype); }

void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; }

TResult operator+=(const TResult& rhs) {
for (const auto& kv : rhs.data_) {
data_[kv.first] += kv.second;
}
return *this;
}

TResult operator*=(int64_t rhs) {
for (auto& kv : data_) {
kv.second *= rhs;
}
return *this;
}

TResult MaxWith(const TResult& rhs) {
for (const auto& kv : rhs.data_) {
double& v = data_[kv.first];
if (v < kv.second) {
v = kv.second;
}
}
return *this;
}

struct DType {
uint8_t code : 8;
uint8_t bits : 8;
uint16_t lanes : 16;
};
static_assert(sizeof(DType) == 4, "Incorrect size of DType");

static String Int2Str(int32_t dtype) {
union {
DType dst;
int32_t src;
} converter;
converter.src = dtype;
static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"};
std::ostringstream os;
os << type_code_tab[converter.dst.code];
os << static_cast<int>(converter.dst.bits);
if (converter.dst.lanes != 1) {
os << "x" << static_cast<int>(converter.dst.lanes);
}
return os.str();
}

static int32_t DataType2Int(const tvm::DataType& dtype) {
union {
DType src;
int32_t dst;
} converter;
converter.src.code = dtype.code();
converter.src.bits = dtype.bits();
converter.src.lanes = dtype.lanes();
return converter.dst;
}

TTable data_;
};

class FlopCounter : public ExprFunctor<TResult(const PrimExpr& n)>,
public StmtFunctor<TResult(const Stmt& n)> {
public:
~FlopCounter() {}

TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); }
TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); }

TResult VisitStmt_(const IfThenElseNode* branch) override {
TResult cond = VisitExpr(branch->condition);
cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case));
return cond;
}

TResult VisitStmt_(const BufferStoreNode* store) override {
TResult result = VisitExpr(store->value);
for (const PrimExpr& e : store->indices) {
result += VisitExpr(e);
}
return result;
}

TResult VisitStmt_(const SeqStmtNode* seq) override {
TResult result;
for (const Stmt& stmt : seq->seq) {
result += VisitStmt(stmt);
}
return result;
}

TResult VisitStmt_(const BlockRealizeNode* block) override {
return VisitStmt(block->block->body);
}

TResult VisitStmt_(const BlockNode* block) override {
TResult result;
if (block->init.defined()) {
result += VisitStmt(block->init.value());
}
result += VisitStmt(block->body);
return result;
}

TResult VisitStmt_(const ForNode* loop) override {
TResult result = VisitStmt(loop->body);
const auto* int_imm = loop->extent.as<IntImmNode>();
ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: "
<< loop->extent->GetTypeKey();
result *= int_imm->value;
return result;
}

#define TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(Node) \
TResult VisitExpr_(const Node* op) final { \
TResult result(op->dtype); \
result += VisitExpr(op->a); \
result += VisitExpr(op->b); \
return result; \
}
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AddNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(SubNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MulNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(DivNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(ModNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorDivNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorModNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MinNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MaxNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(EQNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(NENode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LTNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LENode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GTNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GENode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AndNode);
TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(OrNode);
#undef TVM_META_SCHEDULE_FLOP_COUNTER_BINARY
TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); }
TResult VisitExpr_(const VarNode* op) override { return TResult(); }
TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); }
TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); }
TResult VisitExpr_(const IntImmNode* op) override { return TResult(); }
TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
TResult VisitExpr_(const NotNode* op) override {
TResult result(op->dtype);
result += VisitExpr(op->a);
return result;
}
TResult VisitExpr_(const SelectNode* op) override {
TResult cond = VisitExpr(op->condition);
cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value));
return cond;
}
TResult VisitExpr_(const CallNode* op) override {
TResult ret;
for (const auto& x : op->args) {
ret += VisitExpr(x);
}
return ret;
}
};
FlopCounter counter;
TResult result;
for (const auto& kv : mod->functions) {
const BaseFunc& base_func = kv.second;
if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
result += counter.VisitStmt(prim_func->body);
}
}
double cnt = 0.0;
int i32 = TResult::DataType2Int(tvm::DataType::Int(32));
int i64 = TResult::DataType2Int(tvm::DataType::Int(64));
int u1 = TResult::DataType2Int(tvm::DataType::UInt(1));
for (const auto& kv : result.data_) {
if (kv.first != i32 && kv.first != i64 && kv.first != u1) {
cnt += kv.second;
}
}
return cnt;
}

} // namespace tir
} // namespace tvm

namespace tvm {
namespace meta_schedule {

Expand Down Expand Up @@ -312,7 +114,7 @@ class EchoStatisticsNode : public MeasureCallbackNode {
for (const TuneContext& task : tasks) {
task_info.push_back(TaskInfo(GetTaskName(task, task_id)));
TaskInfo& info = task_info.back();
info.flop = tir::CountFlop(task->mod.value());
info.flop = tir::EstimateTIRFlops(task->mod.value());
++task_id;
}
}
Expand Down
Loading

0 comments on commit 9f75fe7

Please sign in to comment.