Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ARITH] Subspace division #7760

Merged
merged 5 commits into from
Apr 1, 2021
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions include/tvm/arith/iter_affine_map.h
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,32 @@ Array<IterSumExpr> DetectIterMap(const Array<PrimExpr>& indices, const Map<Var,
const PrimExpr& predicate, bool require_bijective,
arith::Analyzer* analyzer);

/*!
* \brief Detect if bindings can be written as
* [a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
*
* where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters)
* b = some-quasi-affine-iter-map(sub_iters)
* c is constant symbols
* e is the extent of b
*
* For example, z*12 + y*3 + x + c = (z*4+y)*3 + x, if sub_iters={x}
*
* \param bindings The input bindings
* \param input_iters Map from variable to iterator's range.
* \param sub_iters Iterators of subspace.
* \param predicate The predicate constraints on the input iterators
* \param require_bijective A boolean flag that indicates whether the mapping should be bijective.
* \param analyzer Analyzer used to get context information.
*
* \return The detected a and b if a match exists,
* otherwise return an empty array.
*/
Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective, arith::Analyzer* analyzer);

} // namespace arith
} // namespace tvm
#endif // TVM_ARITH_ITER_AFFINE_MAP_H_
2 changes: 1 addition & 1 deletion python/tvm/arith/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,4 @@
from .pattern import detect_linear_equation, detect_clip_bound
from .int_solver import solve_linear_equations, solve_linear_inequalities
from .iter_affine_map import IterMapExpr, IterMark, IterSplitExpr, IterSumExpr
from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr
from .iter_affine_map import detect_iter_map, normalize_iter_map_to_expr, subspace_divide
43 changes: 43 additions & 0 deletions python/tvm/arith/iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,3 +128,46 @@ def normalize_iter_map_to_expr(expr):
the corresponding normal PrimExpr
"""
return _ffi_api.NormalizeIterMapToExpr(expr)


def subspace_divide(bindings, input_iters, sub_iters, predicate=True, require_bijective=False):
"""Detect if bindings can be written as
[a_0*e_0 + b_0 + c_0, a_1*e_1 + b_1, ..., a_n*e_n + b_n]
where a = some-quasi-affine-iter-map(input_iters set_minus sub_iters)
b = some-quasi-affine-iter-map(sub_iters)
c is constant symbols
e is the extent of b
For example, z*12 + y*3 + x + c = (z*4+y)*3 + x
bindings = [z*12 + y*3 + x + c]
input_iters = [z, y, x]
sub_iter = [x]
Then the result will be [a, b] where
a = [z*4 + y]
b = [x]

Parameters
----------
bindings : List[PrimExpr]
The input bindings

input_iters : Map[Var, Range]
The domain of input iterator, which is the basis of the whole space

sub_iters : Array[Var]
The subset of input_iters, which is the basis of the subspace

predicate : PrimExpr
The predicate constraints on the input iterators

require_bijective : bool
A boolean flag that indicates whether the bindings should be bijective

Returns
-------
results : List[List[PrimExpr]]
spectrometerHBH marked this conversation as resolved.
Show resolved Hide resolved
The iter map matching result. The inner list is of length 2.
The first expr is the basis of the quotient space.
The second expr is the basis of the subspace.
Empty array if no match can be found.
"""
return _ffi_api.SubspaceDivide(bindings, input_iters, sub_iters, predicate, require_bijective)
297 changes: 297 additions & 0 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1086,5 +1086,302 @@ TVM_REGISTER_GLOBAL("arith.NormalizeIterMapToExpr").set_body_typed([](const Iter
return NormalizeIterMapToExpr(expr);
});

/*!
* \brief Divider to divide the bindings into two sets of bindings(outer and inner)
* such that binding_i = Y_i * E(Xi) + Xi, where E(X) is the extent of X.
* We do message passing among IterSplitExpr and IterSumExpr.
*
* Example
* - If we encounter sum = i*10 + j*5 + k, and i, j, k are splits,
* and we know i = Yi*1 + 0, j = 0*E(Xj) + Xj, k = 0*E(Xk) + Xk through message passing,
* then sum = Yi*10 + (Xj*5 + Xk) = Y*E(X) + X, where Y = Yi, X = Xj*5 + Xk.
* - If we encounter split = (i / 2) % 4, and we know i = Y*E(X) + X through message passing.
* We inspect all the splits of i, which are i / 8, (i / 2) % 4, i % 2.
* Their extents are 2, 4, 2, if E(X) = 2, 8, 16, the splits can be divided.
*/
class SubspaceDivider {
public:
explicit SubspaceDivider(Analyzer* analyzer, const IterMarkSplitCollector& collector,
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters)
: analyzer_(analyzer), collector_(collector), sub_iters_(sub_iters) {}

size_t unresolved_count() const { return unresolved_count_; }

// Denotes outer*inner_extent + inner, used as message passing carrier
struct DivisionResult {
public:
// IterMapExpr of outer iters
IterMapExpr outer;
// IterMapExpr of inner iters
IterMapExpr inner;
// extent of outer
PrimExpr outer_extent;
// extent of inner
PrimExpr inner_extent;

DivisionResult(IterMapExpr outer, PrimExpr outer_extent, IterMapExpr inner,
PrimExpr inner_extent)
: outer(std::move(outer)),
inner(std::move(inner)),
outer_extent(std::move(outer_extent)),
inner_extent(std::move(inner_extent)) {}

// whether the division result is totally in outer subspace
bool IsOuter() const { return is_one(inner_extent); }

// whether the division result is totally in inner subspace
bool IsInner() const { return is_one(outer_extent); }

IterSplitExpr GetOuterAsSplit() const { return GetAsSplit(outer, outer_extent); }

IterSplitExpr GetInnerAsSplit() const { return GetAsSplit(inner, inner_extent); }

static DivisionResult Inner(const IterMapExpr& iter, const PrimExpr& extent) {
return DivisionResult(IterSumExpr({}, 0), 1, iter, extent);
}

static DivisionResult Outer(const IterMapExpr& iter, const PrimExpr& extent) {
return DivisionResult(iter, extent, IterSumExpr({}, 0), 1);
}

private:
static IterSplitExpr GetAsSplit(const IterMapExpr& expr, const PrimExpr& extent) {
if (const auto* op = expr.as<IterSplitExprNode>()) {
return GetRef<IterSplitExpr>(op);
} else if (const auto* op = expr.as<IterSumExprNode>()) {
return IterSplitExpr(IterMark(GetRef<IterSumExpr>(op), extent));
} else {
LOG(FATAL) << "Unknown IterMapExpr type";
return NullValue<IterSplitExpr>();
}
}
};

// Divide an IterSumExpr
DivisionResult DivideIterSumExpr(const IterSumExpr& expr, const PrimExpr& mark_extent) {
if (expr->args.empty()) {
// base
return DivisionResult(IterSumExpr({}, 0), 1, IterSumExpr({}, expr->base), 1);
} else if (expr->args.size() == 1) {
// arg + base, if arg=Y*E(X)+X, then arg+base = Y*E(X)+(X+base)
if (!is_one(expr->args[0]->scale)) return Fail();
DivisionResult res = DivideIterSplitExpr(expr->args[0]);
if (!is_zero(expr->base)) res = AddBase(res, expr->base);
return res;
}
// arg1 + arg2 + ... + argn + base
// then we can write it as Y*E(X)+X
// if it starts with contiguous outer splits, followed by contiguous inner splits
PrimExpr extent = 1;
std::vector<IterSplitExpr> outer_args, inner_args;
bool inner = true, scale_is_one = false;
// we check in inverse order so we can visit from inner to outer
for (auto it = expr->args.rbegin(); it != expr->args.rend(); ++it) {
const IterSplitExpr& arg = *it;
if (is_one(arg->scale)) scale_is_one = true;
DivisionResult arg_division = DivideIterSplitExpr(arg);
IterSplitExpr new_arg;
if (arg_division.IsInner()) {
if (!inner) return Fail();
new_arg = arg_division.GetInnerAsSplit();
inner_args.push_back(new_arg);
inner = true;
} else if (arg_division.IsOuter()) {
new_arg = arg_division.GetOuterAsSplit();
outer_args.push_back(new_arg);
inner = false;
} else {
return Fail();
}
extent *= new_arg->extent;
}
if (!scale_is_one) return Fail();
bool need_predicate = !analyzer_->CanProveEqual(extent, mark_extent);
const IterMark& outer_mark = MarkFromArgsAndBase(outer_args, 0);
const IterMark& inner_mark = MarkFromArgsAndBase(inner_args, expr->base);
IterSumExpr outer_source = Downcast<IterSumExpr>(outer_mark->source);
IterSumExpr inner_source = Downcast<IterSumExpr>(inner_mark->source);
if (need_predicate) {
// if we have a predicate on this sum expr, then we cannot divide it into Y*E+X
// it should either be Y*1+0 or 0*E(X)+X
IterMapToExprNormalizer converter(analyzer_);
if (inner_args.empty()) {
// Y*1+0
outer_preds_ = outer_preds_ && (converter.Convert(outer_source) < mark_extent);
return DivisionResult::Outer(outer_source, mark_extent);
} else if (outer_args.empty()) {
// 0*E(X)+X
inner_preds_ = inner_preds_ && (converter.Convert(inner_source) < mark_extent);
return DivisionResult::Inner(inner_source, mark_extent);
} else {
return Fail();
}
}
return DivisionResult(outer_source, outer_mark->extent, inner_source, inner_mark->extent);
}

PrimExpr GetOuterPreds() const { return outer_preds_; }
PrimExpr GetInnerPreds() const { return inner_preds_; }

private:
DivisionResult Fail() {
unresolved_count_++;
return DivisionResult(IterSumExpr({}, 0), 0, IterSumExpr({}, 0), 0);
}

DivisionResult AddBase(DivisionResult division, PrimExpr base) {
DivisionResult res = division;
if (const auto* op = division.inner.as<IterSplitExprNode>()) {
res.inner = IterSumExpr({GetRef<IterSplitExpr>(op)}, base);
} else if (const auto* op = division.inner.as<IterSumExprNode>()) {
const auto& expr = GetRef<IterSumExpr>(op);
res.inner = IterSumExpr(expr->args, expr->base + base);
}
return res;
}

// args are sorted from inner to outer
static IterMark MarkFromArgsAndBase(const std::vector<IterSplitExpr>& args, PrimExpr base) {
std::vector<IterSplitExpr> res;
PrimExpr extent = 1;
for (const IterSplitExpr& it : args) {
IterSplitExpr arg = it;
arg.CopyOnWrite()->scale = extent;
extent *= arg->extent;
res.push_back(arg);
}
return IterMark(IterSumExpr(Array<IterSplitExpr>(res.rbegin(), res.rend()), base), extent);
}

DivisionResult DivideIterSplitExpr(const IterSplitExpr& expr) {
auto it = split_map_.find(expr);
if (it != split_map_.end()) {
// We will calculate all the splits of an IterMark's division form when we first
// encounter one of them. If we encounter another later, we directly return the record.
return it->second;
}
const Array<IterSplitExpr>& splits = collector_.mark2splits_.at(expr->source);
if (const auto* iter_ptr = expr->source->source.as<VarNode>()) {
// source is input_iter
bool inner = sub_iters_.count(GetRef<Var>(iter_ptr));
for (const IterSplitExpr& split : splits) {
if (inner) {
// 0*E(split)+split
split_map_.emplace(split, DivisionResult::Inner(split, split->extent));
} else {
// split*1 + 0
split_map_.emplace(split, DivisionResult::Outer(split, split->extent));
}
}
} else if (const auto* iter_ptr = expr->source->source.as<IterSumExprNode>()) {
// source = Y*E+X
// splits = [s1, s2, ..., sn]
// we can divide if there exists i, such that extent(s1)extent(s2)...extent(si)=extent(Y)
// extent(si+1)...extent(sn)=extent(X)
// For example, if source = Y*3+X \in [0, 12), Y \in [0, 4), X \in [0, 3)
// Case 1. splits = [s1, s2, s3] = [source / 6, (source / 3) % 2, source % 3],
// where extent(s1) = 2, extent(s2) = 2, extent(s3) = 3.
// Since extent(s1)extent(s2) = extent(Y), extent(s3) = extent(X), we have
// s1 = (Y / 2)*1 + 0, s2 = (Y % 2)*1 + 0, s3 = 0*3 + X
// Case 2. splits = [s1, s2, s3] = [source / 4, (source / 2) % 2, source % 2],
// where extent(s1) = 3, extent(s2) = 2, extent(s3) = 2.
// It's impossible to rewrite s1, s2, s3 in the form of Y*E(X) + X.
DivisionResult mark_division =
DivideIterSumExpr(GetRef<IterSumExpr>(iter_ptr), expr->source->extent);
if (splits.size() == 1) {
return mark_division;
}
IterMark outer_mark(Downcast<IterSumExpr>(mark_division.outer), mark_division.outer_extent);
IterMark inner_mark(Downcast<IterSumExpr>(mark_division.inner), mark_division.inner_extent);
bool encountered_boundary = mark_division.IsOuter();
std::vector<bool> used(splits.size(), false);
std::vector<IterSplitExpr> inner_iters, outer_iters;
PrimExpr expected_lower_factor = make_const(expr->source->source->dtype, 1);
// find the boundary of outer and inner, like case 1 above
for (size_t i = 0; i < splits.size(); ++i) {
size_t j = 0;
for (; j < splits.size(); ++j) {
if (!used[j] && analyzer_->CanProveEqual(splits[j]->lower_factor, expected_lower_factor))
break;
}
if (j == splits.size()) return Fail();
used[j] = true;
if (!encountered_boundary) {
inner_iters.push_back(splits[j]);
} else {
outer_iters.push_back(splits[j]);
}
expected_lower_factor *= splits[j]->extent;
if (analyzer_->CanProveEqual(expected_lower_factor, mark_division.inner_extent))
encountered_boundary = true;
}
if (!encountered_boundary) return Fail();
for (const IterSplitExpr& inner_iter : inner_iters) {
IterSplitExpr new_iter = inner_iter;
new_iter.CopyOnWrite()->source = inner_mark;
split_map_.emplace(inner_iter, DivisionResult::Inner(new_iter, inner_iter->extent));
}
for (const IterSplitExpr& outer_iter : outer_iters) {
IterSplitExpr new_iter = outer_iter;
new_iter.CopyOnWrite()->source = outer_mark;
new_iter.CopyOnWrite()->lower_factor =
floordiv(outer_iter->lower_factor, outer_iters[0]->lower_factor);
split_map_.emplace(outer_iter, DivisionResult::Outer(new_iter, outer_iter->extent));
}
} else {
return Fail();
}
return split_map_.at(expr);
}

size_t unresolved_count_{0};
spectrometerHBH marked this conversation as resolved.
Show resolved Hide resolved
Analyzer* analyzer_;
const IterMarkSplitCollector collector_;
// the set of subspace iters
const std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>& sub_iters_;
// map from SplitExpr to its corresponding DivisionResult(Y*E(X)+X)
std::unordered_map<IterSplitExpr, DivisionResult, ObjectPtrHash, ObjectPtrEqual> split_map_;
// predicate of outer space and inner space;
PrimExpr outer_preds_{Bool(true)}, inner_preds_{Bool(true)};
};

Array<Array<IterMark>> SubspaceDivide(const Array<PrimExpr>& bindings,
const Map<Var, Range>& input_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective, arith::Analyzer* analyzer) {
const Array<IterSumExpr>& maps =
DetectIterMap(bindings, input_iters, predicate, require_bijective, analyzer);
if (maps.empty()) return {};

std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> inner_iter_set;
for (const Var& inner_iter : sub_iters) {
inner_iter_set.insert(inner_iter);
}

IterMarkSplitCollector collector;
collector.Collect(maps);
SubspaceDivider subspace_divider(analyzer, collector, inner_iter_set);

std::vector<Array<IterMark>> results;
for (const IterSumExpr& expr : maps) {
SubspaceDivider::DivisionResult res = subspace_divider.DivideIterSumExpr(expr, 0);
if (subspace_divider.unresolved_count()) return {};
results.push_back(
{IterMark(res.outer, res.outer_extent), IterMark(res.inner, res.inner_extent)});
}

results.push_back({IterMark(IterSumExpr({}, 0), subspace_divider.GetOuterPreds()),
IterMark(IterSumExpr({}, 0), subspace_divider.GetInnerPreds())});
return results;
}

TVM_REGISTER_GLOBAL("arith.SubspaceDivide")
.set_body_typed([](const Array<PrimExpr>& bindings, const Map<Var, Range>& root_iters,
const Array<Var>& sub_iters, const PrimExpr& predicate,
bool require_bijective) {
arith::Analyzer ana;
return SubspaceDivide(bindings, root_iters, sub_iters, predicate, require_bijective, &ana);
});

} // namespace arith
} // namespace tvm
Loading