From d87ebc0af5f2ac5af68246d5ffe62113b64a00e1 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 27 Mar 2021 23:42:44 +0800 Subject: [PATCH 1/5] [ARITH] subspace division --- include/tvm/arith/iter_affine_map.h | 26 ++ python/tvm/arith/iter_affine_map.py | 43 +++ src/arith/iter_affine_map.cc | 297 +++++++++++++++ .../unittest/test_arith_iter_affine_map.py | 340 ++++++++++++++++++ 4 files changed, 706 insertions(+) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index f786c013443c..142992abf844 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -283,6 +283,32 @@ Array DetectIterMap(const Array& indices, const Map> SubspaceDivide(const Array& bindings, + const Map& input_iters, + const Array& sub_iters, const PrimExpr& predicate, + bool require_bijective, arith::Analyzer* analyzer); + } // namespace arith } // namespace tvm #endif // TVM_ARITH_ITER_AFFINE_MAP_H_ diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 5aa817bd7a24..58d2ecbf6917 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -128,3 +128,46 @@ def normalize_iter_map_to_expr(expr): the corresponding normal PrimExpr """ return _ffi_api.NormalizeIterMapToExpr(expr) + + +def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bijective=False): + """Detect if bindings can be written as + [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n] + where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters) + b = some-quasi-affine-iter-map(sub_iters) + c is constant symbols + e is the extent of b + For example, z*12 + y*3 + x + c = (z*4+y)*3 + x + bindings = [z*12 + y*3 + x + c] + input_iters = [z, y, x] + sub_iter = [x] + Then the result will be [a, b] where + a = [z*4 + y] + b = [x] + + Parameters + ---------- + bindings : List[PrimExpr] + The input bindings + + input_iters : Map[Var, Range] + The domain of input iterator, which is the basis of the whole space + + sub_iters : Array[Var] + The subset of input_iters, which is the basis of the subspace + + predicate : PrimExpr + The predicate constraints on the input iterators + + require_bijective : bool + A boolean flag that indicates whether the bindings should be bijective + + Returns + ------- + results : List[List[PrimExpr]] + The iter map matching result. The inner list is of length 2. + The first expr is the basis of the quotient space. + The second expr is the basis of the subspace. + Empty array if no match can be found. + """ + return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index a49478a43635..6e7f919d8c28 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1086,5 +1086,302 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter return NormalizeIterMapToExpr(expr); }); +/*! + * \brief Divider to divide the bindings into two sets of bindings(outer and inner) + * such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X. + * We do message passing among IterSplitExpr and IterSumExpr. + * + * Example + * - If we encounter sum = i*10 + j*5 + k, and i, j, k are splits, + * and we know i = Yi*1 + 0, j = 0*E(Xj) + Xj, k = 0*E(Xk) + Xk through message passing, + * then sum = Yi*10 + (Xj*5 + Xk) = Y*E(X) + X, where Y = Yi, X = Xj*5 + Xk. + * - If we encounter split = (i / 2) % 4, and we know i = Y*E(X) + X through message passing. + * We inspect all the splits of i, which are i / 8, (i / 2) % 4, i % 2. + * Their extents are 2, 4, 2, if E(X) = 2, 8, 16, the splits can be divided. + */ +class SubspaceDivider { + public: + explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector, + const std::unordered_set& sub_iters) + : analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {} + + size_t unresolved_count() const { return unresolved_count_; } + + // Denotes outer*inner_extent + inner, used as message passing carrier + struct DivisionResult { + public: + // IterMapExpr of outer iters + IterMapExpr outer; + // IterMapExpr of inner iters + IterMapExpr inner; + // extent of outer + PrimExpr outer_extent; + // extent of inner + PrimExpr inner_extent; + + DivisionResult(IterMapExpr outer, PrimExpr outer_extent, IterMapExpr inner, + PrimExpr inner_extent) + : outer(std::move(outer)), + inner(std::move(inner)), + outer_extent(std::move(outer_extent)), + inner_extent(std::move(inner_extent)) {} + + // whether the division result is totally in outer subspace + bool IsOuter() const { return is_one(inner_extent); } + + // whether the division result is totally in inner subspace + bool IsInner() const { return is_one(outer_extent); } + + IterSplitExpr GetOuterAsSplit() const { return GetAsSplit(outer, outer_extent); } + + IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); } + + static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) { + return DivisionResult(IterSumExpr({}, 0), 1, iter, extent); + } + + static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) { + return DivisionResult(iter, extent, IterSumExpr({}, 0), 1); + } + + private: + static IterSplitExpr GetAsSplit(const IterMapExpr& expr, const PrimExpr& extent) { + if (const auto* op = expr.as()) { + return GetRef(op); + } else if (const auto* op = expr.as()) { + return IterSplitExpr(IterMark(GetRef(op), extent)); + } else { + LOG(FATAL); + return NullValue(); + } + } + }; + + // Divide an IterSumExpr + DivisionResult DivideIterSumExpr(const IterSumExpr& expr, const PrimExpr& mark_extent) { + if (expr->args.empty()) { + // base + return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1); + } else if (expr->args.size() == 1) { + // arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base) + if (!is_one(expr->args[0]->scale)) return Fail(); + DivisionResult res = DivideIterSplitExpr(expr->args[0]); + if (!is_zero(expr->base)) res = AddBase(res, expr->base); + return res; + } + // arg1 + arg2 + ... + argn + base + // then we can write it as Y*E(X)+X + // if it starts with contiguous outer splits, followed by contiguous inner splits + PrimExpr extent = 1; + std::vector outer_args, inner_args; + bool inner = true, scale_is_one = false; + // we check in inverse order so we can visit from inner to outer + for (auto it = expr->args.rbegin(); it != expr->args.rend(); ++it) { + const IterSplitExpr& arg = *it; + if (is_one(arg->scale)) scale_is_one = true; + DivisionResult arg_division = DivideIterSplitExpr(arg); + IterSplitExpr new_arg; + if (arg_division.IsInner()) { + if (!inner) return Fail(); + new_arg = arg_division.GetInnerAsSplit(); + inner_args.push_back(new_arg); + inner = true; + } else if (arg_division.IsOuter()) { + new_arg = arg_division.GetOuterAsSplit(); + outer_args.push_back(new_arg); + inner = false; + } else { + return Fail(); + } + extent *= new_arg->extent; + } + if (!scale_is_one) return Fail(); + bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent); + const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0); + const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base); + IterSumExpr outer_source = Downcast(outer_mark->source); + IterSumExpr inner_source = Downcast(inner_mark->source); + if (need_predicate) { + // if we have a predicate on this sum expr, then we cannot divide it into Y*E+X + // it should either be Y*1+0 or 0*E(X)+X + IterMapToExprNormalizer converter(analyzer_); + if (inner_args.empty()) { + // Y*1+0 + outer_preds_ = outer_preds_ && (converter.Convert(outer_source) < mark_extent); + return DivisionResult::Outer(outer_source, mark_extent); + } else if (outer_args.empty()) { + // 0*E(X)+X + inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent); + return DivisionResult::Inner(inner_source, mark_extent); + } else { + return Fail(); + } + } + return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent); + } + + PrimExpr GetOuterPreds() const { return outer_preds_; } + PrimExpr GetInnerPreds() const { return inner_preds_; } + + private: + DivisionResult Fail() { + unresolved_count_++; + return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0); + } + + DivisionResult AddBase(DivisionResult division, PrimExpr base) { + DivisionResult res = division; + if (const auto* op = division.inner.as()) { + res.inner = IterSumExpr({GetRef(op)}, base); + } else if (const auto* op = division.inner.as()) { + const auto& expr = GetRef(op); + res.inner = IterSumExpr(expr->args, expr->base + base); + } + return res; + } + + // args are sorted from inner to outer + static IterMark MarkFromArgsAndBase(const std::vector& args, PrimExpr base) { + std::vector res; + PrimExpr extent = 1; + for (const IterSplitExpr& it : args) { + IterSplitExpr arg = it; + arg.CopyOnWrite()->scale = extent; + extent *= arg->extent; + res.push_back(arg); + } + return IterMark(IterSumExpr(Array(res.rbegin(), res.rend()), base), extent); + } + + DivisionResult DivideIterSplitExpr(const IterSplitExpr& expr) { + auto it = split_map_.find(expr); + if (it != split_map_.end()) { + // We will calculate all the splits of an IterMark's division form when we first + // encounter one of them. If we encounter another later, we directly return the record. + return it->second; + } + const Array& splits = collector_.mark2splits_.at(expr->source); + if (const auto* iter_ptr = expr->source->source.as()) { + // source is input_iter, + bool inner = sub_iters_.count(GetRef(iter_ptr)); + for (const IterSplitExpr& split : splits) { + if (inner) { + // 0*E(split)+split + split_map_.emplace(split, DivisionResult::Inner(split, split->extent)); + } else { + // split*1 + 0 + split_map_.emplace(split, DivisionResult::Outer(split, split->extent)); + } + } + } else if (const auto* iter_ptr = expr->source->source.as()) { + // source = Y*E+X + // splits = [s1, s2, ..., sn] + // we can divide if there exists i, such that extent(s1)extent(s2)...extent(si)=extent(Y) + // extent(si+1)...extent(sn)=extent(X) + // For example, if source = Y*3+X \in [0, 12), Y \in [0, 4), X \in [0, 3) + // Case 1. splits = [s1, s2, s3] = [source / 6, (source / 3) % 2, source % 3], + // where extent(s1) = 2, extent(s2) = 2, extent(s3) = 3. + // Since extent(s1)extent(s2) = extent(Y), extent(s3) = extent(X), we have + // s1 = (Y / 2)*1 + 0, s2 = (Y % 2)*1 + 0, s3 = 0*3 + X + // Case 2. splits = [s1, s2, s3] = [source / 4, (source / 2) % 2, source % 2], + // where extent(s1) = 3, extent(s2) = 2, extent(s3) = 2. + // It's impossible to rewrite s1, s2, s3 in the form of Y*E(X) + X. + DivisionResult mark_division = + DivideIterSumExpr(GetRef(iter_ptr), expr->source->extent); + if (splits.size() == 1) { + return mark_division; + } + IterMark outer_mark(Downcast(mark_division.outer), mark_division.outer_extent); + IterMark inner_mark(Downcast(mark_division.inner), mark_division.inner_extent); + bool encountered_boundary = mark_division.IsOuter(); + std::vector used(splits.size(), false); + std::vector inner_iters, outer_iters; + PrimExpr expected_lower_factor = make_const(expr->source->source->dtype, 1); + // find the boundary of outer and inner, like case 1 above + for (size_t i = 0; i < splits.size(); ++i) { + size_t j = 0; + for (; j < splits.size(); ++j) { + if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor)) + break; + } + if (j == splits.size()) return Fail(); + used[j] = true; + if (!encountered_boundary) { + inner_iters.push_back(splits[j]); + } else { + outer_iters.push_back(splits[j]); + } + expected_lower_factor *= splits[j]->extent; + if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent)) + encountered_boundary = true; + } + if (!encountered_boundary) return Fail(); + for (const IterSplitExpr& inner_iter : inner_iters) { + IterSplitExpr new_iter = inner_iter; + new_iter.CopyOnWrite()->source = inner_mark; + split_map_.emplace(inner_iter, DivisionResult::Inner(new_iter, inner_iter->extent)); + } + for (const IterSplitExpr& outer_iter : outer_iters) { + IterSplitExpr new_iter = outer_iter; + new_iter.CopyOnWrite()->source = outer_mark; + new_iter.CopyOnWrite()->lower_factor = + floordiv(outer_iter->lower_factor, outer_iters[0]->lower_factor); + split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent)); + } + } else { + return Fail(); + } + return split_map_.at(expr); + } + + size_t unresolved_count_{0}; + Analyzer* analyzer_; + const IterMarkSplitCollector collector_; + // the set of subspace iters + const std::unordered_set& sub_iters_; + // map from SplitExpr to its corresponding DivisionResult(Y*E(X)+X) + std::unordered_map split_map_; + // predicate of outer space and inner space; + PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)}; +}; + +Array> SubspaceDivide(const Array& bindings, + const Map& input_iters, + const Array& sub_iters, const PrimExpr& predicate, + bool require_bijective, arith::Analyzer* analyzer) { + const Array& maps = + DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer); + if (maps.empty()) return {}; + + std::unordered_set inner_iter_set; + for (const Var& inner_iter : sub_iters) { + inner_iter_set.insert(inner_iter); + } + + IterMarkSplitCollector collector; + collector.Collect(maps); + SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set); + + std::vector> results; + for (const IterSumExpr& expr : maps) { + SubspaceDivider::DivisionResult res = subspace_divider.DivideIterSumExpr(expr, 0); + if (subspace_divider.unresolved_count()) return {}; + results.push_back( + {IterMark(res.outer, res.outer_extent), IterMark(res.inner, res.inner_extent)}); + } + + results.push_back({IterMark(IterSumExpr({}, 0), subspace_divider.GetOuterPreds()), + IterMark(IterSumExpr({}, 0), subspace_divider.GetInnerPreds())}); + return results; +} + +TVM_REGISTER_GLOBAL("arith.SubspaceDivide") + .set_body_typed([](const Array& bindings, const Map& root_iters, + const Array& sub_iters, const PrimExpr& predicate, + bool require_bijective) { + arith::Analyzer ana; + return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana); + }); + } // namespace arith } // namespace tvm diff --git a/tests/python/unittest/test_arith_iter_affine_map.py b/tests/python/unittest/test_arith_iter_affine_map.py index 5ce68aaaf51b..7bfdfc676b67 100644 --- a/tests/python/unittest/test_arith_iter_affine_map.py +++ b/tests/python/unittest/test_arith_iter_affine_map.py @@ -17,6 +17,7 @@ import tvm import tvm.testing from tvm import te +from tvm.tir import floormod, floordiv def ifuse(inputs, pred_extent=None): @@ -285,6 +286,343 @@ def test_predicate(): assert len(res) == 0 +def convert_division(divisions): + if divisions is None or len(divisions) == 0: + return [] + res = [] + for division in divisions[:-1]: + res.append( + [ + tvm.arith.normalize_iter_map_to_expr(division[0].source), + tvm.arith.normalize_iter_map_to_expr(division[1].source), + ] + ) + res.append([divisions[-1][0].extent, divisions[-1][1].extent]) + return res + + +def create_iter(name, extent): + return tvm.tir.Var(name, "int32"), extent + + +def test_subspace_division(): + x = tvm.tir.Var("x", "int32") + y = tvm.tir.Var("y", "int32") + z = tvm.tir.Var("z", "int32") + c = tvm.tir.SizeVar("c", "int32") + + # simple 1.1 + res = tvm.arith.subspace_divide( + [z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x] + ) + res = convert_division(res) + assert len(res) == 2 + tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) + tvm.ir.assert_structural_equal(res[0][1], x + c) + + # simple 1.2 + res = tvm.arith.subspace_divide( + [z * 12 + y * 3 + x + c], var_dom([(x, 3), (y, 4), (z, 5)]), [x], z * 4 + y < 18 + ) + res = convert_division(res) + assert len(res) == 2 + tvm.ir.assert_structural_equal(res[0][0], z * 4 + y) + tvm.ir.assert_structural_equal(res[0][1], x + c) + tvm.ir.assert_structural_equal(res[1][0], z * 4 + y < 18) + tvm.ir.assert_structural_equal(res[1][1], True) + + # compound 1 + i0 = create_iter("i0", 4) + j0 = create_iter("j0", 8) + i3 = create_iter("i3", 2) + + i1, i2 = isplit(j0, 4) + k0 = ifuse([i0, i1]) + k1 = ifuse([i2, i3]) + + # compound 1.1 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]]) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) + tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) + tvm.ir.assert_structural_equal(res[1][1], i3[0]) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) + assert len(res1) == 2 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) + assert len(res2) == 2 + + # compound 1.2 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]]) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], i0[0]) + tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) + tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) + assert len(res1) == 2 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) + assert len(res2) == 2 + + # compound 1.3 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i0[0], i3[0]]) + res = convert_division(res) + assert len(res) == 0 + + # compound 1.4 + res = tvm.arith.subspace_divide([k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], k0[0] < 7) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], (i0[0] * 2) + floordiv(j0[0], 4)) + tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[1][0], floormod(j0[0], 4)) + tvm.ir.assert_structural_equal(res[1][1], i3[0]) + tvm.ir.assert_structural_equal(res[2][0], (i0[0] * 2) + floordiv(j0[0], 4) < 7) + tvm.ir.assert_structural_equal(res[2][1], True) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([i3])) + assert len(res1) == 2 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0, j0])) + assert len(res2) == 2 + + # compound 1.5 + res = tvm.arith.subspace_divide( + [k0[0], k1[0]], var_dom([i0, j0, i3]), [j0[0], i3[0]], k1[0] < 7 + ) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], i0[0]) + tvm.ir.assert_structural_equal(res[0][1], floordiv(j0[0], 4)) + tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][1], (floormod(j0[0], 4) * 2) + i3[0]) + tvm.ir.assert_structural_equal(res[2][0], True) + tvm.ir.assert_structural_equal(res[2][1], (floormod(j0[0], 4) * 2) + i3[0] < 7) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([j0, i3])) + assert len(res1) == 2 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([i0])) + assert len(res2) == 2 + + # compound 1.6 + res = tvm.arith.subspace_divide( + [k0[0], k1[0]], var_dom([i0, j0, i3]), [i3[0]], tvm.tir.all(k0[0] < 7, k1[0] < 7) + ) + res = convert_division(res) + assert len(res) == 0 + + # compound 2 + j0 = create_iter("j0", 4) + l0 = create_iter("l0", 2) + l1 = create_iter("l1", 6) + j3 = create_iter("j3", 3) + + k0 = ifuse([l0, l1]) + i1, j2 = isplit(k0, 3) + j1, i1 = isplit(i1, 2) + i0 = ifuse([j0, j1]) + i2 = ifuse([j2, j3]) + + # compound 2.1 + res = tvm.arith.subspace_divide( + [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l1[0], j3[0]] + ) + res = convert_division(res) + assert len(res) == 4 + tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) + tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) + tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) + assert len(res1) == 3 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) + assert len(res2) == 3 + + # compound 2.2 + res = tvm.arith.subspace_divide( + [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], l1[0], j3[0]] + ) + res = convert_division(res) + assert len(res) == 4 + tvm.ir.assert_structural_equal(res[0][0], j0[0]) + tvm.ir.assert_structural_equal(res[0][1], floordiv(l0[0] * 6 + l1[0], 6)) + tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][1], floormod(floordiv(l0[0] * 6 + l1[0], 3), 2)) + tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][1], (floormod(l0[0] * 6 + l1[0], 3) * 3) + j3[0]) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l0, l1, j3])) + assert len(res1) == 3 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0])) + assert len(res2) == 3 + + # compound 2.3 + res = tvm.arith.subspace_divide( + [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [l0[0], j3[0]] + ) + res = convert_division(res) + assert len(res) == 0 + + # compound 2.4 + res = tvm.arith.subspace_divide( + [i0[0], i1[0], i2[0]], + var_dom([j0, l0, l1, j3]), + [l1[0], j3[0]], + tvm.tir.all(i0[0] < 7, i2[0] < 8), + ) + res = convert_division(res) + assert len(res) == 4 + tvm.ir.assert_structural_equal(res[0][0], (j0[0] * 2) + l0[0]) + tvm.ir.assert_structural_equal(res[0][1], 0) + tvm.ir.assert_structural_equal(res[1][0], 0) + tvm.ir.assert_structural_equal(res[1][1], floordiv(l1[0], 3)) + tvm.ir.assert_structural_equal(res[2][0], 0) + tvm.ir.assert_structural_equal(res[2][1], (floormod(l1[0], 3) * 3) + j3[0]) + tvm.ir.assert_structural_equal(res[3][0], (j0[0] * 2) + l0[0] < 7) + tvm.ir.assert_structural_equal(res[3][1], (floormod(l1[0], 3) * 3) + j3[0] < 8) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1], res[2][1]], var_dom([l1, j3])) + assert len(res1) == 3 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0], res[2][0]], var_dom([j0, l0])) + assert len(res2) == 3 + + # compound 2.5 + res = tvm.arith.subspace_divide( + [i0[0], i1[0], i2[0]], var_dom([j0, l0, l1, j3]), [j3[0]], i2[0] < 8 + ) + res = convert_division(res) + assert len(res) == 0 + + +def test_complex(): + n0 = create_iter("n0", 2) + n1 = create_iter("n1", 4) + + m0 = ifuse([n0, n1], 6) + m1 = create_iter("m1", 3) + + l0 = create_iter("l0", 4) + l1 = create_iter("l1", 8) + l2 = ifuse([m0, m1], 16) + l3 = create_iter("l3", 32) + + k0, k4 = isplit(l0, 2) + k1, k5 = isplit(l1, 2) + k2, k6 = isplit(l2, 4) + k3, k7 = isplit(l3, 4) + + j0 = ifuse([k0, k1], 7) + j1 = ifuse([k2, k3]) + j2 = ifuse([k4, k5]) + j3 = ifuse([k6, k7], 15) + + i0 = ifuse([j0, j1], 200) + i1 = ifuse([j2, j3], 50) + + res = tvm.arith.detect_iter_map( + [i0[0], i1[0]], + var_dom([l0, l1, n0, n1, m1, l3]), + tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), + ) + assert len(res) == 2 + + n0_mark = tvm.arith.IterMark(n0[0], n0[1]) + n1_mark = tvm.arith.IterMark(n1[0], n1[1]) + l0_mark = tvm.arith.IterMark(l0[0], l0[1]) + l1_mark = tvm.arith.IterMark(l1[0], l1[1]) + m1_mark = tvm.arith.IterMark(m1[0], m1[1]) + l3_mark = tvm.arith.IterMark(l3[0], l3[1]) + + m0_expr = tvm.arith.IterSumExpr( + [ + tvm.arith.IterSplitExpr(n0_mark, 1, n0[1], 4), + tvm.arith.IterSplitExpr(n1_mark, 1, n1[1], 1), + ], + 0, + ) + m0_mark = tvm.arith.IterMark(m0_expr, 6) + l2_expr = tvm.arith.IterSumExpr( + [tvm.arith.IterSplitExpr(m0_mark, 1, 6, 3), tvm.arith.IterSplitExpr(m1_mark, 1, m1[1], 1)], + 0, + ) + l2_mark = tvm.arith.IterMark(l2_expr, 16) + k0_expr = tvm.arith.IterSplitExpr(l0_mark, 2, 2, 4) + k1_expr = tvm.arith.IterSplitExpr(l1_mark, 2, 4, 1) + k2_expr = tvm.arith.IterSplitExpr(l2_mark, 4, 4, 8) + k3_expr = tvm.arith.IterSplitExpr(l3_mark, 4, 8, 1) + k4_expr = tvm.arith.IterSplitExpr(l0_mark, 1, 2, 30) + k5_expr = tvm.arith.IterSplitExpr(l1_mark, 1, 2, 15) + k6_expr = tvm.arith.IterSplitExpr(l2_mark, 1, 4, 4) + k7_expr = tvm.arith.IterSplitExpr(l3_mark, 1, 4, 1) + + j0_expr = tvm.arith.IterSumExpr([k0_expr, k1_expr], 0) + j0_mark = tvm.arith.IterMark(j0_expr, 7) + i0_expr = tvm.arith.IterSumExpr( + [tvm.arith.IterSplitExpr(j0_mark, 1, 7, 32), k2_expr, k3_expr], 0 + ) + + j3_expr = tvm.arith.IterSumExpr([k6_expr, k7_expr], 0) + j3_mark = tvm.arith.IterMark(j3_expr, 15) + i1_expr = tvm.arith.IterSumExpr( + [k4_expr, k5_expr, tvm.arith.IterSplitExpr(j3_mark, 1, 15, 1)], 0 + ) + + i0_mark = tvm.arith.IterMark(i0_expr, i0[1]) + i1_mark = tvm.arith.IterMark(i1_expr, i1[1]) + + i0_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i0_mark, 1, i0[1], 1)], 0) + i1_final = tvm.arith.IterSumExpr([tvm.arith.IterSplitExpr(i1_mark, 1, i1[1], 1)], 0) + + tvm.ir.assert_structural_equal(i0_final, res[0]) + tvm.ir.assert_structural_equal(i1_final, res[1]) + + # wrong constraint + res = tvm.arith.detect_iter_map( + [i0[0], i1[0]], + var_dom([l0, l1, n0, n1, m1, l3]), + tvm.tir.all(i0[0] < 200, i1[0] < 50, m0[0] < 9, l2[0] < 16, j0[0] < 7, j3[0] < 14), + ) + assert len(res) == 0 + + # subspace_division + res = tvm.arith.subspace_divide( + [i0[0], i1[0]], + var_dom([l0, l1, n0, n1, m1, l3]), + [n0[0], n1[0], m1[0], l3[0]], + tvm.tir.all(m0[0] < 6, l2[0] < 16, j0[0] < 7, j3[0] < 15), + ) + res = convert_division(res) + assert len(res) == 3 + tvm.ir.assert_structural_equal(res[0][0], floordiv(l0[0], 2) * 4 + floordiv(l1[0], 2)) + tvm.ir.assert_structural_equal( + res[0][1], (floordiv((n0[0] * 4 + n1[0]) * 3 + m1[0], 4) * 8) + floordiv(l3[0], 4) + ) + tvm.ir.assert_structural_equal(res[1][0], ((floormod(l0[0], 2) * 2) + floormod(l1[0], 2))) + tvm.ir.assert_structural_equal( + res[1][1], ((floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4) + floormod(l3[0], 4)) + ) + tvm.ir.assert_structural_equal(res[2][0], (floordiv(l0[0], 2) * 4) + floordiv(l1[0], 2) < 7) + tvm.ir.assert_structural_equal( + res[2][1], + tvm.tir.all( + n0[0] * 4 + n1[0] < 6, + (n0[0] * 4 + n1[0]) * 3 + m1[0] < 16, + floormod(((n0[0] * 4 + n1[0]) * 3 + m1[0]), 4) * 4 + floormod(l3[0], 4) < 15, + ), + ) + + res1 = tvm.arith.detect_iter_map([res[0][1], res[1][1]], var_dom([n0, n1, m1, l3]), res[2][1]) + assert len(res1) == 2 + res2 = tvm.arith.detect_iter_map([res[0][0], res[1][0]], var_dom([l0, l1])) + assert len(res2) == 2 + + def test_normalize_iter_map_to_expr(): fld = tvm.tir.floordiv flm = tvm.tir.floormod @@ -312,3 +650,5 @@ def test_normalize_iter_map_to_expr(): test_compound() test_predicate() test_normalize_iter_map_to_expr() + test_subspace_division() + test_complex() From dbaa0178c71629a808799cfb6a2de8f0c37dd170 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Sat, 27 Mar 2021 23:45:51 +0800 Subject: [PATCH 2/5] [ARITH] subspace division --- python/tvm/arith/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/tvm/arith/__init__.py b/python/tvm/arith/__init__.py index 05843ede9284..a4cdb9839b22 100644 --- a/python/tvm/arith/__init__.py +++ b/python/tvm/arith/__init__.py @@ -22,4 +22,4 @@ from .pattern import detect_linear_equation, detect_clip_bound from .int_solver import solve_linear_equations, solve_linear_inequalities from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr -from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr +from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide From 1af9a42cc5861e7de4ac0775260af72fd2badcb8 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Tue, 30 Mar 2021 09:21:53 +0800 Subject: [PATCH 3/5] [ARITH] process comments --- src/arith/iter_affine_map.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 6e7f919d8c28..d71a65405565 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1151,7 +1151,7 @@ class SubspaceDivider { } else if (const auto* op = expr.as()) { return IterSplitExpr(IterMark(GetRef(op), extent)); } else { - LOG(FATAL); + LOG(FATAL) << "Unknown IterMapExpr type"; return NullValue(); } } From b434d8fa8d60f36e17fecaa9f102c7cb4e6d3ae5 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Tue, 30 Mar 2021 18:27:50 +0800 Subject: [PATCH 4/5] [ARITH] process comment --- src/arith/iter_affine_map.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index d71a65405565..499d7dfa8632 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1262,7 +1262,7 @@ class SubspaceDivider { } const Array& splits = collector_.mark2splits_.at(expr->source); if (const auto* iter_ptr = expr->source->source.as()) { - // source is input_iter, + // source is input_iter bool inner = sub_iters_.count(GetRef(iter_ptr)); for (const IterSplitExpr& split : splits) { if (inner) { From eb2527f7074212a93932aa12cb8b609f37e8b506 Mon Sep 17 00:00:00 2001 From: spectrometerHBH Date: Thu, 1 Apr 2021 10:28:14 +0800 Subject: [PATCH 5/5] [ARITH] fix comment --- include/tvm/arith/iter_affine_map.h | 8 ++++++-- python/tvm/arith/iter_affine_map.py | 8 +++++--- src/arith/iter_affine_map.cc | 2 ++ 3 files changed, 13 insertions(+), 5 deletions(-) diff --git a/include/tvm/arith/iter_affine_map.h b/include/tvm/arith/iter_affine_map.h index 142992abf844..641d0e0f5321 100644 --- a/include/tvm/arith/iter_affine_map.h +++ b/include/tvm/arith/iter_affine_map.h @@ -301,8 +301,12 @@ Array DetectIterMap(const Array& indices, const Map> SubspaceDivide(const Array& bindings, const Map& input_iters, diff --git a/python/tvm/arith/iter_affine_map.py b/python/tvm/arith/iter_affine_map.py index 58d2ecbf6917..bfd5dfadc800 100644 --- a/python/tvm/arith/iter_affine_map.py +++ b/python/tvm/arith/iter_affine_map.py @@ -165,9 +165,11 @@ def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bi Returns ------- results : List[List[PrimExpr]] - The iter map matching result. The inner list is of length 2. - The first expr is the basis of the quotient space. - The second expr is the basis of the subspace. + The result list has length len(bindings) + 1 + [0, len(bindings)): The iter map matching result. The inner list is of length 2. + The first expr is the basis of the quotient space. + The second expr is the basis of the subspace. + len(bindings): the predicate of outer space and inner space Empty array if no match can be found. """ return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective) diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc index 499d7dfa8632..edcb6f8a2c92 100644 --- a/src/arith/iter_affine_map.cc +++ b/src/arith/iter_affine_map.cc @@ -1335,7 +1335,9 @@ class SubspaceDivider { } size_t unresolved_count_{0}; + // arithmetic analyzer used to call CanProve Analyzer* analyzer_; + // collector that collects the outgoing split reference of each IterMark const IterMarkSplitCollector collector_; // the set of subspace iters const std::unordered_set& sub_iters_;