-
Notifications
You must be signed in to change notification settings - Fork 448
Bug in WarpScan? #112
Comments
NVCC compiles the code twice, once for the host, and once for the device. CUDA_ARCH is specified differently during those two passes. For the host, it is undefined. For the device, it is whatever the compute-capability is. I think the root of your problem is that you are using the template-instance of WarpScan as specialized on the host to specify the amount of shared memory dynamically to the kernel. My suggestion is to either: (a) Use statically-allocated shared memory instead of dynamic: remove (b) Instantiate WarpScan template on the host using all three parameters, e.g., "WarpScan<int, 32, 350>". But you will have to make sure the last one (PTX_ARCH) actdually matches the compile string. Let me know if you have any problems after fixing your bug. |
Thanks @dumerrill for you answer! See this minimal example in which I have removed the PTX_ARCH mismatch issue. #include<cub/cub.cuh>
constexpr int tpp = 2;
template<int tpp>
__global__ void test(int N){
int id = blockIdx.x*blockDim.x + threadIdx.x;
using WarpScan = cub::WarpScan<int, tpp>;
__shared__ typename WarpScan::TempStorage temp_storage[tpp];
int input = threadIdx.x;
int inclusive_output = 0;
int warp_aggregate = 0;
int warp_id = threadIdx.x/tpp;
__syncthreads();
WarpScan(temp_storage[warp_id]).InclusiveSum(input, inclusive_output, warp_aggregate);
__syncthreads();
int warpid = threadIdx.x%tpp;
int delta = tpp-warpid-1;
//The output is the expected one only if this line is uncommented
//warp_aggregate = __shfl_down(inclusive_output, delta, 32);
printf("threadIdx.x: %d - input = %d - inclusive_output = %d - warp_aggregate = %d\n",
threadIdx.x, input, inclusive_output, warp_aggregate);
};
int main(){
int N = 4;
test<tpp><<<1, N>>>(N);
cudaDeviceSynchronize();
};
I would expect InclusiveSum to give me the following output:
But what I get (unless I uncomment the __shfl_down line that fixes it) is this:
warp_aggregate is 1 for all threads, which is not what I expected. But maybe it is what it is intended to be! |
I have been having similar issues with WarpScan for logical warp sizes < 32, using a similar reproducer as above. I think that I have narrowed the issue down to the ShuffleIndex to obtain the warp aggregate value, which seems to be consistent with what @RaulPPelaez is fixing with the shuffle down. In my test case, going into the ShuffleIndex, every thread holds the expected value, but afterwards, all threads have the value from the last thread of the first logical warp, not their own. |
Additionally, if I replace line 553: |
Issue was not setting up the shfl constant properly. Refactor of shfl scans and reductions to always use lane_id as being relative to logical warp (not physical)
Thanks guys, you're right, we had a bug where we weren't setting the PTX shuffle constant properly, and it didn't show up in tests because we were only testing 1 subwarp instead of several. A fix is in Master and currently being QA'd for an imminent bugfix release. (As you guys probably know, you want to use CUB scans instead of writing your own via __shfl() because, by going down to ptx, we can leverage the predicate output from shfl to avoid additional instructions). |
Fixed in v1.8.0 |
I am using WarpScan::InclusiveSum and I am noticing something unexpected.
When I specialize the WarpScan template for PTX_ARCH>210 the warp_aggregate value is wrong. No matter the arch I compile for (I tried running with a gtx980 and a 750ti) or the cub version used (> 1.5.4). I only tested with CUDA 8.0
Check this sample:
With a warp size of 2 (tpp) I would expect the following output:
And it is indeed the output when the WaspScan template is specialized with PTX_ARCH<=210 and/or the code is compiled with
nvcc -std=c++11 -arch=sm_20 WarpScan.cu
However, if I specialize for PTX>210 and/or compile with -arch=sm_52 (>21 actually), I get the following result:
Which is not only different but incorrect.
Also, the code wont even compile when the virtual warp size is 1 for arch <= 21, and the results are incorrect otherwise (I want tpp=1 because in this case the code should be equivalent to another kernel I want to reproduce and improve).
This is weird, because if any, I would expect WarpScan to fail when compiled for an incorrect architecture (I run in sm_35 and sm_52 GPUs), but it is the other way around!.
Diving into the code I guess the behavior comes down to use the shared memory or the __shfl versions of WarpScan.
Am I doing something weird here?
Is there something I am not understanding about the behavior of this utility and this is expected?
Thanks!
EDIT:
So upon further testing I am seeing that with arch>210 warp_aggregate will take the value of the sum of the first logical warp. For example:
EDIT 2:
I managed to bypass the issue by adding the following snippet:
The text was updated successfully, but these errors were encountered: