Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TOPI, CUDA] Bug fix properly bind gpu threads to injective op #1603

Closed
wants to merge 1 commit into from

Conversation

masahi
Copy link
Member

@masahi masahi commented Aug 15, 2018

We had an interesting error report on the discussion forum, where reshape op is fused into convolution. In picture, it looks like this. Nodes in red are fused by NNVM. This fusion happens without the recent change I made in #1548.

image

The error occurred for cuda target because reshape op, which is an injective op, is not bound gpu threads when fused with convolution. The output of tvm.lower(...) is below.

lower function  fuse_reshape_broadcast_mul_conv2d_broadcast_add_1
// attr [tensor] storage_scope = "global"
allocate tensor[float32 * 1 * 16 * 1 * 1]
produce tensor {
  for (ax1, 0, 16) {
    tensor[ax1] = input0[ax1]
  }
}
produce tensor {
  // attr [iter_var(blockIdx.z, , blockIdx.z)] thread_extent = 1
  // attr [compute] storage_scope = "local"
  allocate compute[float32 * 2 * 4 * 1 * 1 * 1 * 1]
  // attr [pad_temp.shared] storage_scope = "shared"
  allocate pad_temp.shared[float32 * 1 * 8 * 1 * 56]
  // attr [input2.shared] storage_scope = "shared"
  allocate input2.shared[float32 * 32 * 8 * 1 * 1]
  // attr [iter_var(blockIdx.y, , blockIdx.y)] thread_extent = 56
  // attr [iter_var(blockIdx.x, , blockIdx.x)] thread_extent = 2
  // attr [iter_var(threadIdx.y, Range(min=0, extent=28), threadIdx.y)] thread_extent = 28
  // attr [iter_var(threadIdx.x, Range(min=0, extent=8), threadIdx.x)] thread_extent = 8
  produce compute {
    compute[0] = 0.000000f
    compute[4] = 0.000000f
    compute[1] = 0.000000f
    compute[5] = 0.000000f
    compute[2] = 0.000000f
    compute[6] = 0.000000f
    compute[3] = 0.000000f
    compute[7] = 0.000000f
    for (rx.ry.fused.rc.outer.fused, 0, 18) {
      produce pad_temp.shared {
        for (ax3.outer, 0, 2) {
          pad_temp.shared[((threadIdx.y + (threadIdx.x*56)) + (ax3.outer*28))] = tvm_if_then_else((((((1 - ((rx.ry.fused.rc.outer.fused % 9)/3)) <= blockIdx.y) && (blockIdx.y < (57 - ((rx.ry.fused.rc.outer.fused % 9)/3)))) && (((1 - (ax3.outer*28)) - ((rx.ry.fused.rc.outer.fused % 9) % 3)) <= threadIdx.y)) && (threadIdx.y < ((57 - (ax3.outer*28)) - ((rx.ry.fused.rc.outer.fused % 9) % 3)))), input1[((((((((blockIdx.y*56) + threadIdx.y) + (threadIdx.x*3136)) + ((rx.ry.fused.rc.outer.fused/9)*25088)) + (((rx.ry.fused.rc.outer.fused % 9)/3)*56)) + ((rx.ry.fused.rc.outer.fused % 9) % 3)) + (ax3.outer*28)) + -57)], 0.000000f)
        }
      }
      produce input2.shared {
        for (ax0.outer, 0, 2) {
          if (likely((threadIdx.y < (32 - (ax0.outer*28))))) {
            if (likely(((blockIdx.x*32) < ((16 - (ax0.outer*28)) - threadIdx.y)))) {
              input2.shared[(((threadIdx.y*8) + threadIdx.x) + (ax0.outer*224))] = input2[((((((((blockIdx.x*32) + threadIdx.y)*16) + threadIdx.x) + ((rx.ry.fused.rc.outer.fused/9)*8))*9) + (rx.ry.fused.rc.outer.fused % 9)) + (ax0.outer*4032))]
            }
          }
        }
      }
...

I replaced the use of tag.is_broadcast with tag.is_injective to make sure injective ops as well as broadcast ops are inlined correctly. But I'm not quite sure if this is a valid change, so please review @tqchen @Laurawly @merrymercy . If this looks good, then I should replace other uses of tag.is_broadcast() within other backends.

I added a simplified test case which fails without this PR.

@masahi masahi changed the title bind gpu threads to injective op properly [TOPI, CUDA] Bug fix properly bind gpu threads to injective op Aug 15, 2018
@tqchen
Copy link
Member

tqchen commented Aug 15, 2018

This seems to be the problem of the fusor, as currently we do not encourage fusing of injective op into conv2d, only elemwise is allowed

@masahi
Copy link
Member Author

masahi commented Aug 15, 2018

ok, close it for now.

But to change fusor, we need to realize conv + batchnorm and broadcast_mul before the final elemwise_add. I think the logic becomes very tricky.

@masahi masahi closed this Aug 15, 2018
@tqchen
Copy link
Member

tqchen commented Aug 15, 2018

The fusor's logic should enforce broadcast_mul not being fused into conv2d+bn, because broadcast_mul should be marked as injective(due to it is being fused into reshape)

@masahi
Copy link
Member Author

masahi commented Aug 15, 2018

The difficulty seems to be that broadcast_mul is a sibling node of conv + bn, so it cannot 'see' conv + bn node during the initial partitioning stage. But they are both assigned FuseRule::kFuseToMaster to elemwise_add at the bottom, so they are fused during the last grouping stage.

@tqchen
Copy link
Member

tqchen commented Aug 16, 2018

We should be able to check the fused type of the group, which is injective instead of elemwise

@masahi
Copy link
Member Author

masahi commented Aug 16, 2018

ok, my plan is as follows:

  • During the initial partition step, if injective op is followed by broadcast op, mark broadcast op's op pattern to be injective
  • During the last grouping step, check for input nodes op patterns. If one of them is kOutEWiseFusable and the other is kInjective, ignore the kOutEWiseFusable node and fuse only kInjective node.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants