Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Arith][GPU]Rewrite simplify fix for Vectorized Cooperative Fetching #5924

Merged
merged 7 commits into from
Jun 26, 2020

Conversation

jcf94
Copy link
Contributor

@jcf94 jcf94 commented Jun 25, 2020

This pr is part of #5883 , fix for the rewrite_simplify error when doing vectorized cooperative fetching in some cases.

Code generated with bug is shown like this:

A.shared[ramp(((ax0.ax1.fused.outer.outer*256) + (threadIdx.x_1*4)), 1, 4)] =
(float32x4*)A_2[(((broadcast(((floordiv(blockIdx.x, 4)*32768) + (ax0.ax1.fused.outer.outer*2048)), 4) + (floordiv(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4))*broadcast(512, 4))) + broadcast((k.outer.outer*64), 4)) + floormod(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4)))])

Which will finally lower to wrong CUDA C instructions.
This should be simplified to generate the correct RampNode:

A.shared[ramp(((ax0.ax1.fused.outer.outer*256) + (threadIdx.x_1*4)), 1, 4)] =
(float32x4*)A_2[ramp((((((floordiv(blockIdx.x, 4)*32768) + (ax0.ax1.fused.outer.outer*2048)) + (floordiv(threadIdx.x_1, 16)*512)) + (k.outer.outer*64)) + (floormod(threadIdx.x_1, 16)*4)), 1, 4)])

Then main problems inside this expression are:

threadIdx.x_1 = [0, 64]
floordiv(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4)) * broadcast(512, 4)
floormod(ramp((threadIdx.x_1*4), 1, 4), broadcast(64, 4))

should be simplified to:

threadIdx.x_1 = [0, 64]
broadcast(floordiv(threadIdx.x_1, 16)*512), 4)
ramp(floormod(threadIdx.x_1, 16)*4, 1, 4)

@merrymercy Our former simplify rules will cause some extra bug, I find a better way to fix this.
@tqchen For the UTs, I'm not sure if there's any better way to check if all of the inner AST blocks are RampNode, the current UTs I added can still pass even if the vectorize failed.

cc @minminsun @FrozenGene @yangjunpro

@tqchen
Copy link
Member

tqchen commented Jun 25, 2020

Thanks @jcf94 we should add a testcase to test_arith_rewrite_simplify, by constructing the case and

  • verify each of the rule added in this PR.
  • Use isinstance(x, tvm.ir.Ramp) to assert the ramp node
  • You mentioned a bug in the previous rule, it would be great if the testcase covers the bug you mentioned

src/arith/rewrite_simplify.cc Outdated Show resolved Hide resolved
return broadcast(floordiv(b1, c2), lanes).Eval();
}
// If all indices can be guaranteed to settle inside a coeff range
if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) * c1val < bmod->coeff) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please add a unit test in tests_arith_rewrite_simplify to cover this rule.

@tqchen tqchen added status: need test case need test cases to cover the change status: need update need update based on feedbacks labels Jun 25, 2020
@jcf94
Copy link
Contributor Author

jcf94 commented Jun 26, 2020

Thanks @jcf94 we should add a testcase to test_arith_rewrite_simplify, by constructing the case and

  • verify each of the rule added in this PR.
  • Use isinstance(x, tvm.ir.Ramp) to assert the ramp node
  • You mentioned a bug in the previous rule, it would be great if the testcase covers the bug you mentioned

Commemts are all addressed.

  • Add several test cases in tests_arith_rewrite_simplify for simplify rules
  • Update test_target_codegen_cuda UTs & use pre post function to check the RampNode patterns
  • The bug metioned above is introduced by our former implementation, currently everything work fine

@merrymercy
Copy link
Member

@jcf94 Did our old rule affect the correctness of common operators?

@jcf94
Copy link
Contributor Author

jcf94 commented Jun 26, 2020

@jcf94 Did our old rule affect the correctness of common operators?

Yes, with those rules several other UTs will fail.
For example in test_arith_intset.py:test_mod(),

ck.verify(flm(y, 8), {y : tvm.arith.IntervalSet(z*8+x*4, z*8+x*4+3)}, (0, 7))

Our rules make it to be

(((z*8) + (x*4)) - (8*floordiv(((z*8) + (x*4)), 8))), ((((z*8) + (x*4)) + 3) - (8*floordiv(((z*8) + (x*4)), 8)))

@tqchen tqchen merged commit fcbebea into apache:master Jun 26, 2020
@tqchen
Copy link
Member

tqchen commented Jun 26, 2020

Thanks @jcf94 @merrymercy . this PR is now merged

@tqchen tqchen added status: accepted and removed status: need test case need test cases to cover the change status: need update need update based on feedbacks labels Jun 26, 2020
@jcf94 jcf94 deleted the rewrite_simplify_fix_for_vcf branch June 27, 2020 01:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants