Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching #5924

Merged
merged 7 commits into from
Jun 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions src/arith/rewrite_simplify.cc
Original file line number Diff line number Diff line change
Expand Up @@ -722,8 +722,15 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorDivNode* op) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = floordiv(bmod->base, c2val);
int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val, c2val);
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
return broadcast(floordiv(b1, c2), lanes).Eval();
if (ramp_min == ramp_max) {
// If b1 can devide c2
if (bmod->coeff % c2val == 0) {
return broadcast(floordiv(b1, c2), lanes).Eval();
}
// If all indices can be guaranteed to settle inside a coeff range
if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a unit test in tests_arith_rewrite_simplify to cover this rule.

return broadcast(floordiv(b1, c2), lanes).Eval();
}
}
}
}
Expand Down Expand Up @@ -847,6 +854,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const FloorModNode* op) {
} else {
return floormod(ramp(floormod(bmod->base, c2), c1, lanes), broadcast(c2, lanes)).Eval();
}
} else if (c2val % bmod->coeff == 0 && ramp_min == ramp_max) {
return ramp(floormod(b1, c2), c1, lanes).Eval();
}
}
}
Expand Down
46 changes: 46 additions & 0 deletions tests/python/unittest/test_arith_rewrite_simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def test_vector_simplify():
tvm.tir.Ramp(y + x, 1, 2))
ck.verify(y.astype("int32x2") + x.astype("int32x2"),
(y + x).astype("int32x2"))
ck.verify(tvm.tir.Broadcast(0, 4) + y,
tvm.tir.Broadcast(y, 4))
# Sub rules
ck.verify(tvm.tir.Ramp(x, 4, 4) - tvm.tir.Ramp(y, 2, 4),
tvm.tir.Ramp(x - y, 2, 4))
Expand All @@ -55,6 +57,8 @@ def test_vector_simplify():
tvm.tir.Ramp(x * 2, 8, 4))
ck.verify(2 * tvm.tir.Ramp(x, 4, 4),
tvm.tir.Ramp(x * 2, 8, 4))
ck.verify(tvm.tir.Broadcast(0, 4) * x,
tvm.tir.Broadcast(0, 4))

## DivMod rules
tdiv = tvm.tir.truncdiv
Expand All @@ -69,6 +73,7 @@ def test_vector_simplify():
(x).astype("int32x4"))
ck.verify(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8),
tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
# truc mod
ck.verify(tmod(y.astype("int32x2"), x.astype("int32x2")),
tmod(y, x).astype("int32x2"))
ck.verify(tmod(tvm.tir.Ramp(x, 4, 4), 2),
Expand All @@ -90,6 +95,27 @@ def test_vector_simplify():
(x).astype("int32x4"))
ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8),
fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8))
ck.verify(fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)),
tvm.tir.Ramp(fld(x, 4), 2, 5))
ck.verify(fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
fld(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)),
tvm.tir.Broadcast(x * 2, 4))
ck.verify(fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)),
fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)),
fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(fld(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Broadcast(fld(x, 16), 4))
ck.verify(fld(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Broadcast(fld(x, 8), 4))
ck.verify(fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
fld(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)))
ck.verify(fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
fld(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)))
ck.verify(fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
fld(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)))
# floor mod
ck.verify(flm(y.astype("int32x2"), x.astype("int32x2")),
flm(y, x).astype("int32x2"))
ck.verify(flm(tvm.tir.Ramp(x, 4, 4), 2),
Expand All @@ -98,6 +124,26 @@ def test_vector_simplify():
tvm.tir.Ramp(1, 1, 4))
ck.verify(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8),
flm(tvm.tir.Ramp(1, 15, 4), 8))
ck.verify(flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)),
tvm.tir.Broadcast(flm(x, 4), 4))
ck.verify(flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)),
flm(tvm.tir.Ramp(x, 7, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)),
tvm.tir.Ramp(0, 1, 4))
ck.verify(flm(tvm.tir.Ramp(x * 8, 1, 5), tvm.tir.Broadcast(4, 5)),
flm(tvm.tir.Ramp(0, 1, 5), tvm.tir.Broadcast(4, 5)))
ck.verify(flm(tvm.tir.Ramp(x * 8 + 7, 1, 4), tvm.tir.Broadcast(4, 4)),
flm(tvm.tir.Ramp(3, 1, 4), tvm.tir.Broadcast(4, 4)))
ck.verify(flm(tvm.tir.Ramp(x * 4, 1, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Ramp(flm(x * 4, 64), 1, 4))
ck.verify(flm(tvm.tir.Ramp(x * 8, 2, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Ramp(flm(x * 8, 64), 2, 4))
ck.verify(flm(tvm.tir.Ramp(x * 4, 1, 5), tvm.tir.Broadcast(64, 5)),
tvm.tir.Ramp(flm(x * 4, 64), 1, 5))
ck.verify(flm(tvm.tir.Ramp(x * 4 + 3, 1, 4), tvm.tir.Broadcast(64, 4)),
tvm.tir.Ramp(flm(x * 4 + 3, 64), 1, 4))
ck.verify(flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)),
flm(tvm.tir.Ramp(x * 7, 1, 4), tvm.tir.Broadcast(64, 4)))

# Min/Max rules
vx = te.var("vx", dtype="int32x2")
Expand Down
184 changes: 184 additions & 0 deletions tests/python/unittest/test_target_codegen_cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -485,6 +485,33 @@ def test_cuda_floordiv_with_vectorization():
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)

def test_cuda_floormod_with_vectorization():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
return

with tvm.target.cuda():
# B[i] = A[floormod(i, k)]
n = 256
k = 37
A = te.placeholder((n,), name='A')
B = te.compute((n,), lambda i: A[tvm.tir.floormod(i, k)], name='B')
s = te.create_schedule(B.op)
xo, xi = s[B].split(B.op.axis[0], nparts=1)
xio, xii = s[B].split(xi, factor=4)
s[B].vectorize(xii)
s[B].bind(xo, bx)
s[B].bind(xio, tx)
func = tvm.build(s, [A, B], 'cuda')

ctx = tvm.gpu(0)
a_np = np.random.uniform(size=(n,)).astype(A.dtype)
b_np = np.array([a_np[i % k] for i in range(0, n)])
a_nd = tvm.nd.array(a_np, ctx)
b_nd = tvm.nd.array(np.zeros(b_np.shape, dtype=b_np.dtype), ctx)
func(a_nd, b_nd)
tvm.testing.assert_allclose(b_nd.asnumpy(), b_np, rtol=1e-3)

def test_vectorized_casts():
if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("skip because cuda is not enabled..")
Expand Down Expand Up @@ -693,6 +720,160 @@ def check_cuda(dtype, n, l, padding, lanes):
check_cuda("float16", 64, 16, 3, 4)
check_cuda("float32", 64, 16, 3, 4)

def vcf_check_common(s, args):
N = 512

# To check if every vectorize loop transforms to ramp expr successfully
stmt = tvm.lower(s, args)
# Use this as a stack flag to show whether this stmt is inside a BroadcastNode
inside_broadcast = [False]

# Possible patterns:
# Reduce init: Store[Ramp] = Broadcast(0)
# Shared memory copy: Store[Ramp] = Load[Ramp]
# Compute: Store[Ramp] = Load[Ramp] ... Broadcast[Load]

def pre_visit(stmt):
if isinstance(stmt, tvm.tir.Broadcast):
inside_broadcast[0] = True
# Check Broadcast[Imm numbers] or Broadcast[Load] patterns
assert isinstance(stmt.value, (tvm.tir.IntImm, tvm.tir.FloatImm, tvm.tir.Load))
if isinstance(stmt, tvm.tir.Store):
# Check Store[Ramp] pattern
assert isinstance(stmt.index, tvm.tir.Ramp)
if isinstance(stmt, tvm.tir.Load):
# Check Broadcast[Load] or Load[Ramp] patterns
assert inside_broadcast[0] or isinstance(stmt.index, tvm.tir.Ramp)
# Skip the rest
return stmt
return None

def post_visit(stmt):
if isinstance(stmt, tvm.tir.Broadcast):
inside_broadcast[0] = False
return None

tvm.tir.stmt_functor.ir_transform(stmt['main'].body, pre_visit, post_visit)

if not tvm.gpu(0).exist or not tvm.runtime.enabled("cuda"):
print("CUDA device not found, skip the verification.")
return
else:
tgt = tvm.target.cuda()
mod = tvm.build(s, args, tgt)
# To check if every vectorize loop transforms to correct instruction
# print(mod.imported_modules[0].get_source())

ctx = tvm.context("cuda", 0)
a = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx)
b = tvm.nd.array(np.random.uniform(size=(512, 512)).astype("float32"), ctx)
c = tvm.nd.array(np.zeros((512, 512), dtype="float32"), ctx)
mod(a, b, c)
tvm.testing.assert_allclose(c.asnumpy(), np.dot(
a.asnumpy(), b.asnumpy()), rtol=1e-5)

def test_vectorized_cooperative_fetching_x():
N = 512
A = te.placeholder((N, N), name='A', dtype='float32')
B = te.placeholder((N, N), name='B', dtype='float32')
k = te.reduce_axis((0, N), name='k')
C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k))
s = te.create_schedule(C.op)
i, j = s[C].op.axis
k = s[C].op.reduce_axis[0]

AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])

i3, i4 = s[C].split(i, factor=4)
i2, i3 = s[C].split(i3, factor=2)
i1, i2 = s[C].split(i2, factor=8)
i0, i1 = s[C].split(i1, factor=1)
j3, j4 = s[C].split(j, factor=4)
j2, j3 = s[C].split(j3, factor=2)
j1, j2 = s[C].split(j2, factor=8)
j0, j1 = s[C].split(j1, factor=2)
k1, k2 = s[C].split(k, factor=8)
k0, k1 = s[C].split(k1, factor=8)
s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4)
block_it = s[C].fuse(i0, j0)
s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x"))
vthread_it = s[C].fuse(i1, j1)
s[C].bind(vthread_it, tvm.te.thread_axis("vthread"))
thread_it = s[C].fuse(i2, j2)
s[C].bind(thread_it, tvm.te.thread_axis("threadIdx.x"))
s[C].vectorize(j4)

s[AA].compute_at(s[C], k0)
iaa, jaa = s[AA].op.axis
s[BB].compute_at(s[C], k0)
ibb, jbb = s[BB].op.axis
aa_fused = s[AA].fuse(iaa, jaa)
bb_fused = s[BB].fuse(ibb, jbb)
aa1, aa2 = s[AA].split(aa_fused, factor=4)
aa0, aa1 = s[AA].split(aa1, factor=64)
bb1, bb2 = s[BB].split(bb_fused, factor=4)
bb0, bb1 = s[BB].split(bb1, factor=64)
s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.x"))
s[AA].vectorize(aa2)
s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.x"))
s[BB].vectorize(bb2)

vcf_check_common(s, [A, B, C])

def test_vectorized_cooperative_fetching_xy():
N = 512
A = te.placeholder((N, N), name='A')
B = te.placeholder((N, N), name='B')
k = te.reduce_axis((0, N), name='k')
C = te.compute((N, N), lambda i, j: te.sum(A[i, k] * B[k, j], axis=k))
s = te.create_schedule(C.op)
i, j = s[C].op.axis
k = s[C].op.reduce_axis[0]

AA = s.cache_read(A, "shared", [C])
BB = s.cache_read(B, "shared", [C])

i3, i4 = s[C].split(i, factor=4)
i2, i3 = s[C].split(i3, factor=2)
i1, i2 = s[C].split(i2, factor=8)
i0, i1 = s[C].split(i1, factor=1)
j3, j4 = s[C].split(j, factor=4)
j2, j3 = s[C].split(j3, factor=2)
j1, j2 = s[C].split(j2, factor=8)
j0, j1 = s[C].split(j1, factor=2)
k1, k2 = s[C].split(k, factor=8)
k0, k1 = s[C].split(k1, factor=8)
s[C].reorder(i0, j0, i1, j1, i2, j2, k0, k1, i3, j3, k2, i4, j4)
block_it = s[C].fuse(i0, j0)
s[C].bind(block_it, tvm.te.thread_axis("blockIdx.x"))
vthread_it = s[C].fuse(i1, j1)
s[C].bind(vthread_it, tvm.te.thread_axis("vthread"))
s[C].bind(i2, tvm.te.thread_axis("threadIdx.y"))
s[C].bind(j2, tvm.te.thread_axis("threadIdx.x"))
s[C].vectorize(j4)

s[AA].compute_at(s[C], k0)
iaa, jaa = s[AA].op.axis
s[BB].compute_at(s[C], k0)
ibb, jbb = s[BB].op.axis
aa_fused = s[AA].fuse(iaa, jaa)
bb_fused = s[BB].fuse(ibb, jbb)
aa2, aa3 = s[AA].split(aa_fused, factor=4)
aa1, aa2 = s[AA].split(aa2, factor=8)
aa0, aa1 = s[AA].split(aa1, factor=8)
bb2, bb3 = s[BB].split(bb_fused, factor=4)
bb1, bb2 = s[BB].split(bb2, factor=8)
bb0, bb1 = s[BB].split(bb1, factor=8)
s[AA].bind(aa1, tvm.te.thread_axis("threadIdx.y"))
s[AA].bind(aa2, tvm.te.thread_axis("threadIdx.x"))
s[AA].vectorize(aa3)
s[BB].bind(bb1, tvm.te.thread_axis("threadIdx.y"))
s[BB].bind(bb2, tvm.te.thread_axis("threadIdx.x"))
s[BB].vectorize(bb3)

vcf_check_common(s, [A, B, C])

if __name__ == "__main__":
test_cuda_vectorize_add()
test_cuda_multiply_add()
Expand All @@ -709,7 +890,10 @@ def check_cuda(dtype, n, l, padding, lanes):
test_cuda_reduction()
test_cuda_mix_threaded_and_normal_reduction()
test_cuda_floordiv_with_vectorization()
test_cuda_floormod_with_vectorization()
test_vectorized_intrin1()
test_vectorized_intrin2()
test_vectorized_popcount()
test_cuda_vectorize_load_permute_pad()
test_vectorized_cooperative_fetching_x()
test_vectorized_cooperative_fetching_xy()