Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adds Sort, revamps TensorMath, adds masked* operations #120

Merged
merged 1 commit into from
Apr 1, 2015
Merged

Conversation

soumith
Copy link
Member

@soumith soumith commented Mar 20, 2015

Implements for #70 :

  • maskedSelect
  • maskedCopy
  • maskedFill
  • sort

Adds a TensorApply and TensorReduce kernels that work on both contiguous and non-contiguous tensors, and maintains equal performance for contiguous cases.

This PR makes performance with non-contiguous kernels a huge huge improvement.

Implemented completely by @wickedfoo

Having some build issues, looking into it

@soumith
Copy link
Member Author

soumith commented Mar 20, 2015

Also has a multi-GPU copy performance enhancement by @adamlerer

@wickedfoo
Copy link
Contributor

also fyi re: @adamlerer 's change, that's actually a bug fix too. both the old and new non-contiguous copy kernels were missing the same synchronization semantics implied by cudaMemcpyAsync() when copying between GPUs; cudaMemcpyAsync() actually synchronizes in both src and dst default streams, whereas a kernel launch only gets inserted into the stream of the GPU that it is running on.

@dominikgrewe
Copy link
Member

Thanks guys. That looks really impressive!
Can we leave this up for a couple of days before merging so people can digest it? :)

I just have two questions from having a first glance:

  1. For apply, you fix the number of elements processed per thread to 8 so the compiler can unroll the loop. When processing small tensors, that means we probably don't fully utilise the GPU. Have you looked into that?

  2. Functions such as THCudaTensor_pointwiseApply1 return false for some errors and call THError for others. When just returning false the user doesn't really know why something went wrong. Is it better to just stick to calling THError in all cases?

@wickedfoo
Copy link
Contributor

  1. For larger tensors it better amortizes the initial setup cost, and I
    tried 1-16 in powers of 2 and 8 was about optimal.

I believe when I looked at the SASS while working on this (and on other
codes) the compiler interleaves non-interdependent loads to some degree
based on its operation timing heuristics in the unrolled loop, so that
multiple loads can be in flight at once, so there may not be much benefit
to having more threads perform the load transactions that a smaller number
of threads would be doing concurrently anyways. Increased ILP through fewer
interdependencies and more in-flight operations versus greater TLP with
fewer in-flight operations per thread, though I may not actually be using
the ILP by not giving it temporary storage for the loads.

For really small tensors, note that if the tensor has less than 4096
elements, it will only be processed by one block (thus blockId ==
getLastLinearBlockId()), which doesn't use the unrolled path.
There, each thread issues one load and loops in a block-wide stride. In
that case, there could be a perf hole between 512-4096 elements, since I
could have more warps active.

I will look into 1-16384ish element tensors, inspect the SASS again and do
more detailed analysis.

  1. Sounds good, though in some cases the caller may have allocated memory
    which will not be properly cleaned up upon that error (because there's no
    RAII style allocation, instead depending upon THCudaTensor_free() etc, like
    the error I pointed out). The pointwise functions allocate no memory,
    leaving that response to the caller. If anything I'd argue that the cuda
    error check should also be moved out so the only failure the pointwise
    functions have is if the tensor is too large or has too many dimensions,
    which is returned via true/false.

On Thu, Mar 19, 2015 at 9:56 PM, Dominik Grewe notifications@github.com
wrote:

Thanks guys. That looks really impressive!
Can we leave this up for a couple of days before merging so people can
digest it? :)

I just have two questions from having a first glance:

  1. For apply, you fix the number of elements processed per thread to 8 so
    the compiler can unroll the loop. When processing small tensors, that means
    we probably don't fully utilise the GPU. Have you looked into that?

  2. Functions such as THCudaTensor_pointwiseApply1 return false for some
    errors and call THError for others. When just returning false the user
    doesn't really know why something went wrong. Is it better to just stick to
    calling THError in all cases?


Reply to this email directly or view it on GitHub
#120 (comment).

@wickedfoo
Copy link
Contributor

Oh, and another comment about sort(). Some caveats about the implementation:

-Handles sort along any dimension (contiguous or not) if the size of each slice is 1-2048 elements. Thus, a sort(2) on a CudaTensor(128, 128, 128, 128) will work.
-Handles sort along contiguous dimensions if the total number of slices is <= 16, regardless of the size of the slice. Thus, a sort(2) on a CudaTensor(3, 100000000) will work. This is delegated to (number of slices) Thrust kernel launches.
-What is not yet implemented are cases not falling into the above two categories. Thus, a sort(2) on a CudaTensor(16, 5000, 16, 16) will not work (non-contiguous, >2048 elements), or a sort(2) on a CudaTensor(500, 5000) will not work (contiguous, more than 500 slices, each slice larger than 2048).

I set that aside to work on other things. That can be fixed sooner or later if important.

@wickedfoo
Copy link
Contributor

*nb: pointwise functions allocate no memory except in the exceedingly weird case where you are writing into a tensor with multiple indices that map to the same element, say an unfold-ed tensor, which I argue should just be an error in Torch (unless you are updating just a single element). cunn had a test for this.

Instead, I follow the old behavior is that both arguments are made contiguous, the operation is made, and then the result is written back into the overlapping index tensor using the copy kernel again. The only reason why you got a result that you expected before is because of several properties:
-there’s a write race between multiple GPU threads to write into the one destination memory place;
-an arbitrary choice (dependent upon kernel execution order, etc.) is made as to who wins;
-writing a float in CUDA via a data race does not result in tearing (e.g., the lower half word comes from one thread, the upper half comes from another).

In case where all source values are the same for a single output, you get what you expect; in the case where values are different, it depends upon the implementation and execution order of the non-contiguous copy kernel.

@nicholas-leonard
Copy link
Member

This PR only seems to have unit tests for non-contiguous multi-gpu copy. It should include tests for masked*, sort and most of the other use cases that were added.

@wickedfoo
Copy link
Contributor

Perf relative to the old Thrust implementation for 1-d contiguous tensors. apply1 here is add(1.0), apply2 is a:add(b), apply3 is torch.cmul(c, a, b). New code is faster for large tensors, but slower for small tensors, to a larger degree than desired (but still everything is in the us range, and typical noise is 1-5 us anyways). I'll work on perf for small tensors.

operation 1-d size new time in us (K40) old time in us (K40) relative speedup/slowdown
apply1 1 3.616000 2.559000 0.707688x
apply1 2 2.144000 1.856000 0.865672x
apply1 4 1.984000 1.696000 0.854839x
apply1 8 1.984000 1.728000 0.870968x
apply1 16 1.984000 1.728000 0.870968x
apply1 32 1.952000 1.632000 0.836066x
apply1 64 1.952000 1.632000 0.836066x
apply1 128 2.016000 1.728000 0.857143x
apply1 256 2.048000 1.728000 0.84375x
apply1 512 2.080000 1.760000 0.846154x
apply1 1024 2.336000 1.920000 0.821918x
apply1 2048 2.912000 2.848000 0.978022x
apply1 4096 4.095000 3.008000 0.734554x
apply1 8192 5.408000 2.464000 0.455621x
apply1 16384 4.992000 2.144000 0.429487x
apply1 32768 4.831000 2.432000 0.503415x
apply1 65536 4.096000 2.848000 0.695312x
apply1 131072 4.672000 3.616000 0.773973x
apply1 262144 8.671000 5.535000 0.638335x
apply1 524288 21.53500 19.68000 0.913861x
apply1 1048576 38.46300 37.69500 0.980033x
apply1 2097152 72.92600 73.27900 1.00484x
apply1 4194304 144.4120 145.1170 1.00488x
apply1 8388608 284.1530 289.2410 1.01791x
apply1 16777216 565.0740 578.7060 1.02412x
apply1 33554432 1.12e+03 1.16e+03 1.03571x
apply1 67108864 2.25e+03 2.35e+03 1.04444x
apply1 134217728 4.48e+03 4.78e+03 1.06696x
apply2 1 3.936000 2.368000 0.601626x
apply2 2 2.176000 1.792000 0.823529x
apply2 4 2.080000 1.696000 0.815385x
apply2 8 2.080000 1.728000 0.830769x
apply2 16 2.048000 1.728000 0.84375x
apply2 32 2.048000 1.728000 0.84375x
apply2 64 2.080000 1.728000 0.830769x
apply2 128 2.080000 1.760000 0.846154x
apply2 256 2.080000 1.760000 0.846154x
apply2 512 2.208000 2.688000 1.21739x
apply2 1024 2.528000 2.304000 0.911392x
apply2 2048 3.232000 2.304000 0.712871x
apply2 4096 4.448000 2.400000 0.539568x
apply2 8192 5.760000 1.952000 0.338889x
apply2 16384 5.247000 2.144000 0.408614x
apply2 32768 5.184000 2.848000 0.549383x
apply2 65536 4.768000 3.968000 0.832215x
apply2 131072 5.952000 6.527000 1.09661x
apply2 262144 14.68800 14.52700 0.989039x
apply2 524288 31.19900 29.63100 0.949742x
apply2 1048576 58.11000 58.20700 1.00167x
apply2 2097152 109.3410 113.3090 1.03629x
apply2 4194304 215.0030 227.7380 1.05923x
apply2 8388608 424.1500 450.6770 1.06254x
apply2 16777216 853.3230 905.5130 1.06116x
apply2 33554432 1.71e+03 1.78e+03 1.04094x
apply2 67108864 3.41e+03 3.57e+03 1.04692x
apply2 134217728 6.81e+03 7.13e+03 1.04699x
apply3 1 3.968000 2.848000 0.717742x
apply3 2 2.432000 1.888000 0.776316x
apply3 4 2.368000 1.856000 0.783784x
apply3 8 2.304000 1.727000 0.749566x
apply3 16 2.272000 1.760000 0.774648x
apply3 32 2.272000 1.728000 0.760563x
apply3 64 2.304000 1.728000 0.75x
apply3 128 2.272000 1.760000 0.774648x
apply3 256 2.336000 1.824000 0.780822x
apply3 512 2.368000 1.920000 0.810811x
apply3 1024 2.687000 2.144000 0.797916x
apply3 2048 3.360000 3.264000 0.971429x
apply3 4096 4.576000 3.296000 0.72028x
apply3 8192 5.920000 2.720000 0.459459x
apply3 16384 5.504000 2.496000 0.453488x
apply3 32768 5.440000 2.911000 0.53511x
apply3 65536 4.896000 3.712000 0.75817x
apply3 131072 7.968000 7.648000 0.959839x
apply3 262144 19.26300 16.19100 0.840523x
apply3 524288 31.48800 29.02300 0.921716x
apply3 1048576 58.55900 55.80700 0.953005x
apply3 2097152 111.6460 109.3100 0.979077x
apply3 4194304 217.2750 222.0110 1.0218x
apply3 8388608 429.5570 431.3490 1.00417x
apply3 16777216 945.2560 867.2420 0.917468x
apply3 33554432 1.71e+03 1.77e+03 1.03509x
apply3 67108864 3.44e+03 3.63e+03 1.05523x
apply3 134217728 6.83e+03 7.61e+03 1.1142x

Perf relative to old newContiguous/Thrust/copy to non-contig/free code for non-contiguous 2-d tensors with transpose(1, 2) applied to all arguments. Same operations as above. (Though, as an optimization, I could detect that for apply1 the layout is contiguous after some permutation, and for apply2/3 all arguments have the same size/stride arrays and is contiguous after some permutation, and do the application in the reordered view, so this would map to the contiguous case, but this currently forces the non-contiguous path for both old and new code).

operation 2-d size new time in us (K40) old time in us (K40) relative speedup/slowdown
apply1_noncontig 2x2 34.093856811523 264.16778564453 7.74825x
apply1_noncontig 4x4 31.948089599609 188.82751464844 5.91045x
apply1_noncontig 8x8 23.126602172852 168.0850982666 7.26804x
apply1_noncontig 16x16 23.126602172852 126.12342834473 5.45361x
apply1_noncontig 32x32 30.994415283203 166.89300537109 5.38462x
apply1_noncontig 64x64 40.054321289062 148.05793762207 3.69643x
apply1_noncontig 128x128 42.915344238281 135.18333435059 3.15x
apply1_noncontig 256x256 55.074691772461 148.05793762207 2.68831x
apply1_noncontig 512x512 107.0499420166 1950.0255584717 18.216x
apply1_noncontig 1024x1024 334.02442932129 1926.8989562988 5.76874x
apply1_noncontig 2048x2048 1431.941986084 2746.1051940918 1.91775x
apply1_noncontig 4096x4096 7370.9487915039 10210.037231445 1.38517x
apply1_noncontig 8192x8192 50721.168518066 72425.842285156 1.42792x
apply2_noncontig 2x2 42.200088500977 211.95411682129 5.0226x
apply2_noncontig 4x4 22.88818359375 190.97328186035 8.34375x
apply2_noncontig 8x8 23.126602172852 172.13821411133 7.4433x
apply2_noncontig 16x16 23.126602172852 171.89979553223 7.43299x
apply2_noncontig 32x32 48.87580871582 159.97886657715 3.27317x
apply2_noncontig 64x64 48.87580871582 149.96528625488 3.06829x
apply2_noncontig 128x128 55.074691772461 156.87942504883 2.84848x
apply2_noncontig 256x256 74.148178100586 926.97143554688 12.5016x
apply2_noncontig 512x512 160.93254089355 1886.1293792725 11.72x
apply2_noncontig 1024x1024 740.05126953125 1858.9496612549 2.51192x
apply2_noncontig 2048x2048 3013.8492584229 2702.9514312744 0.896844x
apply2_noncontig 4096x4096 12011.051177979 10941.982269287 0.910993x
apply2_noncontig 8192x8192 81535.816192627 75384.140014648 0.924552x
apply3_noncontig 2x2 36.001205444336 272.98927307129 7.58278x
apply3_noncontig 4x4 23.841857910156 207.90100097656 8.72x
apply3_noncontig 8x8 36.001205444336 248.90899658203 6.91391x
apply3_noncontig 16x16 33.140182495117 180.95970153809 5.46043x
apply3_noncontig 32x32 44.822692871094 171.18453979492 3.81915x
apply3_noncontig 64x64 50.067901611328 168.0850982666 3.35714x
apply3_noncontig 128x128 55.074691772461 202.89421081543 3.68398x
apply3_noncontig 256x256 77.009201049805 1084.0892791748 14.0774x
apply3_noncontig 512x512 216.00723266602 2681.9705963135 12.4161x
apply3_noncontig 1024x1024 936.98501586914 2667.9039001465 2.84733x
apply3_noncontig 2048x2048 2810.001373291 3664.0167236328 1.30392x
apply3_noncontig 4096x4096 12379.884719849 13031.959533691 1.05267x
apply3_noncontig 8192x8192 82146.883010864 89496.13571167 1.08946x

#ifndef TH_CUDA_TENSOR_SORT_INC
#define TH_CUDA_TENSOR_SORT_INC

#include "deeplearning/torch/cutorch/lib/THC/THCTensor.h"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fix path.

const IndexType startLinearIndex =
getStartLinearIndex<IndexType>(blockId);

if (blockId == getLastLinearBlockId<IndexType>()) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for massive tensors with >536 million elements, virtual (with stride 0) or physical, this logic is wrong, fixing...

@dominikgrewe
Copy link
Member

What's the status of this PR? Would be great to get this in.

@soumith soumith force-pushed the goodies branch 3 times, most recently from f004da6 to daf2c51 Compare March 30, 2015 22:13
@soumith
Copy link
Member Author

soumith commented Mar 30, 2015

Tomorrow. Writing tests.

@soumith
Copy link
Member Author

soumith commented Mar 31, 2015

added some tests and fixes for review. will squash the commits before merging.


-- contiguous, no result tensor, cuda mask
local x = torch.randn(n_row, n_col):float()
local mask = torch.DoubleTensor():rand(n_row*n_col):mul(2):floor():byte():resize(n_row,n_col)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This can be shortened to
torch.DoubleTensor(n_row, n_col):uniform():round():byte()

@soumith
Copy link
Member Author

soumith commented Mar 31, 2015

Okay, TensorCopy was implemented according to the docs, but the docs were apparently wrong. So removed it.

Wrote a simple unit test for sort to start with, but sort does not seem to be agreeing with CPU for some cases.

To not delay this PR any further, I'll factor Sort out of this PR and merge it tomorrow if everyone's okay: cc: @dominikgrewe

@soumith
Copy link
Member Author

soumith commented Apr 1, 2015

okay, ready for squash & merge now. unit-tested and good to go:

  • Faster tensor math (especially for non-contiguous cases)
  • maskedSelect
  • maskedFill
  • TensorApply and TensorReduce kernels

cutorch.withDevice(dstDevice, function() cutorch.synchronize() end)

local t2_max = t2:max()
assert(t2_max == 1, "bad copy, transposeSrc= " .. transposeSrc ..
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this be tester:assert? There are a couple of other places that simply call assert.

@dominikgrewe
Copy link
Member

Just one small comment on the testing code, otherwise this looks good to me!

Thanks a lot for breaking up the monolithic THC.cu. That was overdue :)

…elect and maskedFill operations (and tests).

Also adds generic Reduce and Apply kernels that can be reused.
soumith added a commit that referenced this pull request Apr 1, 2015
adds Sort, revamps TensorMath, adds masked* operations
@soumith soumith merged commit c9b2253 into master Apr 1, 2015
@soumith
Copy link
Member Author

soumith commented Apr 1, 2015

fixed the asserts and squashed. merged. thanks for reviewing.

@soumith soumith deleted the goodies branch April 1, 2015 16:49
@soumith soumith mentioned this pull request Feb 28, 2016
35 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants