From 337f7d10ac46c455b344c8b50b005c138f6e724a Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 24 Mar 2019 20:50:09 -0700 Subject: [PATCH] Remove old canonical --- include/tvm/base.h | 1 + src/arithmetic/canonical.cc | 933 --------------------------- src/arithmetic/canonical.h | 56 -- src/arithmetic/canonical_simplify.cc | 2 +- src/arithmetic/rewrite_simplify.h | 1 + src/arithmetic/stmt_simplify.cc | 11 +- 6 files changed, 13 insertions(+), 991 deletions(-) delete mode 100644 src/arithmetic/canonical.cc delete mode 100644 src/arithmetic/canonical.h diff --git a/include/tvm/base.h b/include/tvm/base.h index 77b90b003f240..863bde52e2a5d 100644 --- a/include/tvm/base.h +++ b/include/tvm/base.h @@ -12,6 +12,7 @@ #include #include #include +#include #include "runtime/registry.h" namespace tvm { diff --git a/src/arithmetic/canonical.cc b/src/arithmetic/canonical.cc deleted file mode 100644 index a884fc0ca4542..0000000000000 --- a/src/arithmetic/canonical.cc +++ /dev/null @@ -1,933 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file canonical.cc - * \brief Canonicalize simplification. - */ -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include "canonical.h" -#include "compute_expr.h" -#include "arithmetic/Simplify.h" - -namespace tvm { -namespace arith { -using namespace ir; - -// Canonical entry for communicative ops. -struct ComExprEntry { - // the value of the expression. - Expr value; - // the level of the expression. - int level{0}; - // The integer scale on value - int64_t scale{1}; - - ComExprEntry() {} - ComExprEntry(Expr value, int level) - : value(value), level(level) {} - inline bool operator<(const ComExprEntry& other) const { - if (level < other.level) return true; - if (level > other.level) return false; - // compare top operator of entries and sort on that if possible (fast check) - if (value.type_index() < other.value.type_index()) return true; - if (value.type_index() > other.value.type_index()) return false; - // if none of the above distinguishes the terms, compare the expression tree of the entries. - // This is a slower check. - int compare_result = Compare(value, other.value); - if (compare_result < 0) return true; - if (compare_result > 0) return false; - // it's a problem if we see identical entries at this point. They should've been merged earlier. - LOG(WARNING) << "we should not have identical entries at this point"; - return false; - } -}; - -// canonical expression for communicative expression. -struct ComExprNode : public NodeBase { - // base constant value. - int64_t base{0}; - // The values to be sumed. - std::vector elem; -}; - -// canonical communicative expression -struct ComExpr { - public: - // constructor - ComExpr() {} - explicit ComExpr(NodePtr ptr) : ptr_(ptr) {} - // get member - ComExprNode* operator->() const { - return ptr_.get(); - } - void reset() { - ptr_.reset(); - } - bool defined() const { - return ptr_.get() != nullptr; - } - // comparator - bool operator<(const ComExpr& b) const { - const ComExpr& a = *this; - if (a->base < b->base) return true; - if (a->base > b->base) return false; - if (a->elem.size() < b->elem.size()) return true; - if (a->elem.size() > b->elem.size()) return false; - for (size_t i = 0; i < a->elem.size(); ++i) { - const ComExprEntry& ea = a->elem[i]; - const ComExprEntry& eb = b->elem[i]; - if (ea.level < eb.level) return true; - if (ea.level > eb.level) return false; - if (ea.value.get() < eb.value.get()) return true; - if (ea.value.get() > eb.value.get()) return false; - if (ea.scale < eb.scale) return true; - if (ea.scale > eb.scale) return false; - } - return false; - } - // equality - bool operator==(const ComExpr& b) const { - const ComExpr& a = *this; - if (a->base != b->base) return false; - if (a->elem.size() != b->elem.size()) return false; - for (size_t i = 0; i < a->elem.size(); ++i) { - const ComExprEntry& ea = a->elem[i]; - const ComExprEntry& eb = b->elem[i]; - if (ea.level != eb.level) return false; - if (ea.value.get() != eb.value.get()) return false; - if (ea.scale != eb.scale) return false; - } - return true; - } - - private: - NodePtr ptr_; -}; - -// binary comparison op. -struct BinaryExpr { - int kind; - Expr lhs, rhs; - // comparator - bool operator<(const BinaryExpr& b) const { - if (kind < b.kind) return true; - if (kind > b.kind) return false; - if (lhs.get() < b.lhs.get()) return true; - if (lhs.get() > b.lhs.get()) return false; - return rhs.get() < b.rhs.get(); - } - // equality - bool operator==(const BinaryExpr& b) const { - return kind == b.kind && - lhs.same_as(b.lhs) && - rhs.same_as(b.rhs); - } -}; - - -template -inline Expr Binary_(const T* op, - const Expr& e, - Expr a, Expr b) { - if (a.same_as(op->a) && b.same_as(op->b)) { - return e; - } else { - return T::make(a, b); - } -} - -// internal of canonical engine. -class Canonical::Internal : public IRMutator { - public: - explicit Internal(Map vrange) { - for (auto kv : vrange) { - SetRange(kv.first, kv.second, 0); - } - } - // stack entry. - struct StackEntry { - int max_level{0}; - bool has_side_effect{false}; - }; - // aggressively canonicalized expression - struct CacheEntry { - // The canonical value of the expression. - Expr value; - // The level of the expression. - int max_level{0}; - // whether the expression might have side effect. - bool has_side_effect{false}; - // if not null, corresponds to to sum - ComExpr sum; - // reset the return entry. - void reset() { - sum.reset(); - } - // as sum expr - ComExpr AsSum() const { - if (sum.defined()) return sum; - const int64_t *v1 = as_const_int(value); - const uint64_t *v2 = as_const_uint(value); - auto n = make_node(); - if (v1) { - n->base = *v1; - } else if (v2) { - CHECK_LE(*v2, - static_cast(std::numeric_limits::max())); - n->base = static_cast(*v2); - } else { - n->elem.push_back(ComExprEntry(value, max_level)); - } - return ComExpr(n); - } - }; - // Set range and level of var. - void SetRange(Var v, Range r, int level) { - var_range_[v.get()] = IntSet::range(r); - var_level_[v.get()] = level; - var_rec_.push_back(v); - } - // functions - Stmt Mutate(Stmt stmt) final { - stmt = IRMutator::Mutate(stmt); - return stmt; - } - Expr MutateExpr_(Expr expr) { - stack_.push_back(StackEntry()); - expr = IRMutator::Mutate(expr); - // update result of parent automatically during pop - if (stack_.size() > 1) { - StackEntry& back = stack_[stack_.size() - 1]; - StackEntry& prev = stack_[stack_.size() - 2]; - prev.max_level = std::max(prev.max_level, back.max_level); - if (back.has_side_effect) prev.has_side_effect = true; - } - // copy result from stack - ret_entry_.has_side_effect = stack_.back().has_side_effect; - ret_entry_.max_level = stack_.back().max_level; - stack_.pop_back(); - CHECK(expr.defined()); - if (const IntImm* op = expr.as()) { - return Mutate_(op, expr); - } - return expr; - } - // call produce to get a cache entry. - CacheEntry Produce(Expr expr) { - ret_entry_.reset(); - ret_entry_.value = MutateExpr_(expr); - CacheEntry ret = ret_entry_; - ret_entry_.reset(); - return ret; - } - Expr Mutate(Expr expr) final { - ret_entry_.reset(); - expr = MutateExpr_(expr); - ret_entry_.reset(); - return expr; - } - - // Check whether do special canonicalization. - bool EnableOpt(Type t) const { - return (t.lanes() == 1 && (t.is_int() || t.is_uint())); - } - // Max - Expr Mutate_(const Max* op, const Expr& e) final { - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - return Binary(op, e); - } - // Min - Expr Mutate_(const Min* op, const Expr& e) final { - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - return Binary(op, e); - } - // Add - Expr Mutate_(const Add* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - return SumAdd(a, b, +1); - } - // Sub - Expr Mutate_(const Sub* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - return SumAdd(a, b, -1); - } - // Mul - Expr Mutate_(const Mul* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - if (is_const(a.value) && is_const(b.value)) { - return ComputeExpr(a.value, b.value); - } else if (is_const(a.value)) { - return SumMulConst(b.AsSum(), a.value); - } else if (is_const(b.value)) { - return SumMulConst(a.AsSum(), b.value); - } else { - return Binary(op, e); - } - } - // Variable - Expr Mutate_(const Variable* op, const Expr& e) final { - auto it = var_level_.find(op); - if (it != var_level_.end()) { - stack_.back().max_level = it->second; - } - return IRMutator::Mutate_(op, e); - } - // comparison - Expr Mutate_(const LT* op, const Expr& e) { - if (!EnableOpt(op->a.type())) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - Expr b_sub_a = SumAdd(b, a, -1); - if (EvalSet(b_sub_a, var_range_).can_prove_positive()) { - return make_const(op->type, true); - } else { - return Binary_(op, e, a.value, b.value); - } - } - // IntImm - Expr Mutate_(const IntImm* op, const Expr& e) final { - if (op->type != Int(32)) return e; - auto it = cache_intimm_.find(op->value); - if (it != cache_intimm_.end()) { - return it->second; - } else { - cache_intimm_[op->value] = e; - return e; - } - } - // Div operator - Expr Mutate_(const Div* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - if (is_const(a.value) && is_const(b.value)) { - return ComputeExpr
(a.value, b.value); - } else if (is_const(b.value)) { - return SumDivConst(a.AsSum(), b.value); - } else { - return Binary(op, e); - } - } - // Mod operator - Expr Mutate_(const Mod* op, const Expr& e) final { - if (!EnableOpt(op->type)) { - return Binary(op, e); - } - CacheEntry a = Produce(op->a); - CacheEntry b = Produce(op->b); - if (a.has_side_effect || b.has_side_effect) { - return Binary_(op, e, a.value, b.value); - } - if (is_const(a.value) && is_const(b.value)) { - return ComputeExpr(a.value, b.value); - } else if (is_const(b.value)) { - return SumModConst(a.AsSum(), b.value); - } else { - return Binary(op, e); - } - } - - Expr Mutate_(const And* op, const Expr& e) final { - Expr expr = IRMutator::Mutate_(op, e); - op = expr.as(); - if (is_one(op->a)) return op->b; - if (is_one(op->b)) return op->a; - return expr; - } - // Call - Expr Mutate_(const Call* op, const Expr& e) final { - if (!op->is_pure()) { - stack_.back().has_side_effect = true; - } - Expr expr = IRMutator::Mutate_(op, e); - op = expr.as(); - if (op->is_intrinsic(Call::likely) && is_const(op->args[0])) { - return op->args[0]; - } else { - return expr; - } - } - // For - Stmt Mutate_(const For* op, const Stmt& s) { - ++level_counter_; - Var loop_var(op->loop_var.node_); - this->SetRange(loop_var, - Range::make_by_min_extent(op->min, op->extent), - level_counter_); - Stmt stmt = IRMutator::Mutate_(op, s); - --level_counter_; - return stmt; - } - // IfThenElse - Stmt Mutate_(const IfThenElse* op, const Stmt& s) { - Stmt stmt = IRMutator::Mutate_(op, s); - op = stmt.as(); - if (is_one(op->condition)) return op->then_case; - return stmt; - } - // AttrStmt - Stmt Mutate_(const AttrStmt* op, const Stmt& s) { - if (op->attr_key == attr::thread_extent || - op->attr_key == attr::virtual_thread) { - ++level_counter_; - IterVar iv(op->node.node_); - CHECK_NE(iv->thread_tag.length(), 0U); - if (!var_level_.count(iv->var.get())) { - this->SetRange(iv->var, - Range::make_by_min_extent(0, op->value), - level_counter_); - } - Stmt stmt = IRMutator::Mutate_(op, s); - --level_counter_; - return stmt; - } else { - return IRMutator::Mutate_(op, s); - } - } - // The simplify statement. - static FMutateExpr& vtable_expr() { // NOLINT(*) - static FMutateExpr inst; return inst; - } - - private: - template - Expr Binary(const T* op, Expr e) { - Expr a = this->Mutate(op->a); - Expr b = this->Mutate(op->b); - BinaryExpr key{static_cast(T::_type_info), a, b}; - auto it = cache_binary_.find(key); - if (it != cache_binary_.end()) { - return it->second; - } else { - Expr ret = Binary_(op, e, a, b); - cache_binary_[key] = ret; - return ret; - } - } - // return entry - CacheEntry ret_entry_; - // internal information stack - std::vector stack_; - // cache sum - std::map cache_sum_; - // cache of normal binary op - std::map cache_binary_; - // cache of int constant - std::unordered_map cache_intimm_; - // range of each var - std::unordered_map var_range_; - // level of each var - std::unordered_map var_level_; - // record history vars, to avoid false positive. - std::vector var_rec_; - // level counter - int level_counter_{0}; - // get constant int value - int64_t GetConstIntValue(const Expr& v) { - int64_t value = 0; - const int64_t *v1 = as_const_int(v); - const uint64_t *v2 = as_const_uint(v); - CHECK(v1 || v2); - if (v1) { - value = *v1; - } else if (v2) { - CHECK_LE(*v2, - static_cast(std::numeric_limits::max())); - value = static_cast(*v2); - } - return value; - } - // Detect if a = q * coeff + r, where r \in [0, coeff), coeff > 0 - // (in Euclidean division) - // returns pair (q, r) if such detection is successful - // returns empty vector otherwise. - // Assumes that coeff is a constant integer - std::vector TryLinearEquation(const ComExpr& a, - const Expr& coeff) { - Type type = coeff.type(); - int64_t value = GetConstIntValue(coeff); - CHECK_NE(value, 0); - if (value < 0) return {}; - // Given that denominator (value variable) is positive, truncated division - // (i.e., TVM's division semantics) is equivalent to Euclidean division if and only if - // numerator is non-negative or numerator is divisible by denominator (i.e., value) - IntSet numerator_int_set = EvalSet(Sum2Expr(a, type), var_range_); - bool numerator_is_non_neg = numerator_int_set.can_prove_non_negative(); - // Try to separate terms of a into ones that can be proven to be - // divisible by coeff and ones that are not - // We will build q and r from divisible and non_divisible respectively - auto divisible = make_node(); - auto non_divisible = make_node(); - if (a->base % value == 0) { - divisible->base = a->base; - } else { - non_divisible->base = a->base; - } - for (const auto& e : a->elem) { - if (e.scale % value == 0) { - divisible->elem.push_back(e); - } else { - non_divisible->elem.push_back(e); - } - } - bool non_divisible_is_simplified = false; - int64_t div_result; - Expr non_divisible_res = Sum2Expr(ComExpr(non_divisible), type); - // if non_divisible part consists of only an integer and numerator is non-negative, - // we can simply divide it by coeff - if (is_const(non_divisible_res)) { - int64_t non_divisible_const = GetConstIntValue(non_divisible_res); - if (numerator_is_non_neg || non_divisible_const == 0) { - non_divisible_is_simplified = true; - // We need to do an Euclidean division here because (a*b + c)/b == a + c/b - // holds true only if division is Euclidean - div_result = HalideIR::Internal::div_imp(non_divisible_const , value); - } - } else { - // If we can prove that non_divisible part lies within [0, coeff), then - // non_divisible itself will be our r - IntSet non_divisible_set = EvalSet(non_divisible_res, var_range_); - if (non_divisible_set.min().type() == type && - non_divisible_set.max().type() == type) { - if ( (non_divisible_set.is_single_point() && - can_prove(non_divisible_set.point_value() == 0)) || - (numerator_is_non_neg && - can_prove(non_divisible_set.min() >= make_zero(type)) && - can_prove(non_divisible_set.max() < coeff)) ) { - non_divisible_is_simplified = true; - div_result = 0; - } - } - } - if (non_divisible_is_simplified) { - non_divisible->base -= div_result * value; - divisible->base /= value; - divisible->base += div_result; - for (auto& e : divisible->elem) { - e.scale /= value; - } - return {ComExpr(divisible), ComExpr(non_divisible)}; - } else { - return {}; - } - } - // subroutine to do produce a % v - Expr SumModConst(ComExpr a, Expr v) { - std::vector pair = TryLinearEquation(a, v); - if (pair.size() == 0) { - int64_t value = GetConstIntValue(v); - auto n = make_node(); - // FIXME(derisavi) : The following can be done only for Euclidean division/mod. - // Therefore, it's only valid when truncated division/mod is equivalent to Euclidean one, - // that is, if and only if a and v are - // both negative or both positive or a is divisible by v. - // Extend the code to handle cases where the above condition is not satisfied, i.e., - // a and v are of different signs and a is not divisible by v. - n->base = a->base % value; - for (auto e : a->elem) { - if (e.scale % value == 0) continue; - e.scale = e.scale % value; - n->elem.push_back(e); - } - Expr ret = Sum2Expr(ComExpr(n), v.type()) % v; - if (const Mod* mod = ret.as()) { - return Binary(mod, ret); - } else { - // Sometimes the result is a constant, this may happen when value is -1 - CHECK(is_const(ret)) << "CanonicalSimplify: " - << Sum2Expr(ComExpr(n), v.type()) << " % " << v << " is " << ret - << " which is neither Mod, nor a constant"; - return ret; - } - } - ret_entry_.sum = pair[1]; - ret_entry_.max_level = stack_.back().max_level; - ret_entry_.has_side_effect = stack_.back().has_side_effect; - auto it = cache_sum_.find(ret_entry_.sum); - if (it != cache_sum_.end()) { - ret_entry_ = it->second; - } else { - ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type()); - cache_sum_[ret_entry_.sum] = ret_entry_; - } - return ret_entry_.value; - } - // subroutine to do produce a % v - Expr SumDivConst(ComExpr a, Expr v) { - std::vector pair = TryLinearEquation(a, v); - if (pair.size() == 0) { - Expr ret = Sum2Expr(a, v.type()) / v; - return Binary(ret.as
(), ret); - } - ret_entry_.sum = pair[0]; - ret_entry_.max_level = stack_.back().max_level; - ret_entry_.has_side_effect = stack_.back().has_side_effect; - auto it = cache_sum_.find(ret_entry_.sum); - if (it != cache_sum_.end()) { - ret_entry_ = it->second; - } else { - ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type()); - cache_sum_[ret_entry_.sum] = ret_entry_; - } - return ret_entry_.value; - } - // subroutine to do produce - Expr SumMulConst(ComExpr a, Expr v) { - int64_t value = GetConstIntValue(v); - if (value == 0) { - return make_zero(v.type()); - } - auto vsum = make_node(*a.operator->()); - vsum->base *= value; - for (auto& e : vsum->elem) { - e.scale *= value; - } - ret_entry_.sum = ComExpr(vsum); - ret_entry_.max_level = stack_.back().max_level; - ret_entry_.has_side_effect = stack_.back().has_side_effect; - auto it = cache_sum_.find(ret_entry_.sum); - if (it != cache_sum_.end()) { - ret_entry_ = it->second; - } else { - ret_entry_.value = Sum2Expr(ret_entry_.sum, v.type()); - cache_sum_[ret_entry_.sum] = ret_entry_; - } - return ret_entry_.value; - } - // add two ComExpr together - ComExpr SumAdd_(const ComExpr& suma, - const ComExpr& sumb, - int bscale) { - auto n = make_node(); - n->base = suma->base + sumb->base * bscale; - // merge of suma and sumb; - size_t i = 0, j = 0; - while (i < suma->elem.size() && j < sumb->elem.size()) { - const auto& a = suma->elem[i]; - const auto& b = sumb->elem[j]; - if (a.value.same_as(b.value) && a.level == b.level) { - ComExprEntry e = a; - e.scale = a.scale + b.scale * bscale; - if (e.scale != 0) { - n->elem.push_back(e); - } - ++i; ++j; - } else if (a < b) { - n->elem.push_back(a); - ++i; - } else { - ComExprEntry e = b; - e.scale *= bscale; - n->elem.push_back(e); - ++j; - } - } - for (; i < suma->elem.size(); ++i) { - n->elem.push_back(suma->elem[i]); - } - for (; j < sumb->elem.size(); ++j) { - ComExprEntry e = sumb->elem[j]; - e.scale *= bscale; - n->elem.push_back(e); - } - return ComExpr(n); - } - // subroutine to do produce - Expr SumAdd(CacheEntry a, CacheEntry b, int bscale) { - ret_entry_.sum = SumAdd_(a.AsSum(), b.AsSum(), bscale); - CHECK_NE(stack_.size(), 0U); - ret_entry_.max_level = stack_.back().max_level; - ret_entry_.has_side_effect = stack_.back().has_side_effect; - auto it = cache_sum_.find(ret_entry_.sum); - if (it != cache_sum_.end()) { - ret_entry_ = it->second; - } else { - ret_entry_.value = Sum2Expr(ret_entry_.sum, a.value.type()); - cache_sum_[ret_entry_.sum] = ret_entry_; - } - return ret_entry_.value; - } - // convert sum to expr - Expr Sum2Expr(const ComExpr& com, Type t) { - Expr vsum; - if (com->base > 0) { - vsum = make_const(t, com->base); - } - for (const ComExprEntry& e : com->elem) { - if (e.scale > 0) { - Expr v = e.value; - if (e.scale != 1) { - v = Mul::make(v, make_const(t, e.scale)); - } - if (vsum.defined()) { - vsum = Add::make(vsum, v); - } else { - vsum = v; - } - } - } - if (com->base < 0) { - if (vsum.defined()) { - vsum = Sub::make(vsum, make_const(t, -com->base)); - } else { - vsum = make_const(t, com->base); - } - } - for (const ComExprEntry& e : com->elem) { - if (e.scale < 0) { - Expr v = e.value; - if (e.scale != -1) { - v = Mul::make(v, make_const(t, -e.scale)); - } - if (vsum.defined()) { - vsum = Sub::make(vsum, v); - } else { - vsum = Sub::make(make_zero(t), v); - } - } - } - if (vsum.defined()) { - return vsum; - } else { - return make_zero(t); - } - } -}; - -using CInternal = Canonical::Internal; - -Canonical::Canonical(Map vrange) - : ptr_(std::make_shared(vrange)) {} - -Expr Canonical::Simplify(Expr expr) { - return ptr_->Mutate(expr); -} - -Stmt Canonical::Simplify(Stmt stmt) { - return ptr_->Mutate(stmt); -} - -void Canonical::SetRange(Var v, Range r, int level) { - ptr_->SetRange(v, r, level); -} -} // namespace arith - -namespace ir { -Stmt CanonicalSimplifyX(Stmt stmt, Map vrange) { - return arith::Canonical(vrange).Simplify(stmt); -} - -Expr CanonicalSimplifyX(Expr expr, Map vrange) { - return arith::Canonical(vrange).Simplify(expr); -} - -template -T Simplify_(T a, Map vrange) { - using namespace HalideIR::Internal; - Scope rscope; - for (auto kv : vrange) { - Range r = kv.second; - rscope.push( - kv.first.get(), - Interval(r->min, - simplify(r->min + r->extent - make_const(r->min.type(), 1)))); - } - return HalideIR::Internal::simplify(a, true, rscope); -} - -/*! - * \brief Simplify just the combiner of the given reduce node. - * - * This function applies Simplify to the components of the top reduction's - * combiner, but not to the source or condition of the reduction. - * It also removes all components which are not used to - * compute the resulting value (the value_index-th value). - * - * If \p expr is not a reduction node, it is left unchanged. - * - * \param expr The expression to be simplifed. - * \return Simplified expression. - */ -Expr SimplifyCombiner(const Expr& expr, const Map& vrange = Map()) { - const Reduce* op = expr.as(); - if (!op) { - return expr; - } - - // First simplify the results - Array simplified_result; - for (const auto& res : op->combiner->result) { - simplified_result.push_back(Simplify(res, vrange)); - } - - // Which components to keep - std::vector used(op->combiner->result.size(), false); - - // This function recursively marks the used components starting from - // the index idx - std::function mark_used; - mark_used = [&used, &simplified_result, op, &mark_used](size_t idx) { - // if the idx-th component was marked as used before, do nothing - if (used[idx]) return; - used[idx] = true; - - // check if the idx-th result expr uses some lhs or rhs variables - // and recursively mark the corresponding components - for (size_t i = 0; i < simplified_result.size(); ++i) - if (!used[i]) { - if (ExprUseVar(simplified_result[idx], op->combiner->lhs[i]) || - ExprUseVar(simplified_result[idx], op->combiner->rhs[i])) - mark_used(i); - } - }; - - // mark all used components starting from the value_index - mark_used(op->value_index); - - // components which have side effects should also be preserved - for (size_t i = 0; i < used.size(); ++i) { - if (HasSideEffect(op->source[i]) || HasSideEffect(op->combiner->identity_element[i]) || - HasSideEffect(op->combiner->result[i])) { - mark_used(i); - } - } - - int new_value_index = op->value_index; - Array new_result; - Array new_identity; - Array new_lhs; - Array new_rhs; - Array new_source; - - // new stuff is old stuff which is used - for (size_t i = 0; i < used.size(); ++i) { - if (used[i]) { - // We simplify the result and identity, but not the source - new_result.push_back(simplified_result[i]); - new_identity.push_back(Simplify(op->combiner->identity_element[i], vrange)); - new_lhs.push_back(op->combiner->lhs[i]); - new_rhs.push_back(op->combiner->rhs[i]); - new_source.push_back(op->source[i]); - } else if (static_cast(i) < op->value_index) { - // value_index should also be adjusted - new_value_index--; - } - } - - CommReducer new_combiner = CommReducerNode::make(new_lhs, new_rhs, new_result, new_identity); - return Reduce::make(new_combiner, new_source, op->axis, op->condition, new_value_index); -} - -/*! - * \brief Remove a single reduction over empty axis. - * - * If \p e is a reduction node and its axis is empty, replace it with its source, - * otherwise return \p e unchanged. - * - * \param e The expression to be transformed. - * \return The transformed expression. - */ -Expr RemoveEmptyReduction(const Expr& e) { - const Reduce* r = e.as(); - if (r && r->axis.empty()) { - // Note that here we assume that the identity element is indeed identity. Without this - // assumption we would have to perform a single iteration of the loop, i.e. use - // `(*r->combiner.get())(r->combiner->identity_element, r->source)[r->value_index]` - // instead of `r->source[r->value_index]`. The former may be more difficult to simplify. - return Select::make(r->condition, - r->source[r->value_index], - r->combiner->identity_element[r->value_index]); - } - return e; -} - -Expr Simplify(Expr a, Map vrange) { - // We should not pass an expression having a non-HalideIR op to - // Halide::Internal::simplify. Reduce op is the only such op at this time - // and it only appears as the top op in an expression. So we strip it - // first and send the sub-expressions to the simplifier. - if (const Reduce* r = a.as()) { - // If axis is empty, we can remove the reduce op completely. - if (r->axis.empty()) - return Simplify_(RemoveEmptyReduction(a), vrange); - - // Simplify the combiner of the reduction - a = SimplifyCombiner(a, vrange); - r = a.as(); - - // If axis is not empty then we add the information about ranges to vrange - for (const IterVar& iv : r->axis) { - if (vrange.count(iv->var)) { - Range existing_range = vrange[iv->var]; - CHECK(Equal(existing_range->min, iv->dom->min) && - Equal(existing_range->extent, iv->dom->extent)) - << "Simplify was given vrange stating that the range of the reduction var " - << iv << " is " << existing_range << ". This is probably a mistake."; - } - vrange.Set(iv->var, iv->dom); - } - - Array new_source; - for (auto& e : r->source) { - new_source.push_back(Simplify_(e, vrange)); - } - Expr new_condition = Simplify_(r->condition, vrange); - if (r->source.same_as(new_source) && - r->condition.same_as(new_condition)) { - return a; - } else { - return Reduce::make( - r->combiner, new_source, r->axis, new_condition, r->value_index); - } - } - return Simplify_(a, vrange); -} - -} // namespace ir -} // namespace tvm diff --git a/src/arithmetic/canonical.h b/src/arithmetic/canonical.h deleted file mode 100644 index a02dbeef7e3a8..0000000000000 --- a/src/arithmetic/canonical.h +++ /dev/null @@ -1,56 +0,0 @@ -/*! - * Copyright (c) 2017 by Contributors - * \file canonical.h - * \brief Internal canonicalized expression simplification engine. - */ -#ifndef TVM_ARITHMETIC_CANONICAL_H_ -#define TVM_ARITHMETIC_CANONICAL_H_ - -#include -#include -#include - -namespace tvm { -namespace arith { - -/*! - * \brief A stateful CanonicalEngine over SSA. - * - * Simplify and CSE with canonicalization expressions. - * Each call's result will get cached, so next call will - * simply return the cached result. - */ -class Canonical { - public: - /*! \brief constructor */ - explicit Canonical(Map var_range); - /*! - * \brief simplify expression e. - * \param expr The expression to be simplified. - */ - Expr Simplify(Expr expr); - /*! - * \brief simplify stmt. - * \param stmt The stmt to be simplified. - */ - Stmt Simplify(Stmt expr); - /*! - * \brief Set range and level variable - * \param v The variable - * \param r The range of the variable, can be undefined. - * \param level The scope level of the variable, - * affect the order of formula in communicative ops. - */ - void SetRange(Var v, Range r, int level); - - class Internal; - private: - // Internal pointer - std::shared_ptr ptr_; -}; - - -} // namespace arith -} // namespace tvm - -#endif // TVM_ARITHMETIC_CANONICAL_H_ diff --git a/src/arithmetic/canonical_simplify.cc b/src/arithmetic/canonical_simplify.cc index 5c44abbb76260..cd3d75d858746 100644 --- a/src/arithmetic/canonical_simplify.cc +++ b/src/arithmetic/canonical_simplify.cc @@ -647,7 +647,7 @@ Mutate_(const Mod* op, const Expr& self) { SumExpr lhs, extra; if (TryLinearEquation(psum, cval, &lhs, &extra)) { Expr temp = Normalize(extra); - if (const auto* pconst = temp.as()) { + if (temp.as()) { return temp % c1.Eval(); } else { // If temp < cval && temp >=0 then can remove the mod. diff --git a/src/arithmetic/rewrite_simplify.h b/src/arithmetic/rewrite_simplify.h index e825dad79978c..e3435fe9b197b 100644 --- a/src/arithmetic/rewrite_simplify.h +++ b/src/arithmetic/rewrite_simplify.h @@ -9,6 +9,7 @@ #include #include #include +#include #include "const_fold.h" #include "pattern_match.h" diff --git a/src/arithmetic/stmt_simplify.cc b/src/arithmetic/stmt_simplify.cc index ef98d243997f2..4184ccb73202f 100644 --- a/src/arithmetic/stmt_simplify.cc +++ b/src/arithmetic/stmt_simplify.cc @@ -126,7 +126,8 @@ Expr CanonicalSimplify(Expr expr, Map vrange) { return analyzer.canonical_simplify(expr); } -Stmt Simplify(Stmt a, Map vrange) { +template +T Simplify_(T a, Map vrange) { using namespace HalideIR::Internal; Scope rscope; for (auto kv : vrange) { @@ -138,5 +139,13 @@ Stmt Simplify(Stmt a, Map vrange) { } return HalideIR::Internal::simplify(a, true, rscope); } + +Expr Simplify(Expr a, Map vrange) { + return Simplify_(a, vrange); +} + +Stmt Simplify(Stmt a, Map vrange) { + return Simplify_(a, vrange); +} } // namespace ir } // namespace tvm