Skip to content

Commit

Permalink
[ARITH] Subspace division (apache#7760)
Browse files Browse the repository at this point in the history
  • Loading branch information
spectrometerHBH authored and Trevor Morris committed May 6, 2021
1 parent b16784a commit 80ee78a
Show file tree
Hide file tree
Showing 5 changed files with 715 additions and 1 deletion.
30 changes: 30 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,36 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);

/*!
* \brief Detect if bindings can be written as
* [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
*
* where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters)
* b = some-quasi-affine-iter-map(sub_iters)
* c is constant symbols
* e is the extent of b
*
* For example, z*12 + y*3 + x + c = (z*4+y)*3 + x, if sub_iters={x}
*
* \param bindings The input bindings
* \param input_iters Map from variable to iterator's range.
* \param sub_iters Iterators of subspace.
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
*
* \return The result list has length len(bindings) + 1
[0, len(bindings)): The iter map matching result. The inner list is of length 2.
The first expr is the basis of the quotient space.
The second expr is the basis of the subspace.
len(bindings): the predicate of outer space and inner space
Empty array if no match can be found.
*/
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective, arith::Analyzer* analyzer);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_ITER_AFFINE_MAP_H_
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations, solve_linear_inequalities
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr
from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide
45 changes: 45 additions & 0 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,48 @@ def normalize_iter_map_to_expr(expr):
the corresponding normal PrimExpr
"""
return _ffi_api.NormalizeIterMapToExpr(expr)


def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bijective=False):
"""Detect if bindings can be written as
[a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters)
b = some-quasi-affine-iter-map(sub_iters)
c is constant symbols
e is the extent of b
For example, z*12 + y*3 + x + c = (z*4+y)*3 + x
bindings = [z*12 + y*3 + x + c]
input_iters = [z, y, x]
sub_iter = [x]
Then the result will be [a, b] where
a = [z*4 + y]
b = [x]
Parameters
----------
bindings : List[PrimExpr]
The input bindings
input_iters : Map[Var, Range]
The domain of input iterator, which is the basis of the whole space
sub_iters : Array[Var]
The subset of input_iters, which is the basis of the subspace
predicate : PrimExpr
The predicate constraints on the input iterators
require_bijective : bool
A boolean flag that indicates whether the bindings should be bijective
Returns
-------
results : List[List[PrimExpr]]
The result list has length len(bindings) + 1
[0, len(bindings)): The iter map matching result. The inner list is of length 2.
The first expr is the basis of the quotient space.
The second expr is the basis of the subspace.
len(bindings): the predicate of outer space and inner space
Empty array if no match can be found.
"""
return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective)
299 changes: 299 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1086,5 +1086,304 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter
return NormalizeIterMapToExpr(expr);
});

/*!
* \brief Divider to divide the bindings into two sets of bindings(outer and inner)
* such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X.
* We do message passing among IterSplitExpr and IterSumExpr.
*
* Example
* - If we encounter sum = i*10 + j*5 + k, and i, j, k are splits,
* and we know i = Yi*1 + 0, j = 0*E(Xj) + Xj, k = 0*E(Xk) + Xk through message passing,
* then sum = Yi*10 + (Xj*5 + Xk) = Y*E(X) + X, where Y = Yi, X = Xj*5 + Xk.
* - If we encounter split = (i / 2) % 4, and we know i = Y*E(X) + X through message passing.
* We inspect all the splits of i, which are i / 8, (i / 2) % 4, i % 2.
* Their extents are 2, 4, 2, if E(X) = 2, 8, 16, the splits can be divided.
*/
class SubspaceDivider {
public:
explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector,
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters)
: analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {}

size_t unresolved_count() const { return unresolved_count_; }

// Denotes outer*inner_extent + inner, used as message passing carrier
struct DivisionResult {
public:
// IterMapExpr of outer iters
IterMapExpr outer;
// IterMapExpr of inner iters
IterMapExpr inner;
// extent of outer
PrimExpr outer_extent;
// extent of inner
PrimExpr inner_extent;

DivisionResult(IterMapExpr outer, PrimExpr outer_extent, IterMapExpr inner,
PrimExpr inner_extent)
: outer(std::move(outer)),
inner(std::move(inner)),
outer_extent(std::move(outer_extent)),
inner_extent(std::move(inner_extent)) {}

// whether the division result is totally in outer subspace
bool IsOuter() const { return is_one(inner_extent); }

// whether the division result is totally in inner subspace
bool IsInner() const { return is_one(outer_extent); }

IterSplitExpr GetOuterAsSplit() const { return GetAsSplit(outer, outer_extent); }

IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); }

static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) {
return DivisionResult(IterSumExpr({}, 0), 1, iter, extent);
}

static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) {
return DivisionResult(iter, extent, IterSumExpr({}, 0), 1);
}

private:
static IterSplitExpr GetAsSplit(const IterMapExpr& expr, const PrimExpr& extent) {
if (const auto* op = expr.as<IterSplitExprNode>()) {
return GetRef<IterSplitExpr>(op);
} else if (const auto* op = expr.as<IterSumExprNode>()) {
return IterSplitExpr(IterMark(GetRef<IterSumExpr>(op), extent));
} else {
LOG(FATAL) << "Unknown IterMapExpr type";
return NullValue<IterSplitExpr>();
}
}
};

// Divide an IterSumExpr
DivisionResult DivideIterSumExpr(const IterSumExpr& expr, const PrimExpr& mark_extent) {
if (expr->args.empty()) {
// base
return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1);
} else if (expr->args.size() == 1) {
// arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base)
if (!is_one(expr->args[0]->scale)) return Fail();
DivisionResult res = DivideIterSplitExpr(expr->args[0]);
if (!is_zero(expr->base)) res = AddBase(res, expr->base);
return res;
}
// arg1 + arg2 + ... + argn + base
// then we can write it as Y*E(X)+X
// if it starts with contiguous outer splits, followed by contiguous inner splits
PrimExpr extent = 1;
std::vector<IterSplitExpr> outer_args, inner_args;
bool inner = true, scale_is_one = false;
// we check in inverse order so we can visit from inner to outer
for (auto it = expr->args.rbegin(); it != expr->args.rend(); ++it) {
const IterSplitExpr& arg = *it;
if (is_one(arg->scale)) scale_is_one = true;
DivisionResult arg_division = DivideIterSplitExpr(arg);
IterSplitExpr new_arg;
if (arg_division.IsInner()) {
if (!inner) return Fail();
new_arg = arg_division.GetInnerAsSplit();
inner_args.push_back(new_arg);
inner = true;
} else if (arg_division.IsOuter()) {
new_arg = arg_division.GetOuterAsSplit();
outer_args.push_back(new_arg);
inner = false;
} else {
return Fail();
}
extent *= new_arg->extent;
}
if (!scale_is_one) return Fail();
bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent);
const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0);
const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base);
IterSumExpr outer_source = Downcast<IterSumExpr>(outer_mark->source);
IterSumExpr inner_source = Downcast<IterSumExpr>(inner_mark->source);
if (need_predicate) {
// if we have a predicate on this sum expr, then we cannot divide it into Y*E+X
// it should either be Y*1+0 or 0*E(X)+X
IterMapToExprNormalizer converter(analyzer_);
if (inner_args.empty()) {
// Y*1+0
outer_preds_ = outer_preds_ && (converter.Convert(outer_source) < mark_extent);
return DivisionResult::Outer(outer_source, mark_extent);
} else if (outer_args.empty()) {
// 0*E(X)+X
inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent);
return DivisionResult::Inner(inner_source, mark_extent);
} else {
return Fail();
}
}
return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent);
}

PrimExpr GetOuterPreds() const { return outer_preds_; }
PrimExpr GetInnerPreds() const { return inner_preds_; }

private:
DivisionResult Fail() {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
}

DivisionResult AddBase(DivisionResult division, PrimExpr base) {
DivisionResult res = division;
if (const auto* op = division.inner.as<IterSplitExprNode>()) {
res.inner = IterSumExpr({GetRef<IterSplitExpr>(op)}, base);
} else if (const auto* op = division.inner.as<IterSumExprNode>()) {
const auto& expr = GetRef<IterSumExpr>(op);
res.inner = IterSumExpr(expr->args, expr->base + base);
}
return res;
}

// args are sorted from inner to outer
static IterMark MarkFromArgsAndBase(const std::vector<IterSplitExpr>& args, PrimExpr base) {
std::vector<IterSplitExpr> res;
PrimExpr extent = 1;
for (const IterSplitExpr& it : args) {
IterSplitExpr arg = it;
arg.CopyOnWrite()->scale = extent;
extent *= arg->extent;
res.push_back(arg);
}
return IterMark(IterSumExpr(Array<IterSplitExpr>(res.rbegin(), res.rend()), base), extent);
}

DivisionResult DivideIterSplitExpr(const IterSplitExpr& expr) {
auto it = split_map_.find(expr);
if (it != split_map_.end()) {
// We will calculate all the splits of an IterMark's division form when we first
// encounter one of them. If we encounter another later, we directly return the record.
return it->second;
}
const Array<IterSplitExpr>& splits = collector_.mark2splits_.at(expr->source);
if (const auto* iter_ptr = expr->source->source.as<VarNode>()) {
// source is input_iter
bool inner = sub_iters_.count(GetRef<Var>(iter_ptr));
for (const IterSplitExpr& split : splits) {
if (inner) {
// 0*E(split)+split
split_map_.emplace(split, DivisionResult::Inner(split, split->extent));
} else {
// split*1 + 0
split_map_.emplace(split, DivisionResult::Outer(split, split->extent));
}
}
} else if (const auto* iter_ptr = expr->source->source.as<IterSumExprNode>()) {
// source = Y*E+X
// splits = [s1, s2, ..., sn]
// we can divide if there exists i, such that extent(s1)extent(s2)...extent(si)=extent(Y)
// extent(si+1)...extent(sn)=extent(X)
// For example, if source = Y*3+X \in [0, 12), Y \in [0, 4), X \in [0, 3)
// Case 1. splits = [s1, s2, s3] = [source / 6, (source / 3) % 2, source % 3],
// where extent(s1) = 2, extent(s2) = 2, extent(s3) = 3.
// Since extent(s1)extent(s2) = extent(Y), extent(s3) = extent(X), we have
// s1 = (Y / 2)*1 + 0, s2 = (Y % 2)*1 + 0, s3 = 0*3 + X
// Case 2. splits = [s1, s2, s3] = [source / 4, (source / 2) % 2, source % 2],
// where extent(s1) = 3, extent(s2) = 2, extent(s3) = 2.
// It's impossible to rewrite s1, s2, s3 in the form of Y*E(X) + X.
DivisionResult mark_division =
DivideIterSumExpr(GetRef<IterSumExpr>(iter_ptr), expr->source->extent);
if (splits.size() == 1) {
return mark_division;
}
IterMark outer_mark(Downcast<IterSumExpr>(mark_division.outer), mark_division.outer_extent);
IterMark inner_mark(Downcast<IterSumExpr>(mark_division.inner), mark_division.inner_extent);
bool encountered_boundary = mark_division.IsOuter();
std::vector<bool> used(splits.size(), false);
std::vector<IterSplitExpr> inner_iters, outer_iters;
PrimExpr expected_lower_factor = make_const(expr->source->source->dtype, 1);
// find the boundary of outer and inner, like case 1 above
for (size_t i = 0; i < splits.size(); ++i) {
size_t j = 0;
for (; j < splits.size(); ++j) {
if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor))
break;
}
if (j == splits.size()) return Fail();
used[j] = true;
if (!encountered_boundary) {
inner_iters.push_back(splits[j]);
} else {
outer_iters.push_back(splits[j]);
}
expected_lower_factor *= splits[j]->extent;
if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent))
encountered_boundary = true;
}
if (!encountered_boundary) return Fail();
for (const IterSplitExpr& inner_iter : inner_iters) {
IterSplitExpr new_iter = inner_iter;
new_iter.CopyOnWrite()->source = inner_mark;
split_map_.emplace(inner_iter, DivisionResult::Inner(new_iter, inner_iter->extent));
}
for (const IterSplitExpr& outer_iter : outer_iters) {
IterSplitExpr new_iter = outer_iter;
new_iter.CopyOnWrite()->source = outer_mark;
new_iter.CopyOnWrite()->lower_factor =
floordiv(outer_iter->lower_factor, outer_iters[0]->lower_factor);
split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent));
}
} else {
return Fail();
}
return split_map_.at(expr);
}

size_t unresolved_count_{0};
// arithmetic analyzer used to call CanProve
Analyzer* analyzer_;
// collector that collects the outgoing split reference of each IterMark
const IterMarkSplitCollector collector_;
// the set of subspace iters
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters_;
// map from SplitExpr to its corresponding DivisionResult(Y*E(X)+X)
std::unordered_map<IterSplitExpr, DivisionResult, ObjectPtrHash, ObjectPtrEqual> split_map_;
// predicate of outer space and inner space;
PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)};
};

Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective, arith::Analyzer* analyzer) {
const Array<IterSumExpr>& maps =
DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer);
if (maps.empty()) return {};

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> inner_iter_set;
for (const Var& inner_iter : sub_iters) {
inner_iter_set.insert(inner_iter);
}

IterMarkSplitCollector collector;
collector.Collect(maps);
SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set);

std::vector<Array<IterMark>> results;
for (const IterSumExpr& expr : maps) {
SubspaceDivider::DivisionResult res = subspace_divider.DivideIterSumExpr(expr, 0);
if (subspace_divider.unresolved_count()) return {};
results.push_back(
{IterMark(res.outer, res.outer_extent), IterMark(res.inner, res.inner_extent)});
}

results.push_back({IterMark(IterSumExpr({}, 0), subspace_divider.GetOuterPreds()),
IterMark(IterSumExpr({}, 0), subspace_divider.GetInnerPreds())});
return results;
}

TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
.set_body_typed([](const Array<PrimExpr>& bindings, const Map<Var, Range>& root_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective) {
arith::Analyzer ana;
return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana);
});

} // namespace arith
} // namespace tvm
Loading

0 comments on commit 80ee78a

Please sign in to comment.