diff --git a/include/tvm/arithmetic.h b/include/tvm/arithmetic.h index 1c56629a46d6f..d3f3e15344cc9 100644 --- a/include/tvm/arithmetic.h +++ b/include/tvm/arithmetic.h @@ -75,8 +75,7 @@ class ConstIntBoundAnalyzer { * \return the result of the analysis. */ ConstIntBound operator()(const Expr& expr); - /*! \brief reset and clear all internal states. */ - void Reset(); + /*! * \brief Update constant int bound information of var. * @@ -87,6 +86,13 @@ class ConstIntBoundAnalyzer { void Update(const Var& var, const ConstIntBound& info, bool override = false); + /*! + * \brief Bind variable to a range. + * + * \param var The variable. + * \param range The range we bind to. + */ + void Bind(const Var& var, const Range& range); private: friend class Analyzer; @@ -244,7 +250,17 @@ class Analyzer { * \param var The variable. * \param expr The expression we bind to. */ - void Bind(const Var& var, const Expr& expr); + void Bind(const VarExpr& var, const Expr& expr); + /*! + * \brief Notify all the sub-analyzers that var + * is created and binded to a range. + * + * Each var can only be binded once. + * + * \param var The variable. + * \param range The range we bind to. + */ + void Bind(const VarExpr& var, const Range& range); /*! * \brief Whether can we proof expr >= val. @@ -513,23 +529,6 @@ IntSet DeduceBound(Expr v, Expr cond, */ Domain DomainTouched(Stmt body, const Tensor &tensor, bool consider_calls, bool consider_provides); -// Temporary entry for modular -// TODO(tqchen) use Analyzer. -struct ModularEntry { - int64_t coeff{1}; - int64_t base{0}; -}; - -/*! - * \brief Evaluate the expression with modular analysis - * \param e The expression to be evaluated. - * \param mod_map Map of modular statistics of known variables. - * \return The ModularEntry covering all possible value of e. - */ -ModularEntry EvalModular( - const Expr& e, - const std::unordered_map& mod_map); - // implementation inline const IntSetNode* IntSet::operator->() const { return static_cast(node_.get()); diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc index 004af09861519..cba70370f5b6d 100644 --- a/src/api/api_arith.cc +++ b/src/api/api_arith.cc @@ -100,7 +100,12 @@ TVM_REGISTER_API("arith._CreateAnalyzer") }); } else if (name == "bind") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { - self->Bind(args[0], args[1]); + auto& sptr = args[1].node_sptr(); + if (sptr->is_type()) { + self->Bind(args[0], args[1].operator Range()); + } else { + self->Bind(args[0], args[1].operator Expr()); + } }); } else if (name == "enter_constraint_context") { return PackedFunc([self](TVMArgs args, TVMRetValue *ret) { diff --git a/src/arithmetic/analyzer.cc b/src/arithmetic/analyzer.cc index 382b9f8b40de7..236a21ba71f50 100644 --- a/src/arithmetic/analyzer.cc +++ b/src/arithmetic/analyzer.cc @@ -12,11 +12,18 @@ Analyzer::Analyzer() modular_set(this) { } -void Analyzer::Bind(const Var& var, const Expr& expr) { +void Analyzer::Bind(const VarExpr& v, const Expr& expr) { + Var var(v.node_); this->const_int_bound.Update(var, this->const_int_bound(expr)); this->modular_set.Update(var, this->modular_set(expr)); } +void Analyzer::Bind(const VarExpr& v, const Range& range) { + Var var(v.node_); + this->const_int_bound.Bind(var, range); + // skip modular_set +} + ConstraintContext::ConstraintContext(Analyzer* analyzer, const Expr& constraint) { // entering the scope. auto f0 = analyzer->const_int_bound.EnterConstraint(constraint); diff --git a/src/arithmetic/const_int_bound.cc b/src/arithmetic/const_int_bound.cc index 4b1c7623e1925..c83be8933b55a 100644 --- a/src/arithmetic/const_int_bound.cc +++ b/src/arithmetic/const_int_bound.cc @@ -42,13 +42,28 @@ struct ConstIntBoundAnalyzer::Entry { class ConstIntBoundAnalyzer::Impl : public ExprFunctor { public: + void Bind(const Var& var, const Range& range) { + Entry a = VisitExpr(range->min); + Entry b = VisitExpr(range->extent); + Entry ret; + ret.min_value = a.min_value; + ret.max_value = InfAwareAdd(a.max_value, InfAwareAdd(b.max_value, -1)); + Update(var, ret, false); + } + void Update(const Var& var, - const ConstIntBound& info, + const Entry& info, bool override) { if (!override) { CHECK(!var_map_.count(var)); } - var_map_[var] = MakeBound(info->min_value, info->max_value); + var_map_[var] = info; + } + + void Update(const Var& var, + const ConstIntBound& info, + bool override) { + Update(var, MakeBound(info->min_value, info->max_value), override); } // Override visitor behaviors @@ -358,6 +373,10 @@ void ConstIntBoundAnalyzer::Update(const Var& var, impl_->Update(var, info, override); } +void ConstIntBoundAnalyzer::Bind(const Var& var, const Range& range) { + impl_->Bind(var, range); +} + std::function ConstIntBoundAnalyzer::EnterConstraint(const Expr& constraint) { return nullptr; } diff --git a/src/arithmetic/modular_set.cc b/src/arithmetic/modular_set.cc index f79d0cba50689..8da6e91fc7fa0 100644 --- a/src/arithmetic/modular_set.cc +++ b/src/arithmetic/modular_set.cc @@ -340,22 +340,5 @@ ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; } - -ModularEntry EvalModular( - const Expr& e, - const std::unordered_map& mod_map) { - Analyzer ana; - for (const auto& kv : mod_map) { - auto v = kv.second; - ana.modular_set.Update( - GetRef(kv.first), ModularSetNode::make(v.coeff, v.base)); - } - auto mod = ana.modular_set(e); - ModularEntry ret; - ret.coeff = mod->coeff; - ret.base = mod->base; - return ret; -} - } // namespace arith } // namespace tvm diff --git a/src/codegen/codegen_common.h b/src/codegen/codegen_common.h deleted file mode 100644 index 5e76af12e5834..0000000000000 --- a/src/codegen/codegen_common.h +++ /dev/null @@ -1,59 +0,0 @@ -/*! - * Copyright (c) 2018 by Contributors - * \file codegen_common.h - * \brief Common utility for codegen. - */ -#ifndef TVM_CODEGEN_CODEGEN_COMMON_H_ -#define TVM_CODEGEN_CODEGEN_COMMON_H_ - -#include -#include "../arithmetic/compute_expr.h" - -namespace tvm { -namespace codegen { - -/*! - * \brief Visit AssertStmt recursively, update align_map from condition. - * \param op The AssertStmt - * \param align_map The alignmap - * \param fvisit The recursive visitor - * \tparam FVisit the recursive visitor - */ -template -inline void VisitAssert( - const ir::AssertStmt* op, - std::unordered_map* align_map, - FVisit fvisit) { - using namespace ir; - auto& align_map_ = *align_map; - // Detect useful invariant pattern and use them to visit child. - // Pattern: Var % const == 0 - // TODO(tqchen) merge these pattern to a generic scope info visitor. - if (const EQ* eq = op->condition.as()) { - const Mod* mod = eq->a.as(); - int64_t factor = 0, offset = 0; - if (mod && arith::GetConst(eq->b, &offset)) { - const Variable *var = mod->a.as(); - if (var && arith::GetConst(mod->b, &factor)) { - arith::ModularEntry old = align_map_[var]; - if (factor > old.coeff) { - arith::ModularEntry e; - e.coeff = static_cast(factor); - e.base = static_cast(offset); - // new alignment info, - align_map_[var] = e; - fvisit(op->body); - // restore old info - align_map_[var] = old; - return; - } - } - } - } - fvisit(op->body); -} - -} // namespace codegen -} // namespace tvm - -#endif // TVM_CODEGEN_CODEGEN_COMMON_H_ diff --git a/src/codegen/llvm/codegen_llvm.cc b/src/codegen/llvm/codegen_llvm.cc index f80bd9e8d4360..6b69f97a66fe5 100644 --- a/src/codegen/llvm/codegen_llvm.cc +++ b/src/codegen/llvm/codegen_llvm.cc @@ -9,7 +9,6 @@ #include #include "codegen_llvm.h" #include "codegen_cpu.h" -#include "../codegen_common.h" #include "../../pass/ir_util.h" #include "../../arithmetic/compute_expr.h" @@ -84,9 +83,9 @@ void CodeGenLLVM::AddFunction(const LoweredFunc& f) { void CodeGenLLVM::InitFuncState() { var_map_.clear(); alias_var_set_.clear(); - align_map_.clear(); alloc_storage_info_.clear(); volatile_buf_.clear(); + analyzer_.reset(new arith::Analyzer()); } void CodeGenLLVM::AddFunctionInternal(const LoweredFunc& f, bool ret_void) { @@ -381,14 +380,16 @@ void CodeGenLLVM::GetAlignment(Type t, *p_native_bits = native_vector_bits_; } - arith::ModularEntry me = arith::EvalModular(index, align_map_); + arith::ModularSet me = analyzer_->modular_set(index); + int64_t base = me->base; + int64_t coeff = me->coeff; int align_bits = t.bits(); while (align_bits < max_align_bits && - me.base % 2 == 0 && - me.coeff % 2 == 0) { - me.base = me.base / 2; - me.coeff = me.coeff / 2; + base % 2 == 0 && + coeff % 2 == 0) { + base = base / 2; + coeff = coeff / 2; align_bits *= 2; } if (align_bits < 8) { @@ -874,7 +875,7 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const Select* op) { llvm::Value* CodeGenLLVM::VisitExpr_(const Let* op) { CHECK(!var_map_.count(op->var.get())); var_map_[op->var.get()] = MakeValue(op->value); - align_map_[op->var.get()] = EvalModular(op->value, align_map_); + analyzer_->Bind(op->var, op->value); return MakeValue(op->body); } @@ -998,6 +999,7 @@ void CodeGenLLVM::VisitStmt_(const Store* op) { void CodeGenLLVM::VisitStmt_(const For* op) { CHECK(is_zero(op->min)); + analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); if (op->for_type == ForType::Unrolled) { LOG(WARNING) << "Unroll hint get ignore at CodeGenLLVM backend, " << " consider set unroll_explicit=True"; @@ -1078,6 +1080,7 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv); + analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); } } } else if (op->attr_key == ir::attr::storage_scope) { @@ -1099,21 +1102,19 @@ void CodeGenLLVM::VisitStmt_(const AttrStmt* op) { } void CodeGenLLVM::VisitStmt_(const AssertStmt* op) { - VisitAssert(op, &align_map_, [this](const Stmt& body) { - this->VisitStmt(body); - }); + arith::ConstraintContext cctx(analyzer_.get(), op->condition); + this->VisitStmt(op->body); } void CodeGenLLVM::VisitStmt_(const LetStmt* op) { CHECK(!var_map_.count(op->var.get())); - CHECK(!align_map_.count(op->var.get())); if (op->var.type().is_handle()) { if (!is_restricted_) { alias_var_set_.insert(op->var.get()); } } var_map_[op->var.get()] = MakeValue(op->value); - align_map_[op->var.get()] = EvalModular(op->value, align_map_); + analyzer_->Bind(op->var, op->value); this->VisitStmt(op->body); } diff --git a/src/codegen/llvm/codegen_llvm.h b/src/codegen/llvm/codegen_llvm.h index 0803063103709..ead1af883166c 100644 --- a/src/codegen/llvm/codegen_llvm.h +++ b/src/codegen/llvm/codegen_llvm.h @@ -23,7 +23,6 @@ namespace codegen { using namespace ir; - /*! * \brief A base class to generate a LLVM. */ @@ -267,8 +266,8 @@ class CodeGenLLVM : std::unordered_map str_map_; // Whether current function is restricted bool is_restricted_{true}; - // The alignment information - std::unordered_map align_map_; + // The analyzer information + std::unique_ptr analyzer_; // set of var that are not restricted(can alias) std::unordered_set alias_var_set_; // set of volatile buffer. diff --git a/src/codegen/spirv/codegen_spirv.cc b/src/codegen/spirv/codegen_spirv.cc index 812fee4a114e0..8b1cabd9e386d 100644 --- a/src/codegen/spirv/codegen_spirv.cc +++ b/src/codegen/spirv/codegen_spirv.cc @@ -6,7 +6,7 @@ #include #include #include -#include "../codegen_common.h" +#include "../../arithmetic/compute_expr.h" #include "codegen_spirv.h" namespace tvm { @@ -66,7 +66,7 @@ void CodeGenSPIRV::InitFuncState() { std::fill(workgroup_size_, workgroup_size_ + 3, 1); var_map_.clear(); storage_info_.clear(); - align_map_.clear(); + analyzer_.reset(new arith::Analyzer()); builder_.reset(new spirv::IRBuilder()); builder_->InitHeader(); } @@ -217,7 +217,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Select* op) { spirv::Value CodeGenSPIRV::VisitExpr_(const Let* op) { CHECK(!var_map_.count(op->var.get())); var_map_[op->var.get()] = MakeValue(op->value); - align_map_[op->var.get()] = EvalModular(op->value, align_map_); + analyzer_->Bind(op->var, op->value); return MakeValue(op->body); } @@ -378,9 +378,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const Load* op) { if (const Ramp* ramp = op->index.as()) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->type.lanes()); - arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_); - CHECK((me.coeff % ramp->lanes) == 0 && - (me.base % ramp->lanes) == 0) + arith::ModularSet me = analyzer_->modular_set(ramp->base); + CHECK((me->coeff % ramp->lanes) == 0 && + (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; Expr vec_index = ir::Simplify( ramp->base / make_const(ramp->base.type(), ramp->lanes)); @@ -458,9 +458,9 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) { if (const Ramp* ramp = op->index.as()) { if (is_one(ramp->stride)) { CHECK_EQ(ramp->lanes, op->value.type().lanes()); - arith::ModularEntry me = arith::EvalModular(ramp->base, align_map_); - CHECK((me.coeff % ramp->lanes) == 0 && - (me.base % ramp->lanes) == 0) + arith::ModularSet me = analyzer_->modular_set(ramp->base); + CHECK((me->coeff % ramp->lanes) == 0 && + (me->base % ramp->lanes) == 0) << "Only aligned vector access is allowed in SPIRV"; Expr vec_index = ir::Simplify( ramp->base / make_const(ramp->base.type(), ramp->lanes)); @@ -477,6 +477,7 @@ void CodeGenSPIRV::VisitStmt_(const Store* op) { void CodeGenSPIRV::VisitStmt_(const For* op) { CHECK(is_zero(op->min)); + analyzer_->Bind(op->loop_var, Range::make_by_min_extent(op->min, op->extent)); spirv::Value init_value = MakeValue(op->min); spirv::Value extent_value = MakeValue(op->extent); // Must get init label after making value(to make sure they are correct) @@ -589,6 +590,7 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { if (iv->thread_tag.length() != 0) { if (!var_map_.count(iv->var.get())) { var_map_[iv->var.get()] = GetThreadIndex(iv, op->value); + analyzer_->Bind(iv->var, Range::make_by_min_extent(0, op->value)); } } } else if (op->attr_key == ir::attr::storage_scope) { @@ -605,17 +607,15 @@ void CodeGenSPIRV::VisitStmt_(const AttrStmt* op) { } void CodeGenSPIRV::VisitStmt_(const AssertStmt* op) { - VisitAssert(op, &align_map_, [this](const Stmt& body) { - this->VisitStmt(body); - }); + arith::ConstraintContext cctx(analyzer_.get(), op->condition); + this->VisitStmt(op->body); } void CodeGenSPIRV::VisitStmt_(const LetStmt* op) { CHECK(!var_map_.count(op->var.get())); - CHECK(!align_map_.count(op->var.get())); CHECK(!op->var.type().is_handle()); var_map_[op->var.get()] = MakeValue(op->value); - align_map_[op->var.get()] = EvalModular(op->value, align_map_); + analyzer_->Bind(op->var, op->value); this->VisitStmt(op->body); } diff --git a/src/codegen/spirv/codegen_spirv.h b/src/codegen/spirv/codegen_spirv.h index 6a43182f7f2e3..94cf761b9f847 100644 --- a/src/codegen/spirv/codegen_spirv.h +++ b/src/codegen/spirv/codegen_spirv.h @@ -122,8 +122,8 @@ class CodeGenSPIRV: std::unordered_map storage_info_; // The definition of local variable. std::unordered_map var_map_; - // The alignment information - std::unordered_map align_map_; + // The analyzer. + std::unique_ptr analyzer_; }; } // namespace codegen