Skip to content

Commit

Permalink
Revert "support overlapped itersum (apache#12039)" (apache#12137)
Browse files Browse the repository at this point in the history
This reverts commit 3e7a2ad.
  • Loading branch information
gigiblender authored and Mikael Sevenier committed Jul 26, 2022
1 parent 8908a04 commit b96dd14
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 176 deletions.
91 changes: 25 additions & 66 deletions src/arith/iter_affine_map.cc
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,8 @@ class IterMapRewriter : public ExprMutator {
using Parent = ExprMutator;

explicit IterMapRewriter(Analyzer* analyzer, const Map<Var, Range>& input_iters,
IterMapLevel check_level, bool simplify_trivial_iterators,
Array<String>* errors)
: analyzer_(analyzer),
check_level_(check_level),
errors_(*errors),
padding_predicate_(const_false()) {
bool simplify_trivial_iterators, Array<String>* errors)
: analyzer_(analyzer), errors_(*errors), padding_predicate_(const_false()) {
for (auto kv : input_iters) {
const Var& var = kv.first;
const Range& vrng = kv.second;
Expand Down Expand Up @@ -423,8 +419,6 @@ class IterMapRewriter : public ExprMutator {

// Internal analyzer
Analyzer* analyzer_;
// Iter map check level
IterMapLevel check_level_;
// Error messages for each unresolved expression.
Array<String>& errors_;
// The var map
Expand Down Expand Up @@ -657,7 +651,7 @@ class IterMapRewriter : public ExprMutator {
if (predicate_induced_max.defined())
predicate_induced_max = predicate_induced_max.value() - base;
}
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
Optional<IterSumExpr> opt = TryFuseIters(expr);
ICHECK(!opt.defined() || opt.value()->args.size() == 1);
// scale should be 1
if (opt.defined() && is_one(opt.value()->args[0]->scale)) {
Expand Down Expand Up @@ -708,7 +702,7 @@ class IterMapRewriter : public ExprMutator {
IterSumExpr NormalizeToIterWithOffset(IterSumExpr expr) {
// We are normalizing a regular iter
if (expr->args.size() < 1) return expr;
Optional<IterSumExpr> opt = TryFuseIters(expr, check_level_);
Optional<IterSumExpr> opt = TryFuseIters(expr);
if (opt.defined()) {
return opt.value();
} else {
Expand Down Expand Up @@ -741,10 +735,9 @@ class IterMapRewriter : public ExprMutator {
* return a corresponding IterSumExpr with extra offset if needed.
* Try to normalize IterSum into a fused IterMark
* \param expr The input sum.
* \param check_level The check level if iter mapping.
* \return The sum with the fused IterMark and extra offset if succeed.
*/
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr, IterMapLevel check_level) {
Optional<IterSumExpr> TryFuseIters(IterSumExpr expr) {
// select the iterators in order
std::vector<bool> visited(expr->args.size(), false);
std::vector<IterSplitExpr> flattened_iters, grouped_iters;
Expand All @@ -765,42 +758,14 @@ class IterMapRewriter : public ExprMutator {
}
// check if it can be remapped into a fused pattern.
PrimExpr expected_extra_base = 0;
PrimExpr tail_extent = 0;
PrimExpr expected_scale = base_scale.value();
for (size_t i = 0; i < expr->args.size();) {
// find position such that expr->args[j] match expected scale
int j = i == 0 ? base_index : expr->args.size() - 1;

size_t matched_pos = expr->args.size();
PrimExpr matched_scale{nullptr};
bool is_exact_match{false};

for (; j >= 0; --j) {
if (visited[j]) {
continue;
}
const PrimExpr& cur_scale = expr->args[j]->scale;

// for bijective mapping, the matched scale must equal to expected scale
if (analyzer_->CanProveEqual(cur_scale, expected_scale)) {
matched_pos = j;
matched_scale = cur_scale;
is_exact_match = true;
break;
}
if (check_level != IterMapLevel::Bijective && base_scale.value()->value == 1) {
// find the closest scale which is less or equal to expected scale
if (analyzer_->CanProveGreaterEqual(expected_scale - cur_scale, 0) &&
analyzer_->CanProveGreaterEqual(cur_scale, 0)) {
if (matched_pos == expr->args.size() ||
analyzer_->CanProveLess(matched_scale - cur_scale, 0)) {
matched_pos = j;
matched_scale = cur_scale;
}
}
}
// find j such that expr->args[j] has expected scale
size_t j = i == 0 ? base_index : 0;
for (; j < expr->args.size(); ++j) {
if (!visited[j] && analyzer_->CanProveEqual(expr->args[j]->scale, expected_scale)) break;
}
if (matched_pos == expr->args.size()) {
if (j == expr->args.size()) {
return NullOpt;
}
// look for the longest constrained iter started from expr->args[j]
Expand All @@ -810,8 +775,8 @@ class IterMapRewriter : public ExprMutator {
// otherwise we expect the scale of i to be 2*5=10
Optional<IterSumExpr> constraint_to_match;
for (const IterSumExpr& iter : constrained_iters_flattened_) {
if (IterSplitEqual(expr->args[matched_pos], iter->args.back(), false)) {
// find a predicate started from match position
if (IterSplitEqual(expr->args[j], iter->args.back(), false)) {
// find a predicate started from expr->args[j]
if (!constraint_to_match ||
constraint_to_match.value()->args.size() < iter->args.size()) {
constraint_to_match = iter;
Expand All @@ -828,7 +793,7 @@ class IterMapRewriter : public ExprMutator {
size_t k = 0;
for (; k < expr->args.size(); ++k) {
if (!visited[k] && IterSplitEqual(expr->args[k], *it, false)) {
if (analyzer_->CanProveEqual((*it)->scale * matched_scale, expr->args[k]->scale))
if (analyzer_->CanProveEqual((*it)->scale * expected_scale, expr->args[k]->scale))
break;
}
}
Expand All @@ -841,25 +806,20 @@ class IterMapRewriter : public ExprMutator {
auto iter = sum_fuse_map_.find(constraint_to_match.value());
ICHECK(iter != sum_fuse_map_.end());
const IterMarkWithOffset& iter_matched = iter->second;
grouped_iters.emplace_back(iter_matched.mark, div(matched_scale, base_scale.value()));
expected_extra_base += iter_matched.offset * matched_scale;
if (!is_exact_match) {
tail_extent += expected_scale - matched_scale;
}
expected_scale = matched_scale * iter_matched.mark->extent;
grouped_iters.emplace_back(iter_matched.mark, expected_scale);
expected_extra_base += iter_matched.offset * expected_scale;
expected_scale *= iter_matched.mark->extent;
// move forward
i += constraint_to_match.value()->args.size();
} else {
// constraint_to_match not found, skip this iterator
visited[matched_pos] = true;
IterSplitExpr arg = expr->args[matched_pos];
arg.CopyOnWrite()->scale = analyzer_->Simplify(div(arg->scale, base_scale.value()));
visited[j] = true;
IterSplitExpr arg = expr->args[j];
arg.CopyOnWrite()->scale =
analyzer_->Simplify(div(expr->args[j]->scale, base_scale.value()));
flattened_iters.push_back(arg);
grouped_iters.push_back(arg);
if (!is_exact_match) {
tail_extent += expected_scale - matched_scale;
}
expected_scale = matched_scale * expr->args[matched_pos]->extent;
expected_scale *= expr->args[j]->extent;
++i;
}
}
Expand All @@ -883,8 +843,7 @@ class IterMapRewriter : public ExprMutator {
expr->base + expected_extra_base);
} else {
// new iter, form a new mark
IterMark mark =
IterMark(structured_form, div(expected_scale, base_scale.value()) + tail_extent);
IterMark mark = IterMark(structured_form, div(expected_scale, base_scale.value()));
sum_fuse_map_[flattened_form] = IterMarkWithOffset(mark, 0);
flattened_map_[structured_form] = flattened_form;
return IterSumExpr({IterSplitExpr(mark, base_scale.value())},
Expand Down Expand Up @@ -1127,8 +1086,8 @@ IterMapResult DetectIterMap(const Array<PrimExpr>& indices, const Map<Var, Range
constraints.begin(), constraints.end(),
[](const IterConstraint& a, const IterConstraint& b) { return a.expr_size < b.expr_size; });

IterMapRewriter rewriter(analyzer, constrained_input_iters, check_level,
simplify_trivial_iterators, &result->errors);
IterMapRewriter rewriter(analyzer, constrained_input_iters, simplify_trivial_iterators,
&result->errors);
// Step0.0: rewrite constraints in the order from size-small ones to size-big ones
for (const IterConstraint& constraint : constraints) {
auto res = rewriter.RewriteIterConstraint(constraint.iter, constraint.lower_bound,
Expand Down Expand Up @@ -1322,7 +1281,7 @@ IterSumExpr IterMapRewriter::PreprocessDividend(IterMapExpr dividend, PrimExpr o
} else if (sum->args.size() == 1) {
return sum;
}
auto opt_fused = TryFuseIters(sum, check_level_);
auto opt_fused = TryFuseIters(sum);
if (!opt_fused) {
ErrorLogger(this) << "Dividend " << tvm::PrettyPrint(original_dividend)
<< ", can't be written as a single fused IterSum";
Expand Down
7 changes: 5 additions & 2 deletions tests/python/unittest/test_arith_intset.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ def do_test_point_access(point, predicates, var_dom, expect):


def test_region_lower_bound_unfusable():
# This test is designed to trigger an error in DetectIterMap,
# resulting from a numerator which required multiple input
# variables. The bug resulted in an exception being thrown,
# rather than a return value of None.
var_dom = {
tvm.tir.Var("i", "int32"): tvm.ir.Range(8),
tvm.tir.Var("j", "int32"): tvm.ir.Range(4),
Expand All @@ -332,8 +336,7 @@ def test_region_lower_bound_unfusable():
tvm.ir.Range.from_min_extent((i + j) // 2, 1),
]
result = tvm.arith.estimate_region_lower_bound(region, var_dom, predicate=True)
assert result[0].min_value == 0
assert result[0].max_value == 5
assert result is None


def test_union_lower_bound():
Expand Down
58 changes: 1 addition & 57 deletions tests/python/unittest/test_arith_iter_affine_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ def assert_iter_sum_pattern(
)
indices = res.indices
assert len(indices) == len(keys), res.errors
print(indices)
for i, input_iter in enumerate(keys):
spec = expect_dict[input_iter]
(
Expand Down Expand Up @@ -445,13 +446,6 @@ def test_predicate():
predicate=xo * 129 + xi < 128,
)

# strided iteration predicate
assert_iter_sum_pattern(
{xo * 16 + xi * 4: (10, 0, 4)},
var_dom([(xo, 3), (xi, 4)]),
predicate=xo * 4 + xi < 10,
)


def convert_division(divisions):
if divisions is None or len(divisions) == 0:
Expand Down Expand Up @@ -1016,55 +1010,5 @@ def test_padding():
assert_iter_sum_failure({flm(x, 16)}, var_dom([(x, 3)]))


def test_overlapped_fuse():
x = tvm.tir.Var("x", "int32")
y = tvm.tir.Var("y", "int32")
z = tvm.tir.Var("z", "int32")
a = tvm.tir.Var("x", "int32")
b = tvm.tir.Var("y", "int32")

# non-bijective fuse of two
assert_iter_sum_pattern(
{
x * 7 + y: (22, 0, 1),
},
var_dom([(x, 3), (y, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 3), (y, 8)]), check_level="bijective")

# non-bijective fuse of three
assert_iter_sum_pattern(
{
x * 18 + y * 7 + z: (40, 0, 1),
},
var_dom([(x, 2), (y, 3), (z, 8)]),
check_level="surjective",
)
assert_iter_sum_failure([x * 7 + y], var_dom([(x, 2), (y, 3), (z, 8)]), check_level="bijective")

# negative scale fusion is not allowed
assert_iter_sum_failure([x * -7 + y], var_dom([(x, 3), (y, 8)]), check_level="surjective")
assert_iter_sum_failure([x * 7 - y], var_dom([(x, 3), (y, 8)]), check_level="surjective")

# with predicate
assert_iter_sum_pattern(
{
a * 40 + b * 20 + x * 18 + y * 3 + z: (125, 6, 1),
},
var_dom([(a, 3), (b, 2), (x, 2), (y, 6), (z, 8)]),
predicate=tvm.tir.all(z < 4, 1 < x * 6 + y, x * 6 + y < 10),
check_level="surjective",
)

# stride=1 kernel
assert_iter_sum_pattern(
{x + a: (230, 0, 1)}, var_dom([(x, 224), (a, 7)]), check_level="surjective"
)

# do not allow both strided and overlapped
assert_iter_sum_failure([5 * x + 2 * y], var_dom([(x, 4), (y, 3)]), check_level="surjective")


if __name__ == "__main__":
tvm.testing.main()
26 changes: 13 additions & 13 deletions tests/python/unittest/test_meta_schedule_space_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
for i0_0, i1_0, i2_0, i0_1_1, i1_1_1, i2_1_1 in T.grid(1, 1, 2, 1, 1, 8):
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
n = T.axis.spatial(1, i0_1_1 + i0_2 + i0_3 + i0_0)
l = T.axis.spatial(128, i1_0 * 128 + i1_1_1 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1_1 * 8 + i2_2)
n = T.axis.spatial(1, i0_0 + i0_1_1 + i0_2 + i0_3)
l = T.axis.spatial(128, i1_1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, (i2_0 * 8 + i2_1_1) * 8 + i2_2 + i2_3)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
rc = T.axis.reduce(64, i4_1 + i4_0)
rc = T.axis.reduce(64, i4_0 + i4_1)
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
T.writes(conv1d_nlc_global[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
Expand Down Expand Up @@ -89,11 +89,11 @@ def c1d_1(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
PadInput[i0, i1, i2] = T.if_then_else(1 <= i1 and i1 < 257, inputs[i0, i1 - 1, i2], T.float32(0), dtype="float32")
for i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
rc = T.axis.reduce(64, i4_1 + i4_0)
rc = T.axis.reduce(64, i4_0 + i4_1)
T.reads(PadInput[n, l * 2 + rl, co // 128 * 64 + rc], weight[rl, rc, co])
T.writes(conv1d_nlc_global[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
Expand All @@ -107,7 +107,7 @@ def c1d_1(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
T.reads(conv1d_nlc_global[v0, v1, v2])
T.writes(conv1d_nlc[v0, v1, v2])
conv1d_nlc[v0, v1, v2] = conv1d_nlc_global[v0, v1, v2]

@T.prim_func
def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 128), "float32"], conv1d_nlc: T.Buffer[(1, 128, 128), "float32"]) -> None:
# function attr dict
Expand All @@ -119,11 +119,11 @@ def c1d_2(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
T.block_attr({"meta_schedule.parallel":288, "meta_schedule.unroll_explicit":16, "meta_schedule.vectorize":64})
for i0_0, i1_0, i2_0, i0_1, i1_1, i2_1, i3_0, i4_0, i0_2, i1_2, i2_2, i3_1, i4_1, i0_3, i1_3, i2_3 in T.grid(1, 1, 2, 1, 1, 8, 1, 64, 1, 64, 8, 3, 1, 1, 2, 1):
with T.block("conv1d_nlc"):
n = T.axis.spatial(1, i0_1 + i0_2 + i0_3 + i0_0)
l = T.axis.spatial(128, i1_0 * 128 + i1_1 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, i2_3 + i2_0 * 64 + i2_1 * 8 + i2_2)
n = T.axis.spatial(1, i0_0 + i0_1 + i0_2 + i0_3)
l = T.axis.spatial(128, i1_1 * 128 + i1_0 * 128 + i1_2 * 2 + i1_3)
co = T.axis.spatial(128, (i2_0 * 8 + i2_1) * 8 + i2_2 + i2_3)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1)
rc = T.axis.reduce(64, i4_1 + i4_0)
rc = T.axis.reduce(64, i4_0 + i4_1)
T.reads(inputs[n, l * 2 + rl - 1, co // 128 * 64 + rc], weight[rl, rc, co])
T.writes(conv1d_nlc[n, l, co])
T.block_attr({"meta_schedule.tiling_structure":"SSRSRS"})
Expand Down
12 changes: 6 additions & 6 deletions tests/python/unittest/test_meta_schedule_space_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
for ax0_ax1_ax2_fused in T.serial(260):
with T.block("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused // 4)
v1 = T.axis.spatial(258, i0_0_i1_0_i2_0_fused * 64 + ax0_ax1_ax2_fused % 260 // 4)
v2 = T.axis.spatial(64, i4_0 * 4 + ax0_ax1_ax2_fused % 4)
T.reads(inputs[v0, v1 - 1, v2])
T.writes(PadInput_shared[v0, v1, v2])
Expand All @@ -64,11 +64,11 @@ def c1d_0(inputs: T.Buffer[(1, 256, 64), "float32"], weight: T.Buffer[(3, 64, 12
weight_shared[v0, v1, v2] = weight[v0, v1, v2]
for i3_1, i4_1, i0_3, i1_3, i2_3, i3_2, i4_2, i0_4, i1_4, i2_4 in T.grid(1, 2, 1, 1, 2, 3, 2, 1, 4, 8):
with T.block("conv1d_nlc"):
n = T.axis.spatial(1, i0_4 + i0_3)
l = T.axis.spatial(128, i0_0_i1_0_i2_0_fused * 32 + i0_1_i1_1_i2_1_fused // 2 * 4 + i1_3 * 4 + i1_4)
co = T.axis.spatial(128, i0_1_i1_1_i2_1_fused % 2 * 64 + i0_2_i1_2_i2_2_fused * 16 + i2_3 * 8 + i2_4)
rl = T.axis.reduce(3, i3_0 * 3 + i3_1 * 3 + i3_2)
rc = T.axis.reduce(64, i4_0 * 4 + i4_1 * 2 + i4_2)
n = T.axis.spatial(1, i0_4 + i0_3 + 0 + 0 + 0)
l = T.axis.spatial(128, (i0_0_i1_0_i2_0_fused % 4 * 8 + i0_1_i1_1_i2_1_fused % 16 // 2 + 0 + i1_3) * 4 + i1_4)
co = T.axis.spatial(128, (((0 * 2 + i0_1_i1_1_i2_1_fused % 2) * 4 + i0_2_i1_2_i2_2_fused % 4) * 2 + i2_3) * 8 + i2_4)
rl = T.axis.reduce(3, (i3_0 + i3_1) * 3 + i3_2)
rc = T.axis.reduce(64, (i4_0 * 2 + i4_1) * 2 + i4_2)
T.reads(PadInput_shared[n, l * 2 + rl, co // 128 * 64 + rc], weight_shared[rl, rc, co])
T.writes(conv1d_nlc_local[n, l, co])
T.block_attr({"meta_schedule.thread_extent_high_inclusive":1024, "meta_schedule.thread_extent_low_inclusive":32, "meta_schedule.tiling_structure":"SSSRRSRS"})
Expand Down
Loading

0 comments on commit b96dd14

Please sign in to comment.