diff --git a/ci/jenkins/generated/arm_jenkinsfile.groovy b/ci/jenkins/generated/arm_jenkinsfile.groovy index 4c830dce2c30..14ad1ad78022 100644 --- a/ci/jenkins/generated/arm_jenkinsfile.groovy +++ b/ci/jenkins/generated/arm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.672484 +// Generated at 2023-04-25T11:40:51.453275 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/cortexm_jenkinsfile.groovy b/ci/jenkins/generated/cortexm_jenkinsfile.groovy index d8a4d4671e86..3f0347c37b74 100644 --- a/ci/jenkins/generated/cortexm_jenkinsfile.groovy +++ b/ci/jenkins/generated/cortexm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.614676 +// Generated at 2023-04-25T11:40:51.505590 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/cpu_jenkinsfile.groovy b/ci/jenkins/generated/cpu_jenkinsfile.groovy index cdd2564e0591..caaeafcb7863 100644 --- a/ci/jenkins/generated/cpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/cpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.563887 +// Generated at 2023-04-25T11:40:51.627063 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/docker_jenkinsfile.groovy b/ci/jenkins/generated/docker_jenkinsfile.groovy index 32dec7863bcf..d0b37bdc5227 100644 --- a/ci/jenkins/generated/docker_jenkinsfile.groovy +++ b/ci/jenkins/generated/docker_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.699838 +// Generated at 2023-04-26T17:36:59.403201 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( @@ -845,7 +846,7 @@ def deploy() { -if (rebuild_docker_images) { +if (false && rebuild_docker_images) { stage('Docker Image Build') { parallel( 'ci_arm': { diff --git a/ci/jenkins/generated/gpu_jenkinsfile.groovy b/ci/jenkins/generated/gpu_jenkinsfile.groovy index 390c8ddc3dc2..428caedcbfd1 100644 --- a/ci/jenkins/generated/gpu_jenkinsfile.groovy +++ b/ci/jenkins/generated/gpu_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.640362 +// Generated at 2023-04-25T11:40:51.523364 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/hexagon_jenkinsfile.groovy b/ci/jenkins/generated/hexagon_jenkinsfile.groovy index 58fe4d14c969..e774518be8cb 100644 --- a/ci/jenkins/generated/hexagon_jenkinsfile.groovy +++ b/ci/jenkins/generated/hexagon_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.512545 +// Generated at 2023-04-25T11:40:51.434735 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/i386_jenkinsfile.groovy b/ci/jenkins/generated/i386_jenkinsfile.groovy index b5bf5cb1fe40..3f3fed8244e4 100644 --- a/ci/jenkins/generated/i386_jenkinsfile.groovy +++ b/ci/jenkins/generated/i386_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.590456 +// Generated at 2023-04-25T11:40:51.488582 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/lint_jenkinsfile.groovy b/ci/jenkins/generated/lint_jenkinsfile.groovy index ed5aa8d67954..52d6036d2cc6 100644 --- a/ci/jenkins/generated/lint_jenkinsfile.groovy +++ b/ci/jenkins/generated/lint_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.725728 +// Generated at 2023-04-25T11:40:51.545459 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy index 4c748e3f20d7..f6d8f52c6459 100644 --- a/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_cross_isa_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-07T23:01:16.071376 +// Generated at 2023-04-25T11:40:51.596303 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/minimal_jenkinsfile.groovy b/ci/jenkins/generated/minimal_jenkinsfile.groovy index 72864ec4ca0f..6b25b3706354 100644 --- a/ci/jenkins/generated/minimal_jenkinsfile.groovy +++ b/ci/jenkins/generated/minimal_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.540335 +// Generated at 2023-04-25T11:40:51.561737 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/riscv_jenkinsfile.groovy b/ci/jenkins/generated/riscv_jenkinsfile.groovy index 2dfeb3561281..47f2d6c92f09 100644 --- a/ci/jenkins/generated/riscv_jenkinsfile.groovy +++ b/ci/jenkins/generated/riscv_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.792163 +// Generated at 2023-04-25T11:40:51.472038 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/generated/wasm_jenkinsfile.groovy b/ci/jenkins/generated/wasm_jenkinsfile.groovy index 27e8f6570ed0..bd84e4fef240 100644 --- a/ci/jenkins/generated/wasm_jenkinsfile.groovy +++ b/ci/jenkins/generated/wasm_jenkinsfile.groovy @@ -60,7 +60,7 @@ // 'python3 jenkins/generate.py' // Note: This timestamp is here to ensure that updates to the Jenkinsfile are // always rebased on main before merging: -// Generated at 2023-02-02T20:12:16.748767 +// Generated at 2023-04-25T11:40:51.612532 import org.jenkinsci.plugins.pipeline.modeldefinition.Utils // These are set at runtime from data in ci/jenkins/docker-images.yml, update @@ -150,7 +150,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/ci/jenkins/templates/docker_jenkinsfile.groovy.j2 b/ci/jenkins/templates/docker_jenkinsfile.groovy.j2 index beb9b478bafb..f395f45dca34 100644 --- a/ci/jenkins/templates/docker_jenkinsfile.groovy.j2 +++ b/ci/jenkins/templates/docker_jenkinsfile.groovy.j2 @@ -179,7 +179,7 @@ def deploy() { -if (rebuild_docker_images) { +if (false && rebuild_docker_images) { stage('Docker Image Build') { parallel( {% for image in images %} diff --git a/ci/jenkins/templates/utils/Prepare.groovy.j2 b/ci/jenkins/templates/utils/Prepare.groovy.j2 index d5aebdc07008..66db58c3667d 100644 --- a/ci/jenkins/templates/utils/Prepare.groovy.j2 +++ b/ci/jenkins/templates/utils/Prepare.groovy.j2 @@ -20,7 +20,8 @@ def init_git() { update_upstream_revision("HEAD") } else { // This is PR branch so merge with latest main. - merge_with_main() + // merge_with_main() + update_upstream_revision("HEAD") } sh( diff --git a/src/arith/bound_deducer.cc b/src/arith/bound_deducer.cc index 7cfe8681bea3..d52ae7e6fde3 100644 --- a/src/arith/bound_deducer.cc +++ b/src/arith/bound_deducer.cc @@ -344,6 +344,10 @@ void BoundDeducer::Deduce() { expr_map_ = EvalSetForEachSubExpr(expr_, hint_map_); this->VisitExpr(expr_); + + if (success_) { + result_ = analyzer_.Simplify(result_); + } } void BoundDeducer::Relax() { diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc index 11fb041511f9..14c91934d3b2 100644 --- a/src/arith/canonical_simplify.cc +++ b/src/arith/canonical_simplify.cc @@ -633,6 +633,27 @@ class CanonicalSimplifier::Impl : public RewriteSimplifier::Impl { */ void SeparateDivisibleParts(const SumExprNode* psum, int64_t coeff, SumExpr* out_divisible, SumExpr* out_non_divisible); + /*! + * \brief Pattern match and check whether lhs is fully divisible by + * rhs using prod pattern simiplification expressions. + * + * The following two relations holds for floordiv/mod and truncdiv/mod + * Note that the relation do not hold for euclidean divide and mod. + * + * This is because the floordiv/mod and truncdiv/mod result can be + * uniquely determined by the value of the realdiv result and the + * relation holds for realdiv. + * + * - div((a0 * a1 * c), (b0 * b1 * c)) = div((a0 * a1), (b0 * b1)) + * - mod((a0 * a1 * c), (b0 * b1 * c)) = mod((a0 * a1), (b0 * b1)) * c + * + * \param lhs The left operand to be updated. + * \param rhs The right operand to be updated. + * \param common_scale The common scale between lhs and rhs. + * \returns The simplified result if it is successful. + * \note This simplification mainly target when rhs is symbolic. + */ + bool ProdDivSimplify(PrimExpr* lhs, PrimExpr* rhs, PrimExpr* common_scale); /*! * \brief Normalize expr to normal expr. * \param expr The input expression. @@ -862,6 +883,66 @@ SplitExpr CanonicalSimplifier::Impl::SplitDivConst(SplitExpr lhs, int64_t cval, return lhs; } +bool CanonicalSimplifier::Impl::ProdDivSimplify(PrimExpr* plhs, PrimExpr* prhs, + PrimExpr* common_scale) { + // the constant rhs case is covered by other simplifier so + // we just skip to save the time + if (prhs->as()) return false; + // collect lhs products and try to eliminate by matching them to prod in rhs + Array> lhs_prods; + PrimExpr new_rhs = make_const(prhs->dtype(), 1); + PrimExpr new_common_scale = make_const(prhs->dtype(), 1); + int64_t lhs_cscale = 1, rhs_cscale = 1; + int num_elimination = 0; + + // collect lhs product and constant scale. + auto fcollect_lhs = [&](PrimExpr value) { + if (auto* intimm = value.as()) { + lhs_cscale *= intimm->value; + } else { + lhs_prods.push_back(value); + } + }; + UnpackReduction(*plhs, fcollect_lhs); + + // collect rhs product and try to eliminate when possible + PEqualChecker deep_equal; + auto fcollect_rhs = [&](PrimExpr value) { + if (auto* intimm = value.as()) { + rhs_cscale *= intimm->value; + } else { + // try eliminate from lhs + for (size_t i = 0; i < lhs_prods.size(); ++i) { + if (lhs_prods[i].defined() && deep_equal(value, lhs_prods[i].value())) { + lhs_prods.Set(i, NullOpt); + ++num_elimination; + new_common_scale = new_common_scale * value; + return; + } + } + // if elimination is not possible then construct the expression. + new_rhs = new_rhs * value; + } + }; + UnpackReduction(*prhs, fcollect_rhs); + // find gcd of const scales. + int64_t cscale_gcd = ZeroAwareGCD(lhs_cscale, rhs_cscale); + lhs_cscale /= cscale_gcd; + rhs_cscale /= cscale_gcd; + // if no elimination is possible + if (num_elimination == 0 && cscale_gcd == 1) return false; + + // construct prod via canonical form + PrimExpr new_lhs = make_const(plhs->dtype(), 1); + for (Optional val : lhs_prods) { + if (val.defined()) new_lhs = new_lhs * val.value(); + } + *plhs = new_lhs * make_const(plhs->dtype(), lhs_cscale); + *prhs = new_rhs * make_const(prhs->dtype(), rhs_cscale); + *common_scale = new_common_scale * make_const(prhs->dtype(), cscale_gcd); + return true; +} + PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { if (!IsIndexType(op->dtype)) { return Rewriter::VisitExpr_(op); @@ -913,6 +994,12 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const DivNode* op) { // normal path a = Normalize(a); b = Normalize(b); + PrimExpr scale; + // note this is the case where b is not constant + if (ProdDivSimplify(&a, &b, &scale)) { + // use operator ver so it can constant fold if b == 1 + return truncdiv(a, b); + } if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { @@ -967,6 +1054,11 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorDivNode* op) { // normal path a = Normalize(a); b = Normalize(b); + PrimExpr scale; + if (ProdDivSimplify(&a, &b, &scale)) { + // use operator ver so it can const fold. + return floordiv(a, b); + } if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { @@ -1088,6 +1180,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const ModNode* op) { // normal path a = Normalize(a); b = Normalize(b); + + PrimExpr scale; + if (ProdDivSimplify(&a, &b, &scale)) { + // use operator version here so it can const fold b == 1 + return truncmod(a, b) * scale; + } + if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { @@ -1146,6 +1245,13 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const FloorModNode* op) { // normal path a = Normalize(a); b = Normalize(b); + + PrimExpr scale; + if (ProdDivSimplify(&a, &b, &scale)) { + // use operator version here so it can const fold b == 1 + return floormod(a, b) * scale; + } + if (op->a.same_as(a) && op->b.same_as(b)) { return GetRef(op); } else { diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h index 55b51d7a315b..0bb172e56053 100644 --- a/src/arith/pattern_match.h +++ b/src/arith/pattern_match.h @@ -915,6 +915,23 @@ matches_one_of(const TPattern&... patterns) { return PMatchesOneOf(patterns...); } +/*! + * \brief Unpack reduction by calling each leaf via fleaf. + * + * \param value The expression value. + * \tparam TNode the reduction node to match. + * \tparam FLeaf The callback function at leaf. + */ +template +inline void UnpackReduction(const PrimExpr& value, FLeaf fleaf) { + if (const TNode* node = value.as()) { + UnpackReduction(node->a, fleaf); + UnpackReduction(node->b, fleaf); + } else { + fleaf(value); + } +} + } // namespace arith } // namespace tvm #endif // TVM_ARITH_PATTERN_MATCH_H_ diff --git a/tests/python/unittest/test_arith_deduce_bound.py b/tests/python/unittest/test_arith_deduce_bound.py index 45ecb6275549..a36fd214794b 100644 --- a/tests/python/unittest/test_arith_deduce_bound.py +++ b/tests/python/unittest/test_arith_deduce_bound.py @@ -114,12 +114,10 @@ def test_deduce(): assert str(res9.max_value) == "neg_inf" assert str(res9.min_value) == "pos_inf" - # Unsatisfiable Mul in `EQ` - res10 = tvm.arith.deduce_bound( - a, (b * a == b), {b: b_s}, {} - ) # simplifier is not able to prove that (b % b == 0) - assert str(res10.max_value) == "neg_inf" - assert str(res10.min_value) == "pos_inf" + res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {}) + # simplifier is now able to prove symbolic relation (b * a % b == 0) + tvm.testing.assert_prim_expr_equal(res10.max_value, 1) + tvm.testing.assert_prim_expr_equal(res10.min_value, 1) def test_check():