diff --git a/include/tvm/tir/analysis.h b/include/tvm/tir/analysis.h index 51bdb18d22174..8306cb173e0af 100644 --- a/include/tvm/tir/analysis.h +++ b/include/tvm/tir/analysis.h @@ -72,6 +72,20 @@ inline void VisitPrimFuncs(const IRModule& mod, FLambda fvisit) { } } +/*! + * \brief Estimate the FLOPs of a TIR fragment. + * \param stmt The TIR fragment to be estimated. + * \return The estimated FLOPs. + */ +TVM_DLL double EstimateTIRFlops(const Stmt& stmt); + +/*! + * \brief Estimate the FLOPs of TIRs in an IRModule. + * \param mod The IRModule to be estimated. + * \return The estimated FLOPs. + */ +TVM_DLL double EstimateTIRFlops(const IRModule& mod); + /*! * \brief Find undefined vars in the statement. * \param stmt The function to be checked. diff --git a/python/tvm/tir/analysis/analysis.py b/python/tvm/tir/analysis/analysis.py index c2338dd9b6111..0e91f88413135 100644 --- a/python/tvm/tir/analysis/analysis.py +++ b/python/tvm/tir/analysis/analysis.py @@ -16,15 +16,16 @@ # under the License. """Wrapping existing analysis utils.""" # pylint: disable=invalid-name -from typing import Dict, List +from typing import Dict, List, Union from tvm import Object -from tvm.tir.stmt import Block, BufferRegion -from tvm.tir.stmt import PrimExpr +from tvm.ir import IRModule from tvm.tir.expr import Var -from . import _ffi_api -from ..function import PrimFunc +from tvm.tir.stmt import Block, BufferRegion, PrimExpr + from .. import Buffer, Stmt +from ..function import PrimFunc +from . import _ffi_api def expr_deep_equal(lhs: PrimExpr, rhs: PrimExpr) -> bool: @@ -199,6 +200,22 @@ def detect_buffer_access_lca(func: PrimFunc) -> Dict[Buffer, Stmt]: return _ffi_api.detect_buffer_access_lca(func) # type: ignore # pylint: disable=no-member +def estimate_tir_flops(stmt_or_mod: Union[Stmt, IRModule]) -> float: + """Estimate the FLOPs of a TIR fragment. + + Parameters + ---------- + stmt_or_mod: Union[Stmt, IRModule] + The TIR fragment or IRModule to be estimated. + + Returns + ------- + flops: float + The estimated FLOPs. + """ + return _ffi_api.EstimateTIRFlops(stmt_or_mod) # type: ignore # pylint: disable=no-member + + # NOTE: relay_func_type in the following two functions should be relay.FuncType however that would # introduce a cycling dependency. We make do with Object. diff --git a/src/meta_schedule/measure_callback/echo_statistics.cc b/src/meta_schedule/measure_callback/echo_statistics.cc index 1504e77c299f4..ae7a4826c947e 100644 --- a/src/meta_schedule/measure_callback/echo_statistics.cc +++ b/src/meta_schedule/measure_callback/echo_statistics.cc @@ -20,204 +20,6 @@ #include "../utils.h" -namespace tvm { -namespace tir { - -double CountFlop(const IRModule& mod) { - struct TResult { - using TTable = std::unordered_map; - - TResult() = default; - - explicit TResult(const tvm::DataType& dtype) { Add(dtype); } - - void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; } - - TResult operator+=(const TResult& rhs) { - for (const auto& kv : rhs.data_) { - data_[kv.first] += kv.second; - } - return *this; - } - - TResult operator*=(int64_t rhs) { - for (auto& kv : data_) { - kv.second *= rhs; - } - return *this; - } - - TResult MaxWith(const TResult& rhs) { - for (const auto& kv : rhs.data_) { - double& v = data_[kv.first]; - if (v < kv.second) { - v = kv.second; - } - } - return *this; - } - - struct DType { - uint8_t code : 8; - uint8_t bits : 8; - uint16_t lanes : 16; - }; - static_assert(sizeof(DType) == 4, "Incorrect size of DType"); - - static String Int2Str(int32_t dtype) { - union { - DType dst; - int32_t src; - } converter; - converter.src = dtype; - static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"}; - std::ostringstream os; - os << type_code_tab[converter.dst.code]; - os << static_cast(converter.dst.bits); - if (converter.dst.lanes != 1) { - os << "x" << static_cast(converter.dst.lanes); - } - return os.str(); - } - - static int32_t DataType2Int(const tvm::DataType& dtype) { - union { - DType src; - int32_t dst; - } converter; - converter.src.code = dtype.code(); - converter.src.bits = dtype.bits(); - converter.src.lanes = dtype.lanes(); - return converter.dst; - } - - TTable data_; - }; - - class FlopCounter : public ExprFunctor, - public StmtFunctor { - public: - ~FlopCounter() {} - - TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); } - TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); } - - TResult VisitStmt_(const IfThenElseNode* branch) override { - TResult cond = VisitExpr(branch->condition); - cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case)); - return cond; - } - - TResult VisitStmt_(const BufferStoreNode* store) override { - TResult result = VisitExpr(store->value); - for (const PrimExpr& e : store->indices) { - result += VisitExpr(e); - } - return result; - } - - TResult VisitStmt_(const SeqStmtNode* seq) override { - TResult result; - for (const Stmt& stmt : seq->seq) { - result += VisitStmt(stmt); - } - return result; - } - - TResult VisitStmt_(const BlockRealizeNode* block) override { - return VisitStmt(block->block->body); - } - - TResult VisitStmt_(const BlockNode* block) override { - TResult result; - if (block->init.defined()) { - result += VisitStmt(block->init.value()); - } - result += VisitStmt(block->body); - return result; - } - - TResult VisitStmt_(const ForNode* loop) override { - TResult result = VisitStmt(loop->body); - const auto* int_imm = loop->extent.as(); - ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: " - << loop->extent->GetTypeKey(); - result *= int_imm->value; - return result; - } - -#define TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(Node) \ - TResult VisitExpr_(const Node* op) final { \ - TResult result(op->dtype); \ - result += VisitExpr(op->a); \ - result += VisitExpr(op->b); \ - return result; \ - } - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AddNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(SubNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MulNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(DivNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(ModNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorDivNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorModNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MinNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MaxNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(EQNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(NENode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LTNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LENode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GTNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GENode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AndNode); - TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(OrNode); -#undef TVM_META_SCHEDULE_FLOP_COUNTER_BINARY - TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } - TResult VisitExpr_(const VarNode* op) override { return TResult(); } - TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); } - TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); } - TResult VisitExpr_(const IntImmNode* op) override { return TResult(); } - TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } - TResult VisitExpr_(const NotNode* op) override { - TResult result(op->dtype); - result += VisitExpr(op->a); - return result; - } - TResult VisitExpr_(const SelectNode* op) override { - TResult cond = VisitExpr(op->condition); - cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value)); - return cond; - } - TResult VisitExpr_(const CallNode* op) override { - TResult ret; - for (const auto& x : op->args) { - ret += VisitExpr(x); - } - return ret; - } - }; - FlopCounter counter; - TResult result; - for (const auto& kv : mod->functions) { - const BaseFunc& base_func = kv.second; - if (const auto* prim_func = base_func.as()) { - result += counter.VisitStmt(prim_func->body); - } - } - double cnt = 0.0; - int i32 = TResult::DataType2Int(tvm::DataType::Int(32)); - int i64 = TResult::DataType2Int(tvm::DataType::Int(64)); - int u1 = TResult::DataType2Int(tvm::DataType::UInt(1)); - for (const auto& kv : result.data_) { - if (kv.first != i32 && kv.first != i64 && kv.first != u1) { - cnt += kv.second; - } - } - return cnt; -} - -} // namespace tir -} // namespace tvm - namespace tvm { namespace meta_schedule { @@ -312,7 +114,7 @@ class EchoStatisticsNode : public MeasureCallbackNode { for (const TuneContext& task : tasks) { task_info.push_back(TaskInfo(GetTaskName(task, task_id))); TaskInfo& info = task_info.back(); - info.flop = tir::CountFlop(task->mod.value()); + info.flop = tir::EstimateTIRFlops(task->mod.value()); ++task_id; } } diff --git a/src/tir/analysis/estimate_flops.cc b/src/tir/analysis/estimate_flops.cc new file mode 100644 index 0000000000000..839969f985e71 --- /dev/null +++ b/src/tir/analysis/estimate_flops.cc @@ -0,0 +1,201 @@ +#include +#include + +namespace tvm { +namespace tir { + +int32_t DataType2Int(const tvm::DataType& dtype) { + static_assert(sizeof(DLDataType) == sizeof(int32_t), "Incorrect size of DLDataType"); + union { + DLDataType src; + int32_t dst; + } converter; + converter.src.code = dtype.code(); + converter.src.bits = dtype.bits(); + converter.src.lanes = dtype.lanes(); + return converter.dst; +} + +String Int2DataTypeStr(int32_t dtype) { + union { + DLDataType dst; + int32_t src; + } converter; + converter.src = dtype; + static std::string type_code_tab[] = {"int", "uint", "float", "handle", "bfloat"}; + std::ostringstream os; + os << type_code_tab[converter.dst.code]; + os << static_cast(converter.dst.bits); + if (converter.dst.lanes != 1) { + os << "x" << static_cast(converter.dst.lanes); + } + return os.str(); +} + +struct TResult { + TResult() = default; + + void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; } + + TResult operator+=(const TResult& rhs) { + for (const auto& kv : rhs.data_) { + data_[kv.first] += kv.second; + } + return *this; + } + + TResult operator*=(int64_t rhs) { + for (auto& kv : data_) { + kv.second *= rhs; + } + return *this; + } + + TResult MaxWith(const TResult& rhs) { + for (const auto& kv : rhs.data_) { + double& v = data_[kv.first]; + if (v < kv.second) { + v = kv.second; + } + } + return *this; + } + + std::unordered_map data_; +}; + +class FlopEstimator : public ExprFunctor, + public StmtFunctor { + public: + TResult VisitExpr(const PrimExpr& expr) override { return ExprFunctor::VisitExpr(expr); } + TResult VisitStmt(const Stmt& stmt) override { return StmtFunctor::VisitStmt(stmt); } + +#define TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(Node) \ + TResult VisitExpr_(const Node* op) final { \ + TResult result = VisitExpr(op->a); \ + result += VisitExpr(op->b); \ + result.Add(op->dtype); \ + return result; \ + } + TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(AddNode); + TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(SubNode); + TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(MulNode); + TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(DivNode); + TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(ModNode); + TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(FloorDivNode); + TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(FloorModNode); + TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(MinNode); + TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY(MaxNode); +#undef TVM_TIR_ESTIMATE_FLOP_VISIT_BINARY + TResult VisitExpr_(const EQNode* op) override { return TResult(); } + TResult VisitExpr_(const NENode* op) override { return TResult(); } + TResult VisitExpr_(const LTNode* op) override { return TResult(); } + TResult VisitExpr_(const LENode* op) override { return TResult(); } + TResult VisitExpr_(const GTNode* op) override { return TResult(); } + TResult VisitExpr_(const GENode* op) override { return TResult(); } + + TResult VisitExpr_(const NotNode* op) override { return VisitExpr(op->a); } + TResult VisitExpr_(const AndNode* op) final { + TResult result = VisitExpr(op->a); + result += VisitExpr(op->b); + return result; + } + TResult VisitExpr_(const OrNode* op) final { + TResult result = VisitExpr(op->a); + result += VisitExpr(op->b); + return result; + } + + TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); } + TResult VisitStmt_(const BufferStoreNode* store) override { return VisitExpr(store->value); } + TResult VisitStmt_(const BlockRealizeNode* block) override { + return VisitStmt(block->block->body); + } + TResult VisitStmt_(const BlockNode* block) override { + TResult result; + if (block->init.defined()) { + result += VisitStmt(block->init.value()); + } + result += VisitStmt(block->body); + return result; + } + TResult VisitStmt_(const ForNode* loop) override { + TResult result = VisitStmt(loop->body); + const auto* int_imm = loop->extent.as(); + ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm, but gets: " + << loop->extent->GetTypeKey(); + result *= int_imm->value; + return result; + } + + TResult VisitStmt_(const IfThenElseNode* branch) override { + TResult cond = VisitExpr(branch->condition); + cond += VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case)); + return cond; + } + + TResult VisitExpr_(const SelectNode* op) override { + TResult cond = VisitExpr(op->condition); + cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value)); + return cond; + } + + TResult VisitExpr_(const VarNode* op) override { return TResult(); } + TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); } + TResult VisitExpr_(const IntImmNode* op) override { return TResult(); } + TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); } + TResult VisitExpr_(const CastNode* op) override { return VisitExpr(op->value); } + + TResult VisitStmt_(const SeqStmtNode* seq) override { + TResult result; + for (const Stmt& stmt : seq->seq) { + result += VisitStmt(stmt); + } + return result; + } + + TResult VisitExpr_(const CallNode* op) override { + TResult ret; + for (const auto& x : op->args) { + ret += VisitExpr(x); + } + return ret; + } +}; + +double PostprocessResults(const TResult& result) { + double cnt = 0.0; + for (const auto& kv : result.data_) { + cnt += kv.second; + } + return cnt; +} + +double EstimateTIRFlops(const Stmt& stmt) { + FlopEstimator counter; + return PostprocessResults(counter(stmt)); +} + +double EstimateTIRFlops(const IRModule& mod) { + FlopEstimator counter; + TResult result; + VisitPrimFuncs(mod, [&result, &counter](const PrimFuncNode* f) { + result += counter.VisitStmt(f->body); // + }); + return PostprocessResults(result); +} + +TVM_REGISTER_GLOBAL("tir.analysis.EstimateTIRFlops").set_body_typed([](ObjectRef obj) -> double { + if (const auto* mod = obj.as()) { + return EstimateTIRFlops(GetRef(mod)); + } else if (const auto* stmt = obj.as()) { + return EstimateTIRFlops(GetRef(stmt)); + } else { + LOG(FATAL) << "TypeError: Expect the input to be either IRModule or Stmt, but gets: " + << obj->GetTypeKey(); + throw; + } +}); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py new file mode 100644 index 0000000000000..a516f07473f04 --- /dev/null +++ b/tests/python/unittest/test_tir_analysis_estimate_tir_flops.py @@ -0,0 +1,51 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import sys + +import pytest +from tvm.ir import IRModule +from tvm.meta_schedule.testing.te_workload import create_te_workload +from tvm.tir.analysis import estimate_tir_flops + + +@pytest.mark.parametrize( + "workload, flops", + [ + ("C1D", 6291456), + ("C2D", 236027904), + ("C3D", 13217562624), + ("CAP", 75497472), + ("DEP", 7225344), + ("DIL", 223552896), + ("GMM", 4194304), + ("GRP", 28901376), + ("T2D", 268435456), + ("C2d-BN-RELU", 239239168), + ("TBG", 25165824), + ("NRM", 131072), + ("SFM", 262144), + ], +) +def test_te_workload(workload, flops): + te_workload = create_te_workload(workload, 0) + mod = IRModule({"main": te_workload}) + assert float(flops) == estimate_tir_flops(mod) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))