From 41006b8b96c7b91ac7175b60322b2c61b3e2af9d Mon Sep 17 00:00:00 2001 From: soumith Date: Mon, 30 Mar 2015 15:06:52 -0700 Subject: [PATCH] revamps TensorMath to remove sync points at many places, adds maskedSelect and maskedFill operations (and tests). Also adds generic Reduce and Apply kernels that can be reused. --- FFI.lua | 1 + Tensor.c | 15 - Tensor.lua | 30 + TensorMath.lua | 20 +- init.c | 10 + lib/THC/CMakeLists.txt | 26 +- lib/THC/THC.cu | 11 - lib/THC/THCApply.cu | 10 + lib/THC/THCApply.cuh | 582 ++++++ lib/THC/THCGeneral.c | 6 + lib/THC/THCGeneral.h | 1 + lib/THC/THCReduce.cuh | 314 ++++ lib/THC/THCReduceApplyUtils.cu | 123 ++ lib/THC/THCReduceApplyUtils.cuh | 236 +++ lib/THC/THCTensorCopy.cu | 395 +--- lib/THC/THCTensorIndex.cu | 240 +++ lib/THC/THCTensorMasked.cu | 125 ++ lib/THC/THCTensorMath.cu | 2210 ++--------------------- lib/THC/THCTensorMath.h | 7 + lib/THC/THCTensorMath2.cu | 569 ++++++ lib/THC/THCTensorMathBlas.cu | 380 ++++ lib/THC/THCTensorMathCompare.cu | 113 ++ lib/THC/THCTensorMathCompareT.cu | 101 ++ lib/THC/THCTensorMathPairwise.cu | 96 + lib/THC/THCTensorMathPointwise.cu | 157 ++ lib/THC/THCTensorMathScan.cu | 213 +++ lib/THC/THCTensorMathTransformReduce.cu | 219 +++ rocks/cutorch-scm-1.rockspec | 4 +- test/test.lua | 271 ++- torch/generic/Tensor.c | 46 +- 30 files changed, 4016 insertions(+), 2515 deletions(-) delete mode 100644 lib/THC/THC.cu create mode 100644 lib/THC/THCApply.cu create mode 100644 lib/THC/THCApply.cuh create mode 100644 lib/THC/THCReduce.cuh create mode 100644 lib/THC/THCReduceApplyUtils.cu create mode 100644 lib/THC/THCReduceApplyUtils.cuh create mode 100644 lib/THC/THCTensorIndex.cu create mode 100644 lib/THC/THCTensorMasked.cu create mode 100644 lib/THC/THCTensorMath2.cu create mode 100644 lib/THC/THCTensorMathBlas.cu create mode 100644 lib/THC/THCTensorMathCompare.cu create mode 100644 lib/THC/THCTensorMathCompareT.cu create mode 100644 lib/THC/THCTensorMathPairwise.cu create mode 100644 lib/THC/THCTensorMathPointwise.cu create mode 100644 lib/THC/THCTensorMathScan.cu create mode 100644 lib/THC/THCTensorMathTransformReduce.cu diff --git a/FFI.lua b/FFI.lua index 8050cad5..e4fdf20d 100644 --- a/FFI.lua +++ b/FFI.lua @@ -7,6 +7,7 @@ typedef struct THCState { struct THCRNGState* rngState; struct THCBlasState* blasState; + struct cudaDeviceProp* deviceProperties; } THCState; typedef struct THCudaStorage diff --git a/Tensor.c b/Tensor.c index 9e5b5b46..f9ff39f0 100644 --- a/Tensor.c +++ b/Tensor.c @@ -5,21 +5,6 @@ /* everything is as the generic Storage.c, except few things (see below) */ -static void THCudaTensor_maskedFill(THCState *state, THCudaTensor *tensor, THByteTensor *mask, float value) -{ - THError("not yet implemented for CUDA"); -} - -static void THCudaTensor_maskedCopy(THCState *state, THCudaTensor *tensor, THByteTensor *mask, THCudaTensor* src) -{ - THError("not yet implemented for CUDA"); -} - -void THCudaTensor_maskedSelect(THCState *state, THCudaTensor *tensor, THCudaTensor* src, THByteTensor *mask) -{ - THError("not yet implemented for CUDA"); -} - #define real float #define Real Cuda diff --git a/Tensor.lua b/Tensor.lua index c2f3c411..b5916134 100644 --- a/Tensor.lua +++ b/Tensor.lua @@ -30,14 +30,44 @@ local function Tensor__float(self,type) return self:type('torch.FloatTensor') end +local function Tensor__byte(self,type) + return self:type('torch.ByteTensor') +end + +local function Tensor__char(self,type) + return self:type('torch.CharTensor') +end + +local function Tensor__int(self,type) + return self:type('torch.IntTensor') +end + +local function Tensor__short(self,type) + return self:type('torch.ShortTensor') +end + +local function Tensor__Long(self,type) + return self:type('torch.LongTensor') +end + rawset(torch.getmetatable('torch.DoubleTensor'), 'cuda', Tensor__cuda) rawset(torch.getmetatable('torch.FloatTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.ByteTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.CharTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.IntTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.ShortTensor'), 'cuda', Tensor__cuda) +rawset(torch.getmetatable('torch.LongTensor'), 'cuda', Tensor__cuda) rawset(torch.getmetatable('torch.CudaTensor'), 'cuda', Tensor__cuda) rawset(torch.getmetatable('torch.CudaTensor'), 'type', Tensor__type) rawset(torch.getmetatable('torch.CudaTensor'), 'typeAs', Tensor__typeAs) rawset(torch.getmetatable('torch.CudaTensor'), 'double', Tensor__double) rawset(torch.getmetatable('torch.CudaTensor'), 'float', Tensor__float) +rawset(torch.getmetatable('torch.CudaTensor'), 'byte', Tensor__byte) +rawset(torch.getmetatable('torch.CudaTensor'), 'char', Tensor__char) +rawset(torch.getmetatable('torch.CudaTensor'), 'int', Tensor__int) +rawset(torch.getmetatable('torch.CudaTensor'), 'short', Tensor__short) +rawset(torch.getmetatable('torch.CudaTensor'), 'long', Tensor__long) do local metatable = torch.getmetatable('torch.CudaTensor') diff --git a/TensorMath.lua b/TensorMath.lua index 008fe6cc..cec46b5a 100644 --- a/TensorMath.lua +++ b/TensorMath.lua @@ -169,7 +169,7 @@ end local function lastdim(argn) return function(arg) - return string.format("THCudaTensor_nDimension(%s)", arg.args[argn]:carg()) + return string.format("THCudaTensor_nDimension(cutorch_getstate(L), %s)", arg.args[argn]:carg()) end end @@ -289,6 +289,24 @@ wrap("addcdiv", {name=Tensor}, {name=Tensor}}) +wrap("maskedFill", + cname("maskedFill"), + {{name=Tensor, returned=true, method={default='nil'}}, + {name=Tensor}, + {name=real}}) + +wrap("maskedCopy", + cname("maskedCopy"), + {{name=Tensor, returned=true, method={default='nil'}}, + {name=Tensor}, + {name=Tensor}}) + +wrap("maskedSelect", + cname("maskedSelect"), + {{name=Tensor, returned=true, default=true}, + {name=Tensor}, + {name=Tensor}}) + do local Tensor = Tensor local real = real diff --git a/init.c b/init.c index 48f3c287..45903389 100644 --- a/init.c +++ b/init.c @@ -40,6 +40,15 @@ static int cutorch_getDeviceCount(lua_State *L) return 1; } +static int cutorch_getMemoryUsage(lua_State *L) { + size_t freeBytes = 0; + size_t totalBytes = 0; + THCudaCheck(cudaMemGetInfo(&freeBytes, &totalBytes)); + lua_pushnumber(L, freeBytes); + lua_pushnumber(L, totalBytes); + return 2; +} + static int cutorch_setDevice(lua_State *L) { THCState *state = cutorch_getstate(L); @@ -158,6 +167,7 @@ static const struct luaL_Reg cutorch_stuff__ [] = { {"deviceReset", cutorch_deviceReset}, {"getDeviceCount", cutorch_getDeviceCount}, {"getDeviceProperties", cutorch_getDeviceProperties}, + {"getMemoryUsage", cutorch_getMemoryUsage}, {"setDevice", cutorch_setDevice}, {"seed", cutorch_seed}, {"seedAll", cutorch_seedAll}, diff --git a/lib/THC/CMakeLists.txt b/lib/THC/CMakeLists.txt index ed597e43..92dfef14 100644 --- a/lib/THC/CMakeLists.txt +++ b/lib/THC/CMakeLists.txt @@ -1,7 +1,28 @@ SET(src THCGeneral.c THCStorage.c THCStorageCopy.c THCTensor.c THCTensorCopy.c) -SET(src-cuda THC.cu) +SET(src-cuda + THCReduceApplyUtils.cu + THCBlas.cu + THCStorage.cu + THCStorageCopy.cu + THCTensor.cu + THCTensorCopy.cu + THCTensorMath2.cu + THCTensorMathBlas.cu + THCTensorMathCompare.cu + THCTensorMathCompareT.cu + THCTensorMath.cu + THCTensorMathPairwise.cu + THCTensorMathPointwise.cu + THCTensorMathScan.cu + THCTensorMathTransformReduce.cu + THCTensorMasked.cu + THCTensorIndex.cu + THCTensorConv.cu + THCTensorRandom.cu + THCApply.cu + ) CUDA_ADD_LIBRARY(THC SHARED ${src} ${src-cuda}) CUDA_ADD_CUBLAS_TO_TARGET(THC) @@ -23,4 +44,7 @@ INSTALL(FILES THCTensorRandom.h THCTensorMath.h THCTensorConv.h + THCApply.cuh + THCReduce.cuh + THCReduceApplyUtils.cuh DESTINATION "${Torch_INSTALL_INCLUDE_SUBDIR}/THC") diff --git a/lib/THC/THC.cu b/lib/THC/THC.cu deleted file mode 100644 index 6b1c2adc..00000000 --- a/lib/THC/THC.cu +++ /dev/null @@ -1,11 +0,0 @@ - -/* thrust library does not allow multiple files */ - -#include "THCBlas.cu" -#include "THCStorage.cu" -#include "THCStorageCopy.cu" -#include "THCTensor.cu" -#include "THCTensorCopy.cu" -#include "THCTensorMath.cu" -#include "THCTensorConv.cu" -#include "THCTensorRandom.cu" diff --git a/lib/THC/THCApply.cu b/lib/THC/THCApply.cu new file mode 100644 index 00000000..3ee9d51e --- /dev/null +++ b/lib/THC/THCApply.cu @@ -0,0 +1,10 @@ +#include "THCApply.cuh" + +// Implementation of copyIgnoringOverlaps, defined after pointwiseApply2. +void THCudaTensor_copyIgnoringOverlaps(THCState* state, + THCudaTensor* dst, + THCudaTensor* src) { + THCudaTensor_pointwiseApply2(state, dst, src, CopyOp(), + ReadOnly, // ignore overwrites + ReadOnly); +} diff --git a/lib/THC/THCApply.cuh b/lib/THC/THCApply.cuh new file mode 100644 index 00000000..130287e8 --- /dev/null +++ b/lib/THC/THCApply.cuh @@ -0,0 +1,582 @@ +#ifndef THC_APPLY_INC +#define THC_APPLY_INC + +#include "THCTensorCopy.h" +#include "THCReduceApplyUtils.cuh" + +// +// This file contains pointwise operation functions and kernels that +// work on both contiguous and non-contiguous tensor arguments of +// arbitrary (up to MAX_CUTORCH_DIMS) dimensioned arguments without +// copying or temporary storage. +// + +// Threads per block for our apply kernel +#define THC_APPLY_THREADS_PER_BLOCK 32 * 16 + +// Called when we are copying into an overlapping index `dst`, but +// we don't care which writer wins. Hacky but it works. +void THCudaTensor_copyIgnoringOverlaps(THCState* state, + THCudaTensor* dst, + THCudaTensor* src); + +template +#if __CUDA_ARCH__ >= 350 +__launch_bounds__(32 * 16, 4) +#endif +__global__ void +THCudaTensor_pointwiseApply1(TensorInfo a, + IndexType totalElements, + Op op) { + for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = + IndexToOffset::get(linearIndex, a); + + op(&a.data[aOffset]); + } +} + +template +#if __CUDA_ARCH__ >= 350 +__launch_bounds__(32 * 16, 4) +#endif +__global__ void +THCudaTensor_pointwiseApply2(TensorInfo a, + TensorInfo b, + IndexType totalElements, + Op op) { + for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = + IndexToOffset::get(linearIndex, a); + + // Convert `linearIndex` into an offset of `b` + const IndexType bOffset = + IndexToOffset::get(linearIndex, b); + + op(&a.data[aOffset], &b.data[bOffset]); + } +} + +template +#if __CUDA_ARCH__ >= 350 +__launch_bounds__(32 * 16, 4) +#endif +__global__ void +THCudaTensor_pointwiseApply3(TensorInfo a, + TensorInfo b, + TensorInfo c, + IndexType totalElements, + Op op) { + for (IndexType linearIndex = blockIdx.x * blockDim.x + threadIdx.x; + linearIndex < totalElements; + linearIndex += gridDim.x * blockDim.x) { + // Convert `linearIndex` into an offset of `a` + const IndexType aOffset = + IndexToOffset::get(linearIndex, a); + + // Convert `linearIndex` into an offset of `b` + const IndexType bOffset = + IndexToOffset::get(linearIndex, b); + + // Convert `linearIndex` into an offset of `c` + const IndexType cOffset = + IndexToOffset::get(linearIndex, c); + + op(&a.data[aOffset], &b.data[bOffset], &c.data[cOffset]); + } +} + +inline dim3 getApplyBlock() { + return dim3(THC_APPLY_THREADS_PER_BLOCK); +} + +inline bool getApplyGrid(THCState* state, long totalElements, dim3& grid) { + int curDevice = -1; + cudaGetDevice(&curDevice); + + if (curDevice == -1) { + return false; + } + + // Assume a reasonable number of SMs if no state is available + int numSM = + state ? state->deviceProperties[curDevice].multiProcessorCount : 15; + + // 16 warps per block * 4 per SM gives 64 warps per SM at maximum, + // which seems to be a good sweetspot for latency hiding + grid = dim3(min(DIVUP(totalElements, (long long) THC_APPLY_THREADS_PER_BLOCK), + 4LL * numSM)); + return true; +} + +template +bool THCudaTensor_pointwiseApply1(THCState* state, + THCudaTensor* a, + const Op& op, + TensorArgType aType = ReadWrite) { + long totalElements = THCudaTensor_nElement(state, a); + + if (THCudaTensor_nDimension(state, a) > MAX_CUTORCH_DIMS) { + return false; + } + + if (THCudaTensor_nDimension(state, a) == 0) { + // Zero-dim tensor; do nothing + return true; + } + + const dim3 block = getApplyBlock(); + + dim3 grid; + if (!getApplyGrid(state, totalElements, grid)) { + return false; + } + + // If tensor args have overlapping indices and are read/write, then + // we must expand the tensor to a contiguous form first, since + // otherwise there are conflicting writes. Upon copying back to the + // non-contiguous form, there will be conflicting writes, but at + // least with copy, one of the updaters will win atomically. This is + // a sketchy property of the old system as well (writing into all + // indices of a tensor with overlapping indices should probably be + // an error, since it is unclear which one should win), but we will + // preserve this last-writer-wins (in arbitrary copy order) behavior. + THCudaTensor* oldA = NULL; + + if (aType == ReadWrite && THC_overlappingIndices(state, a)) { + // Must perform in contiguous space + oldA = a; + a = THCudaTensor_newContiguous(state, a); + } + + // It is possible that the tensor dimensions are able to be collapsed, + // and thus we can reduce the actual code complexity of the copy by + // exploiting this knowledge statically, since the div/mod is the + // most expensive part of the operation, more so than memory accesses. + // For instance, when copying a non-contiguous to a contiguous tensor + // (or vice versa), the contiguous tensor can be collapsed to one + // dimension, and the loop to translate the linear index to the array + // index can be similarly collapsed. That is what this unrolling is for. +#define HANDLE_CASE(TYPE, A) \ + THCudaTensor_pointwiseApply1 \ + <<>>(aInfo, (TYPE) totalElements, op); + +#define HANDLE_A_CASE(TYPE, A) \ + { \ + if (aInfo.isContiguous()) { \ + HANDLE_CASE(TYPE, -2); \ + } else { \ + switch (A) { \ + case 1: \ + HANDLE_CASE(TYPE, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, 2); \ + break; \ + case 3: \ + HANDLE_CASE(TYPE, 3); \ + break; \ + default: \ + HANDLE_CASE(TYPE, -1); \ + break; \ + } \ + } \ + } + + // Can we use 32-bit integer math in the kernel (the linear ID for the copy + // and the resulting non-linear offset is all computable using 32-bit math?) + // We also use unsigned index math in the kernel, as signed div/mod has + // additional overhead. + if (THC_canUse32BitIndexMath(state, a)) { + TensorInfo aInfo(state, a); + + HANDLE_A_CASE(unsigned int, aInfo.dims); + } else { + TensorInfo aInfo(state, a); + + // For large tensors, we only compile the completely contiguous + // version and the completely generic version, to reduce + // compilation time. + if (aInfo.isContiguous()) { + THCudaTensor_pointwiseApply1 + <<>>(aInfo, (unsigned long) totalElements, op); + } else { + THCudaTensor_pointwiseApply1 + <<>>(aInfo, (unsigned long) totalElements, op); + } + } +#undef HANDLE_CASE +#undef HANDLE_A_CASE + + if (oldA) { + // Ignore overlaps when copying back; if we use THCudaTensor_copy + // instead, it will recursively try and invoke ourselves to make + // oldA contiguous. + THCudaTensor_copyIgnoringOverlaps(state, oldA, a); + THCudaTensor_free(state, a); + a = oldA; + } + + return true; +} + +template +bool THCudaTensor_pointwiseApply2(THCState* state, + THCudaTensor* a, + THCudaTensor* b, + const Op& op, + TensorArgType aType = ReadWrite, + TensorArgType bType = ReadOnly) { + long totalElements = THCudaTensor_nElement(state, a); + + if (totalElements != THCudaTensor_nElement(state, b)) { + return false; + } + + if (THCudaTensor_nDimension(state, a) > MAX_CUTORCH_DIMS || + THCudaTensor_nDimension(state, b) > MAX_CUTORCH_DIMS) { + return false; + } + + if (THCudaTensor_nDimension(state, a) == 0) { + // Zero-dim tensor; do nothing + return true; + } + + const dim3 block = getApplyBlock(); + + dim3 grid; + if (!getApplyGrid(state, totalElements, grid)) { + return false; + } + + // If tensor args have overlapping indices and are read/write, then + // we must expand the tensor to a contiguous form first, since + // otherwise there are conflicting writes. Upon copying back to the + // non-contiguous form, there will be conflicting writes, but at + // least with copy, one of the updaters will win atomically. This is + // a sketchy property of the old system as well (writing into all + // indices of a tensor with overlapping indices should probably be + // an error, since it is unclear which one should win), but we will + // preserve this last-writer-wins (in arbitrary copy order) behavior. + THCudaTensor* oldA = NULL; + THCudaTensor* oldB = NULL; + + if (aType == ReadWrite && THC_overlappingIndices(state, a)) { + // Must perform in contiguous space + oldA = a; + a = THCudaTensor_newContiguous(state, a); + } + if (bType == ReadWrite && THC_overlappingIndices(state, b)) { + // Must perform in contiguous space + oldB = b; + b = THCudaTensor_newContiguous(state, b); + } + + // It is possible that the tensor dimensions are able to be collapsed, + // and thus we can reduce the actual code complexity of the copy by + // exploiting this knowledge statically, since the div/mod is the + // most expensive part of the operation, more so than memory accesses. + // For instance, when copying a non-contiguous to a contiguous tensor + // (or vice versa), the contiguous tensor can be collapsed to one + // dimension, and the loop to translate the linear index to the array + // index can be similarly collapsed. That is what this unrolling is for. +#define HANDLE_CASE(TYPE, A, B) \ + THCudaTensor_pointwiseApply2 \ + <<>>(aInfo, bInfo, (TYPE) totalElements, op); \ + +#define HANDLE_B_CASE(TYPE, A, B) \ + { \ + if (bInfo.isContiguous()) { \ + HANDLE_CASE(TYPE, A, -2); \ + } else { \ + switch (B) { \ + case 1: \ + HANDLE_CASE(TYPE, A, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, A, 2); \ + break; \ + case 3: \ + HANDLE_CASE(TYPE, A, 3); \ + break; \ + default: \ + HANDLE_CASE(TYPE, A, -1); \ + break; \ + } \ + } \ + } + +#define HANDLE_A_CASE(TYPE, A, B) \ + { \ + if (aInfo.isContiguous()) { \ + HANDLE_B_CASE(TYPE, -2, B); \ + } else { \ + switch (A) { \ + case 1: \ + HANDLE_B_CASE(TYPE, 1, B); \ + break; \ + case 2: \ + HANDLE_B_CASE(TYPE, 2, B); \ + break; \ + case 3: \ + HANDLE_B_CASE(TYPE, 3, B); \ + break; \ + default: \ + HANDLE_B_CASE(TYPE, -1, B); \ + break; \ + } \ + } \ + } + + if (THC_canUse32BitIndexMath(state, a) && + THC_canUse32BitIndexMath(state, b)) { + TensorInfo aInfo(state, a); + TensorInfo bInfo(state, b); + + HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims); + } else { + TensorInfo aInfo(state, a); + TensorInfo bInfo(state, b); + + // For large tensors, we only compile the completely contiguous + // version and the completely generic version, to reduce + // compilation time. + if (aInfo.isContiguous() && bInfo.isContiguous()) { + THCudaTensor_pointwiseApply2 + <<>>(aInfo, bInfo, (unsigned long) totalElements, op); + } else { + THCudaTensor_pointwiseApply2 + <<>>(aInfo, bInfo, (unsigned long) totalElements, op); + } + } +#undef HANDLE_CASE +#undef HANDLE_B_CASE +#undef HANDLE_A_CASE + + if (oldA) { + // Ignore overlaps when copying back; if we use THCudaTensor_copy + // instead, it will recursively try and invoke ourselves to make + // oldA contiguous. + THCudaTensor_copyIgnoringOverlaps(state, oldA, a); + THCudaTensor_free(state, a); + a = oldA; + } + + if (oldB) { + // Ignore overlaps when copying back; if we use THCudaTensor_copy + // instead, it will recursively try and invoke ourselves to make + // oldB contiguous. + THCudaTensor_copyIgnoringOverlaps(state, oldB, b); + THCudaTensor_free(state, b); + b = oldB; + } + + return true; +} + +template +bool THCudaTensor_pointwiseApply3(THCState* state, + THCudaTensor* a, + THCudaTensor* b, + THCudaTensor* c, + const Op& op, + TensorArgType aType = ReadWrite, + TensorArgType bType = ReadOnly, + TensorArgType cType = ReadOnly) { + long totalElements = THCudaTensor_nElement(state, a); + + if (totalElements != THCudaTensor_nElement(state, b) || + totalElements != THCudaTensor_nElement(state, c)) { + return false; + } + + if (THCudaTensor_nDimension(state, a) > MAX_CUTORCH_DIMS || + THCudaTensor_nDimension(state, b) > MAX_CUTORCH_DIMS || + THCudaTensor_nDimension(state, c) > MAX_CUTORCH_DIMS) { + return false; + } + + if (THCudaTensor_nDimension(state, a) == 0) { + // Zero-dim tensor; do nothing + return true; + } + + const dim3 block = getApplyBlock(); + + dim3 grid; + if (!getApplyGrid(state, totalElements, grid)) { + return false; + } + + // If tensor args have overlapping indices and are read/write, then + // we must expand the tensor to a contiguous form first, since + // otherwise there are conflicting writes. Upon copying back to the + // non-contiguous form, there will be conflicting writes, but at + // least with copy, one of the updaters will win atomically. This is + // a sketchy property of the old system as well (writing into all + // indices of a tensor with overlapping indices should probably be + // an error, since it is unclear which one should win), but we will + // preserve this last-writer-wins (in arbitrary copy order) behavior. + THCudaTensor* oldA = NULL; + THCudaTensor* oldB = NULL; + THCudaTensor* oldC = NULL; + + if (aType == ReadWrite && THC_overlappingIndices(state, a)) { + // Must perform in contiguous space + oldA = a; + a = THCudaTensor_newContiguous(state, a); + } + + if (bType == ReadWrite && THC_overlappingIndices(state, b)) { + // Must perform in contiguous space + oldB = b; + b = THCudaTensor_newContiguous(state, b); + } + + if (cType == ReadWrite && THC_overlappingIndices(state, c)) { + // Must perform in contiguous space + oldC = c; + c = THCudaTensor_newContiguous(state, c); + } + +#define HANDLE_CASE(TYPE, A, B, C) \ + THCudaTensor_pointwiseApply3 \ + <<>>(aInfo, bInfo, cInfo, (TYPE) totalElements, op); \ + +#define HANDLE_C_CASE(TYPE, A, B, C) \ + { \ + if (cInfo.isContiguous()) { \ + HANDLE_CASE(TYPE, A, B, -2); \ + } else { \ + switch (C) { \ + case 1: \ + HANDLE_CASE(TYPE, A, B, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, A, B, 2); \ + break; \ + case 3: \ + HANDLE_CASE(TYPE, A, B, 3); \ + break; \ + default: \ + HANDLE_CASE(TYPE, A, B, -1); \ + break; \ + } \ + } \ + } + +#define HANDLE_B_CASE(TYPE, A, B, C) \ + { \ + if (bInfo.isContiguous()) { \ + HANDLE_C_CASE(TYPE, A, -2, C); \ + } else { \ + switch (B) { \ + case 1: \ + HANDLE_C_CASE(TYPE, A, 1, C); \ + break; \ + case 2: \ + HANDLE_C_CASE(TYPE, A, 2, C); \ + break; \ + case 3: \ + HANDLE_C_CASE(TYPE, A, 3, C); \ + break; \ + default: \ + HANDLE_C_CASE(TYPE, A, -1, C); \ + break; \ + } \ + } \ + } + +#define HANDLE_A_CASE(TYPE, A, B, C) \ + { \ + if (aInfo.isContiguous()) { \ + HANDLE_B_CASE(TYPE, -2, B, C); \ + } else { \ + switch (A) { \ + case 1: \ + HANDLE_B_CASE(TYPE, 1, B, C); \ + break; \ + case 2: \ + HANDLE_B_CASE(TYPE, 2, B, C); \ + break; \ + case 3: \ + HANDLE_B_CASE(TYPE, 3, B, C); \ + break; \ + default: \ + HANDLE_B_CASE(TYPE, -1, B, C); \ + break; \ + } \ + } \ + } + + if (THC_canUse32BitIndexMath(state, a) && + THC_canUse32BitIndexMath(state, b) && + THC_canUse32BitIndexMath(state, c)) { + TensorInfo aInfo(state, a); + TensorInfo bInfo(state, b); + TensorInfo cInfo(state, c); + + HANDLE_A_CASE(unsigned int, aInfo.dims, bInfo.dims, cInfo.dims); + } else { + TensorInfo aInfo(state, a); + TensorInfo bInfo(state, b); + TensorInfo cInfo(state, c); + + // For large tensors, we only compile the completely contiguous + // version and the completely generic version, to reduce + // compilation time. + if (aInfo.isContiguous() && bInfo.isContiguous() && cInfo.isContiguous()) { + THCudaTensor_pointwiseApply3 + <<>>(aInfo, bInfo, cInfo, + (unsigned long) totalElements, op); + } else { + THCudaTensor_pointwiseApply3 + <<>>(aInfo, bInfo, cInfo, + (unsigned long) totalElements, op); + } + } +#undef HANDLE_CASE +#undef HANDLE_C_CASE +#undef HANDLE_B_CASE +#undef HANDLE_A_CASE + + if (oldA) { + // Ignore overlaps when copying back; if we use THCudaTensor_copy + // instead, it will recursively try and invoke ourselves to make + // oldA contiguous. + THCudaTensor_copyIgnoringOverlaps(state, oldA, a); + THCudaTensor_free(state, a); + a = oldA; + } + + if (oldB) { + // Ignore overlaps when copying back; if we use THCudaTensor_copy + // instead, it will recursively try and invoke ourselves to make + // oldB contiguous. + THCudaTensor_copyIgnoringOverlaps(state, oldB, b); + THCudaTensor_free(state, b); + b = oldB; + } + + if (oldC) { + // Ignore overlaps when copying back; if we use THCudaTensor_copy + // instead, it will recursively try and invoke ourselves to make + // oldC contiguous. + THCudaTensor_copyIgnoringOverlaps(state, oldC, c); + THCudaTensor_free(state, c); + c = oldC; + } + + return true; +} + +#undef THC_APPLY_THREADS_PER_BLOCK + +#endif // THC_APPLY_INC diff --git a/lib/THC/THCGeneral.c b/lib/THC/THCGeneral.c index 8acd499b..8ce4cbc1 100644 --- a/lib/THC/THCGeneral.c +++ b/lib/THC/THCGeneral.c @@ -17,10 +17,15 @@ void THCudaInit(THCState* state) state->blasState = (THCBlasState*)malloc(sizeof(THCBlasState)); THCudaBlas_init(state, count, device); + state->deviceProperties = + (struct cudaDeviceProp*)malloc(count * sizeof(struct cudaDeviceProp)); + int i,j; for(i=0; i < count; ++i) { THCudaCheck(cudaSetDevice(i)); + THCudaCheck(cudaGetDeviceProperties(&state->deviceProperties[i], i)); + for (j=0; j < count; ++j) { if(i != j) @@ -50,6 +55,7 @@ void THCudaShutdown(THCState* state) THCRandom_shutdown(state); free(state->blasState); free(state->rngState); + free(state->deviceProperties); THCudaBlas_shutdown(state); } diff --git a/lib/THC/THCGeneral.h b/lib/THC/THCGeneral.h index 557b9088..1538267b 100644 --- a/lib/THC/THCGeneral.h +++ b/lib/THC/THCGeneral.h @@ -32,6 +32,7 @@ typedef struct THCState { struct THCRNGState* rngState; struct THCBlasState* blasState; + struct cudaDeviceProp* deviceProperties; } THCState; THC_API void THCudaBlas_init(THCState *state, int num_devices, int current_device); diff --git a/lib/THC/THCReduce.cuh b/lib/THC/THCReduce.cuh new file mode 100644 index 00000000..b31a5a86 --- /dev/null +++ b/lib/THC/THCReduce.cuh @@ -0,0 +1,314 @@ +#ifndef THC_REDUCE_INC +#define THC_REDUCE_INC + +// +// This file contains dimension reduction operation functions and +// kernels that work on both contiguous and non-contiguous tensor +// arguments of arbitrary (up to MAX_CUTORCH_DIMS) dimensioned +// arguments without copying or temporary storage. +// + +#include "THCReduceApplyUtils.cuh" + +// Threads per thread block +#define THC_NONCONTIG_REDUCE_BLOCK_SIZE 32 * 16 + +template +__device__ __forceinline__ IndexType getReduceNoncontigDimSliceIndex() { + // Each thread handles one slice + return getLinearBlockId() * THC_NONCONTIG_REDUCE_BLOCK_SIZE + threadIdx.x; +} + +// Kernel that handles an entire reduction of a slice of a tensor per each thread +template +#if __CUDA_ARCH__ >= 350 +__launch_bounds__(32 * 16, 4) +#endif +__global__ void +THCudaTensor_reduceNoncontigDim(TensorInfo out, + TensorInfo in, + IndexType reductionStride, + IndexType reductionSize, + IndexType totalSlices, + float init, + ModifyOp modifyOp, + ReduceOp reduceOp) { + const IndexType sliceIndex = getReduceNoncontigDimSliceIndex(); + + if (sliceIndex >= totalSlices) { + return; + } + + // Each thread picks a point in `out` and `in` for which it is + // producing the reduction + const IndexType outOffset = + IndexToOffset::get(sliceIndex, out); + const IndexType inBaseOffset = + IndexToOffset::get(sliceIndex, in); + + // For each point in reductionSize, reduce into `r` + IndexType inOffset = inBaseOffset; + float r = init; + + for (IndexType i = 0; i < reductionSize; ++i) { + r = reduceOp(r, modifyOp(in.data[inOffset])); + inOffset += reductionStride; + } + + // Write out reduced value + out.data[outOffset] = r; +} + +template +__device__ __forceinline__ IndexType getReduceContigDimSliceIndex() { + // Each block handles one slice + return getLinearBlockId(); +} + +// Kernel that handles an entire reduction of a slice of a tensor per +// each block +template +__global__ void +THCudaTensor_reduceContigDim(TensorInfo out, + TensorInfo in, + IndexType reductionSize, + IndexType totalSlices, + float init, + ModifyOp modifyOp, + ReduceOp reduceOp) { + const IndexType sliceIndex = getReduceContigDimSliceIndex(); + + if (sliceIndex >= totalSlices) { + return; + } + + // Get the offset in `out` for the reduction + const IndexType outOffset = + IndexToOffset::get(sliceIndex, out); + + // Get the base offset in `in` for this block's reduction + const IndexType inBaseOffset = + IndexToOffset::get(sliceIndex, in); + + // Each thread in the block will reduce some subset of elements in + // the slice. The elements are guaranteed contiguous starting at + // `inBaseOffset`. + float r = init; + for (IndexType i = threadIdx.x; i < reductionSize; i += blockDim.x) { + r = reduceOp(r, modifyOp(in.data[inBaseOffset + i])); + } + + // Reduce within the block + extern __shared__ float smem[]; + smem[threadIdx.x] = r; + + // First warp will perform reductions across warps + __syncthreads(); + if ((threadIdx.x / 32) == 0) { + r = init; + for (int i = 0; i < blockDim.x; i += 32) { + r = reduceOp(r, smem[i + threadIdx.x]); + } + + // Each lane participating writes out a value + smem[threadIdx.x] = r; + } + + // First thread will perform reductions across the block + __syncthreads(); + if (threadIdx.x == 0) { + r = init; +#pragma unroll + for (int i = 0; i < 32; ++i) { + r = reduceOp(r, smem[i]); + } + + // Write out reduced value + out.data[outOffset] = r; + } +} + +inline dim3 getNoncontigReduceBlock() { + return dim3(THC_NONCONTIG_REDUCE_BLOCK_SIZE); +} + +inline dim3 getContigReduceBlock(long numSlices, long reductionSize) { + // If the number of slices is low but the reduction dimension size + // is high, then we should increase block size for greater parallelism. + // Aim for at least 32 warps per SM (assume 15 SMs; don't bother + // inquiring the real number for now). + int maxWarps = 4; // better occupancy if many blocks are around + // For numSlices > 15 * 8, there are > 32 warps active per SM. + if (numSlices < 15 * 8) { + maxWarps = 8; + if (numSlices < 15 * 4) { + maxWarps = 16; + if (numSlices < 15 * 2) { + maxWarps = 32; + } + } + } + + // Scale up block size based on the reduction dimension size + long warpsInReductionSize = DIVUP(reductionSize, 32L); + int numWarps = + warpsInReductionSize > (long) maxWarps ? maxWarps : (int) warpsInReductionSize; + return dim3(numWarps * 32); +} + +inline bool getNoncontigReduceGrid(long elements, dim3& grid) { + // One output point per thread + return THC_getGridFromTiles(DIVUP(elements, THC_NONCONTIG_REDUCE_BLOCK_SIZE), grid); +} + +inline bool getContigReduceGrid(long elements, dim3& grid) { + // One output point per block + return THC_getGridFromTiles(elements, grid); +} + +// Performs a reduction out[..., 0, ...] = reduce_i(modify(in[..., i, ...])) for +// all in where i and the out's 0 are indexed at dimension `dim` +template +bool THCudaTensor_reduceDim(THCState* state, + THCudaTensor* out, + THCudaTensor* in, + const ModifyOp& modifyOp, + const ReduceOp& reduceOp, + float init, + int dim) { + long inElements = THCudaTensor_nElement(state, in); + + long reductionSize = THCudaTensor_size(state, in, dim); + long reductionStride = THCudaTensor_stride(state, in, dim); + long outElements = inElements / reductionSize; + + if (THCudaTensor_nDimension(state, out) > MAX_CUTORCH_DIMS || + THCudaTensor_nDimension(state, in) > MAX_CUTORCH_DIMS) { + return false; + } + + if (THCudaTensor_nDimension(state, in) == 0) { + // Zero-dim tensor; do nothing + return true; + } + + // Is the reduction dimension contiguous? If so, then we can use a + // shared memory reduction kernel to increase performance. + bool contigReduction = (reductionStride == 1); + + dim3 block; + dim3 grid; + int smemSize = 0; // contiguous reduction uses smem + if (contigReduction) { + if (!getContigReduceGrid(outElements, grid)) { + return false; + } + + block = getContigReduceBlock(outElements, reductionSize); + smemSize = sizeof(float) * block.x; + } else { + if (!getNoncontigReduceGrid(outElements, grid)) { + return false; + } + + block = getNoncontigReduceBlock(); + } + + // Resize out to correspond to the reduced size + THLongStorage* sizes = THCudaTensor_newSizeOf(state, in); + THLongStorage_set(sizes, dim, 1); + THCudaTensor_resize(state, out, sizes, NULL); + THLongStorage_free(sizes); + + // It is possible that the tensor dimensions are able to be collapsed, + // and thus we can reduce the actual code complexity of the copy by + // exploiting this knowledge statically, since the div/mod is the + // most expensive part of the operation, more so than memory accesses. + // For instance, when copying a non-contiguous to a contiguous tensor + // (or vice versa), the contiguous tensor can be collapsed to one + // dimension, and the loop to translate the linear index to the array + // index can be similarly collapsed. That is what this unrolling is for. +#define HANDLE_CASE(TYPE, OUT, IN) \ + if (contigReduction) { \ + THCudaTensor_reduceContigDim \ + <<>>(outInfo, inInfo, reductionSize, \ + (TYPE) outElements, init, modifyOp, reduceOp); \ + } else { \ + THCudaTensor_reduceNoncontigDim \ + <<>>(outInfo, inInfo, reductionStride, reductionSize, \ + (TYPE) outElements, init, modifyOp, reduceOp); \ + } \ + +#define HANDLE_IN_CASE(TYPE, OUT, IN) \ + { \ + if (inInfo.isContiguous()) { \ + HANDLE_CASE(TYPE, OUT, -2); \ + } else { \ + switch (IN) { \ + case 1: \ + HANDLE_CASE(TYPE, OUT, 1); \ + break; \ + case 2: \ + HANDLE_CASE(TYPE, OUT, 2); \ + break; \ + case 3: \ + HANDLE_CASE(TYPE, OUT, 3); \ + break; \ + default: \ + HANDLE_CASE(TYPE, OUT, -1); \ + break; \ + } \ + } \ + } + +#define HANDLE_OUT_CASE(TYPE, OUT, IN) \ + { \ + if (outInfo.isContiguous()) { \ + HANDLE_IN_CASE(TYPE, -2, IN); \ + } else { \ + switch (OUT) { \ + case 1: \ + HANDLE_IN_CASE(TYPE, 1, IN); \ + break; \ + case 2: \ + HANDLE_IN_CASE(TYPE, 2, IN); \ + break; \ + case 3: \ + HANDLE_IN_CASE(TYPE, 3, IN); \ + break; \ + default: \ + HANDLE_IN_CASE(TYPE, -1, IN); \ + break; \ + } \ + } \ + } + + if (THC_canUse32BitIndexMath(state, out) && + THC_canUse32BitIndexMath(state, in)) { + TensorInfo outInfo(state, out); + TensorInfo inInfo(state, in, dim); + + HANDLE_OUT_CASE(unsigned int, outInfo.dims, inInfo.dims); + } else { + TensorInfo outInfo(state, out); + TensorInfo inInfo(state, in, dim); + + // For large tensors, we only compile the completely contiguous + // version and the completely generic version, to reduce + // compilation time. + if (outInfo.isContiguous() && inInfo.isContiguous()) { + HANDLE_CASE(unsigned long, -2, -2); + } else { + HANDLE_CASE(unsigned long, -1, -1); + } + } +#undef HANDLE_CASE +#undef HANDLE_B_CASE +#undef HANDLE_A_CASE + + return true; +} + +#undef THC_NONCONTIG_REDUCE_BLOCK_SIZE + +#endif // THC_REDUCE_INC diff --git a/lib/THC/THCReduceApplyUtils.cu b/lib/THC/THCReduceApplyUtils.cu new file mode 100644 index 00000000..c5c05f86 --- /dev/null +++ b/lib/THC/THCReduceApplyUtils.cu @@ -0,0 +1,123 @@ +#include "THCReduceApplyUtils.cuh" + +#include +#include + +// Maximum size per grid dimension that we assume (compute capability >= 2.0) +#define MAX_GRID_SIZE 65535L + +bool THC_canUse32BitIndexMath(THCState* state, THCudaTensor* t) { + long elements = THCudaTensor_nElement(state, t); + if (elements >= UINT_MAX) { + return false; + } + + long offset = 0; + long linearId = elements - 1; + + for (int i = THCudaTensor_nDimension(state, t) - 1; i >= 0; --i) { + long curDimIndex = linearId % THCudaTensor_size(state, t, i); + long curDimOffset = curDimIndex * THCudaTensor_stride(state, t, i); + offset += curDimOffset; + linearId /= THCudaTensor_size(state, t, i); + } + + if (offset >= UINT_MAX) { + return false; + } + + return true; +} + +bool THC_getGridFromTiles(long gridTiles, dim3& grid) { + if (gridTiles > MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE) { + return false; + } + + long gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + long gridY = 1; + long gridZ = 1; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = DIVUP(gridTiles, MAX_GRID_SIZE); + gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + + if (gridTiles > MAX_GRID_SIZE) { + gridTiles = DIVUP(gridTiles, MAX_GRID_SIZE); + gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + } + } + + grid = dim3(gridX, gridY, gridZ); + return true; +} + +namespace { + +struct SizeAndStride { + long size; + long stride; +}; + +int compareSizeAndStride(const void* a, const void* b) { + const SizeAndStride* aS = (const SizeAndStride*) a; + const SizeAndStride* bS = (const SizeAndStride*) b; + + return aS->stride < bS->stride; +} + +} + +bool THC_overlappingIndices(THCState* state, THCudaTensor* t) { + // In this function, we don't care about permutations of the + // size/stride arrays (transpositions). + // We order the size/stride arrays by stride, skipping dimensions of + // size 1. Strides of dimensions of size 1 don't matter, since there + // is only one addressing point in them. + // In this reordered view, the tensor is contiguous if + // stride[dim] == size[dim + 1] * stride[dim + 1] for all `dim`. + // The tensor has holes if + // stride[dim] > size[dim + 1] * stride[dim + 1] for one or more + // `dim`. + // The tensor has overlaps if + // stride[dim] < size[dim + 1] * stride[dim + 1] for one or more + // `dim`, or the innermost stride is 0. + + // Extract size/stride arrays; only consider size >1 dims. + SizeAndStride info[MAX_CUTORCH_DIMS]; + + int dims = THCudaTensor_nDimension(state, t); + int nonSize1Dims = 0; + for (int i = 0; i < dims; ++i) { + long size = THCudaTensor_size(state, t, i); + if (size > 1) { + info[nonSize1Dims].size = size; + info[nonSize1Dims].stride = THCudaTensor_stride(state, t, i); + ++nonSize1Dims; + } + } + + if (nonSize1Dims == 0) { + // no overlap + return false; + } + + // Ascending order (innermost dimension in sorted view is at [0]) + qsort(info, nonSize1Dims, sizeof(SizeAndStride), compareSizeAndStride); + + // Base case: innermost dimension must have stride >= 1 + if (info[nonSize1Dims - 1].stride < 1) { + return true; + } + + // Subsequent dimensions, if any + for (int i = nonSize1Dims - 2; i >= 0; --i) { + if (info[i].stride < info[i + 1].size * info[i + 1].stride) { + // There are overlaps + return true; + } + } + + // Tensor has holes or is contiguous + return false; +} diff --git a/lib/THC/THCReduceApplyUtils.cuh b/lib/THC/THCReduceApplyUtils.cuh new file mode 100644 index 00000000..70304191 --- /dev/null +++ b/lib/THC/THCReduceApplyUtils.cuh @@ -0,0 +1,236 @@ +#ifndef THC_REDUCE_APPLY_UTILS_INC +#define THC_REDUCE_APPLY_UTILS_INC + +#include +#include +#include "THGeneral.h" +#include "THCGeneral.h" +#include "THCTensor.h" + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +// Maximum number of dimensions allowed for cutorch +#define MAX_CUTORCH_DIMS 25 + +// Warning string for tensor arguments that are too large or have too +// many dimensions +#define CUTORCH_STR(X) #X +#define CUTORCH_DIM_WARNING "tensor too large or too many (>" \ + CUTORCH_STR(MAX_CUTORCH_DIMS) ") dimensions" + +// Enum that indicates whether tensor arguments are read/write or +// read-only +enum TensorArgType { ReadWrite, ReadOnly }; + +// Copy operator for the pointwise apply kernel +template +struct CopyOp { + __device__ __forceinline__ void operator()(T* dst, T* src) { +#if __CUDA_ARCH__ >= 350 + *dst = __ldg(src); +#else + *dst = *src; +#endif + } +}; + +// CUDA kernel argument that defines tensor layout +template +struct TensorInfo { + // Extracts size/stride information for the kernel. + // Successive dimensions can be collapsed if the size/strides match + // up and thus there are no holes between the dimensions. This is used + // to reduce the complexity of the problem. + // The optional `reduceDim` indicates a reduction dimension for the + // given tensor, so that the output size for this dimension will be 1. + TensorInfo(THCState* state, THCudaTensor* t, int reduceDim = -1); + + // Contiguous tensors of more than one dimension are collapsed down + // to one tensor + __host__ __device__ inline bool isContiguous() const { + return (dims == 1 && strides[0] == 1); + } + + float* data; + IndexType sizes[MAX_CUTORCH_DIMS]; + IndexType strides[MAX_CUTORCH_DIMS]; + int dims; +}; + +template +TensorInfo::TensorInfo(THCState* state, + THCudaTensor* t, + int reduceDim) + : data(NULL), dims(0) { + int origDims = THCudaTensor_nDimension(state, t); + assert(origDims <= MAX_CUTORCH_DIMS); + assert(reduceDim < origDims); + + data = THCudaTensor_data(state, t); + + // Count the number of successive dimensions that can be collapsed, from + // innermost to outermost. + int numCollapsed = 0; + + // Find the innermost dimension not of size 1, since dimensions of size 1 are + // collapsible. + int firstNonOneDim = -1; + + for (int i = origDims - 1; i >= 0; --i) { + if (THCudaTensor_size(state, t, i) != 1 && i != reduceDim) { + firstNonOneDim = i; + break; + } + } + + // Special case: if all dimensions are of size 1, then this is a + // single-point tensor that we still have to operate on. Reduce to a + // single point. + if (firstNonOneDim == -1) { + dims = 1; + sizes[0] = 1; + strides[0] = 1; + return; + } + + // Skip the leading size 1 dims + numCollapsed += origDims - 1 - firstNonOneDim; + + // Now, to determine the other collapsible dims. These are the size/strides + // of the previous inner non-collapsible dim we encounter. + long sizeInner = THCudaTensor_size(state, t, firstNonOneDim); + long strideInner = THCudaTensor_stride(state, t, firstNonOneDim); + + for (int i = firstNonOneDim - 1; i >= 0; --i) { + long sizeOuter = (i == reduceDim) ? 1 : THCudaTensor_size(state, t, i); + long strideOuter = THCudaTensor_stride(state, t, i); + + // The next outermost dimension can be skipped if size 1 + if (sizeOuter == 1) { + ++numCollapsed; + continue; + } + + // If the next outermost dimension is contiguous with the + // previous non-collapsed one, collapse it + if (strideOuter == strideInner * sizeInner) { + ++numCollapsed; + + // This is the run of collapsed dimensions' size + sizeInner = sizeInner * sizeOuter; + continue; + } + + // Otherwise, this new outer dimension at `i` cannot be collapsed + // and is different from the previous. + sizeInner = sizeOuter; + strideInner = strideOuter; + } + + assert(numCollapsed < origDims); + dims = origDims - numCollapsed; + + // Determine the sizes of the collapsed dimensions. + int collapsedIndex = origDims - numCollapsed - 1; + sizes[collapsedIndex] = THCudaTensor_size(state, t, firstNonOneDim); + strides[collapsedIndex] = THCudaTensor_stride(state, t, firstNonOneDim); + + for (int i = firstNonOneDim - 1; i >= 0; --i) { + long sizeOuter = (i == reduceDim) ? 1 : THCudaTensor_size(state, t, i); + long strideOuter = THCudaTensor_stride(state, t, i); + + if (sizeOuter == 1) { + // skip + continue; + } + + if (strideOuter == sizes[collapsedIndex] * strides[collapsedIndex]) { + // collapse + sizes[collapsedIndex] *= sizeOuter; + continue; + } + + // Otherwise, strides don't match; dimension `i` is not collapsible. + --collapsedIndex; + assert(collapsedIndex >= 0); + sizes[collapsedIndex] = sizeOuter; + strides[collapsedIndex] = strideOuter; + } + + // We must have filled all the dimensions we're looking for + assert(collapsedIndex == 0); +} + +// Translate a linear index for the apply to a float* offset; +// specialized on `Dims` to reduce nvcc compilation time +template +struct IndexToOffset { + static __host__ __device__ IndexType get( + IndexType linearId, + const TensorInfo& info) { + IndexType offset = 0; + + // Use static dims + for (int i = Dims - 1; i >= 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + + if (i > 0) { + linearId /= info.sizes[i]; + } + } + + return offset; + } +}; + +template +struct IndexToOffset { + static __forceinline__ __host__ __device__ IndexType + get(IndexType linearId, const TensorInfo& info) { + return linearId; + } +}; + +template +struct IndexToOffset { + static __forceinline__ __host__ __device__ IndexType + get(IndexType linearId, const TensorInfo& info) { + IndexType offset = 0; + + // Use dynamic dims + for (int i = info.dims - 1; i >= 0; --i) { + IndexType curDimIndex = linearId % info.sizes[i]; + IndexType curDimOffset = curDimIndex * info.strides[i]; + offset += curDimOffset; + + linearId /= info.sizes[i]; + } + + return offset; + } +}; + +template +__device__ __forceinline__ IndexType getLinearBlockId() { + return blockIdx.z * gridDim.y * gridDim.x + + blockIdx.y * gridDim.x + + blockIdx.x; +} + +// Returns true if all linear ID -> offset math can be performed using 32 bit +// unsigned math, which is faster than 64 bit math +bool THC_canUse32BitIndexMath(THCState* state, THCudaTensor* t); + +// Produces a grid with at least one point per tile +bool THC_getGridFromTiles(long gridTiles, dim3& grid); + +// Determines if the given tensor has overlapping data points (i.e., +// is there more than one index into the tensor that references the +// same piece of data)? +bool THC_overlappingIndices(THCState* state, THCudaTensor* t); + +#endif // THC_REDUCE_APPLY_UTILS_INC diff --git a/lib/THC/THCTensorCopy.cu b/lib/THC/THCTensorCopy.cu index 71b28e8d..55b7be0b 100644 --- a/lib/THC/THCTensorCopy.cu +++ b/lib/THC/THCTensorCopy.cu @@ -1,254 +1,18 @@ -#include "THGeneral.h" -#include "THCGeneral.h" -#include "THCTensor.h" -#include +#include "THCApply.cuh" -#ifndef DIVUP -#define DIVUP(x, y) (((x) + (y) - 1) / (y)) -#endif - -// backward-compatible LDG -#if __CUDA_ARCH__ >= 350 -#define LDG(x) (__ldg(x)) -#else -#define LDG(x) (*(x)) -#endif - -// Maximum elements per thread that we will copy -#define ELEMENTS_PER_THREAD 8L - -// Threads per thread block -#define THREADS_PER_BLOCK 32 * 4 - -// Maximum size per grid dimension that we assume (compute capability >= 2.0) -#define MAX_GRID_SIZE 65535L - -// Maximum number of dimensions allowed for cutorch -#define MAX_DIMS 25 - -template -struct TensorInfo { - float* data; - IndexType sizes[MAX_DIMS]; - IndexType strides[MAX_DIMS]; - int dims; -}; - -// This function extracts size/stride information for the kernel. -// Successive dimensions can be collapsed if the size/strides match -// up and thus there are no holes between the dimensions. This is used -// to reduce the complexity of the problem. -template -TensorInfo -THCudaTensor_computeTensorInfo(THCState *state, THCudaTensor* t) { - int dims = THCudaTensor_nDimension(state, t); - assert(dims <= MAX_DIMS); - - TensorInfo info; - info.data = THCudaTensor_data(state, t); - - // Count the number of successive dimensions that can be collapsed, from - // innermost to outermost. - int numCollapsed = 0; - - // Find the innermost dimension not of size 1, since dimensions of size 1 are - // collapsible. - int firstNonOneDim = -1; - - for (int i = dims - 1; i >= 0; --i) { - if (THCudaTensor_size(state, t, i) != 1) { - firstNonOneDim = i; - break; - } - } - - // We guarantee that we are never called with only dimensions of size 1. - assert(firstNonOneDim >= 0); - - // Skip the leading size 1 dims - numCollapsed += dims - 1 - firstNonOneDim; - - // Now, to determine the other collapsible dims. These are the size/strides - // of the previous inner non-collapsible dim we encounter. - long sizeInner = THCudaTensor_size(state, t, firstNonOneDim); - long strideInner = THCudaTensor_stride(state, t, firstNonOneDim); - - for (int i = firstNonOneDim - 1; i >= 0; --i) { - long sizeOuter = THCudaTensor_size(state, t, i); - long strideOuter = THCudaTensor_stride(state, t, i); - - // The next outermost dimension can be skipped if size 1 - if (sizeOuter == 1) { - ++numCollapsed; - continue; - } - - // If the next outermost dimension is contiguous with the - // previous non-collapsed one, collapse it - if (strideOuter == strideInner * sizeInner) { - ++numCollapsed; - - // This is the run of collapsed dimensions' size - sizeInner = sizeInner * sizeOuter; - continue; - } - - // Otherwise, this new outer dimension at `i` cannot be collapsed - // and is different from the previous. - sizeInner = sizeOuter; - strideInner = strideOuter; - } - - assert(numCollapsed < dims); - info.dims = dims - numCollapsed; - - // Determine the sizes of the collapsed dimensions. - int collapsedIndex = dims - numCollapsed - 1; - info.sizes[collapsedIndex] = THCudaTensor_size(state, t, firstNonOneDim); - info.strides[collapsedIndex] = THCudaTensor_stride(state, t, firstNonOneDim); - - for (int i = firstNonOneDim - 1; i >= 0; --i) { - long sizeOuter = THCudaTensor_size(state, t, i); - long strideOuter = THCudaTensor_stride(state, t, i); - - if (sizeOuter == 1) { - // skip - continue; - } - - if (strideOuter == - info.sizes[collapsedIndex] * info.strides[collapsedIndex]) { - // collapse - info.sizes[collapsedIndex] *= sizeOuter; - continue; - } - - // Otherwise, strides don't match; dimension `i` is not collapsible. - --collapsedIndex; - assert(collapsedIndex >= 0); - info.sizes[collapsedIndex] = sizeOuter; - info.strides[collapsedIndex] = strideOuter; - } - - // We must have filled all the dimensions we're looking for - assert(collapsedIndex == 0); - - // Fill out the remainder dims for sanity. - for (int i = dims - numCollapsed; i < MAX_DIMS; ++i) { - info.sizes[i] = 1; - info.strides[i] = info.strides[dims - numCollapsed - 1] * - info.sizes[dims - numCollapsed - 1]; - } - - return info; -} - -// Returns true if all linear ID -> offset math can be performed using 32 bit -// unsigned math -bool -canUse32BitCopyMath(THCState *state, THCudaTensor* t) { - long elements = THCudaTensor_nElement(state, t); - if (elements >= UINT_MAX) { - return false; - } - - long offset = 0; - long linearId = elements - 1; - - for (int i = THCudaTensor_nDimension(state, t) - 1; i >= 0; --i) { - long curDimIndex = linearId % THCudaTensor_size(state, t, i); - long curDimOffset = curDimIndex * THCudaTensor_stride(state, t, i); - offset += curDimOffset; - linearId /= THCudaTensor_size(state, t, i); - } - - if (offset >= UINT_MAX) { - return false; - } - - return true; -} - -// Translate a linear ID for the copy to a float offset -template -__forceinline__ __device__ IndexType -linearIdToOffset(IndexType linearId, const TensorInfo& info) { - IndexType offset = 0; - - if (Dims == -1) { - // Use dynamic dims - for (int i = info.dims - 1; i >= 0; --i) { - IndexType curDimIndex = linearId % info.sizes[i]; - IndexType curDimOffset = curDimIndex * info.strides[i]; - offset += curDimOffset; - - linearId /= info.sizes[i]; - } - } else { - // Use static dims - for (int i = Dims - 1; i >= 0; --i) { - IndexType curDimIndex = linearId % info.sizes[i]; - IndexType curDimOffset = curDimIndex * info.strides[i]; - offset += curDimOffset; - - if (i > 0) { - linearId /= info.sizes[i]; - } - } - } - - return offset; -} - -// Both `src` and `dst` have the same number of total elements, which are copied -// based on a linear id. -template -#if __CUDA_ARCH__ >= 350 -__launch_bounds__(32 * 4, 16) -#endif -__global__ void -THCudaTensor_kernel_copy(TensorInfo dst, - TensorInfo src, - IndexType totalElements) { - const IndexType linearBlockId = - blockIdx.z * gridDim.y * gridDim.x + - blockIdx.y * gridDim.x + - blockIdx.x; - - const IndexType startLinearId = - linearBlockId * THREADS_PER_BLOCK * ELEMENTS_PER_THREAD; - - IndexType endLinearId = - (linearBlockId + 1) * THREADS_PER_BLOCK * ELEMENTS_PER_THREAD; - endLinearId = endLinearId < totalElements ? endLinearId : totalElements; - - for (IndexType linearId = startLinearId + threadIdx.x; - linearId < endLinearId; - linearId += THREADS_PER_BLOCK) { - // Convert `linearId` into an offset of `src` - const IndexType srcOffset = - linearIdToOffset(linearId, src); - - // Convert `linearId` into an offset of `dst` - const IndexType dstOffset = - linearIdToOffset(linearId, dst); - - dst.data[dstOffset] = LDG(&src.data[srcOffset]); - } +static inline int curGPU() { + int curDev; + THCudaCheck(cudaGetDevice(&curDev)); + return curDev; } THC_API void -THCudaTensor_copy(THCState *state, THCudaTensor* dst, THCudaTensor* src) { +THCudaTensor_copy(THCState* state, THCudaTensor* dst, THCudaTensor* src) { long totalElements = THCudaTensor_nElement(state, dst); THArgCheck(totalElements == THCudaTensor_nElement(state, src), 2, "sizes do not match"); - THArgCheck(THCudaTensor_nDimension(state, dst) <= MAX_DIMS, 2, - "Copy only supported for <= 25 dimensions"); - THArgCheck(THCudaTensor_nDimension(state, src) <= MAX_DIMS, 3, - "Copy only supported for <= 25 dimensions"); - if (THCudaTensor_nDimension(state, dst) == 0) { // Zero-dim tensor; copy nothing return; @@ -260,9 +24,9 @@ THCudaTensor_copy(THCState *state, THCudaTensor* dst, THCudaTensor* src) { // -FIXME: if both tensors have matching size and stride arrays, and no // holes within (in other words, there is some permutation that can be applied // to the size/strides such that the resulting tensor is contiguous). - bool memcpyEligible = - (THCudaTensor_isContiguous(state, dst) && THCudaTensor_isContiguous(state, src)) || - (totalElements == 1); + bool srcContig = THCudaTensor_isContiguous(state, src); + bool dstContig = THCudaTensor_isContiguous(state, dst); + bool memcpyEligible = (srcContig && dstContig) || (totalElements == 1); if (memcpyEligible) { THCudaCheck(cudaMemcpyAsync(THCudaTensor_data(state, dst), @@ -270,110 +34,49 @@ THCudaTensor_copy(THCState *state, THCudaTensor* dst, THCudaTensor* src) { totalElements * sizeof(float), cudaMemcpyDeviceToDevice)); } else { - // We always work with a THREADS_PER_BLOCK-sized thread block, - // and assume a max sized grid dimension of MAX_GRID_SIZE. - // Each thread will process up to ELEMENTS_PER_THREAD elements. - const dim3 block(THREADS_PER_BLOCK); - - long gridTiles = DIVUP(totalElements, block.x * ELEMENTS_PER_THREAD); - THArgCheck(gridTiles <= MAX_GRID_SIZE * MAX_GRID_SIZE * MAX_GRID_SIZE, 2, - "tensor too large"); - - long gridX = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; - long gridY = 1; - long gridZ = 1; + int oldDev = curGPU(); + int srcDev = THCudaTensor_getDevice(state, src); + int dstDev = THCudaTensor_getDevice(state, dst); - if (gridTiles > MAX_GRID_SIZE) { - gridTiles = DIVUP(gridTiles, MAX_GRID_SIZE); - gridY = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; - - if (gridTiles > MAX_GRID_SIZE) { - gridTiles = DIVUP(gridTiles, MAX_GRID_SIZE); - gridZ = gridTiles > MAX_GRID_SIZE ? MAX_GRID_SIZE : gridTiles; + if (srcDev == dstDev) { + if (oldDev != srcDev) { + THCudaCheck(cudaSetDevice(srcDev)); } - } - - dim3 grid(gridX, gridY, gridZ); - // It is possible that the tensor dimensions are able to be collapsed, - // and thus we can reduce the actual code complexity of the copy by - // exploiting this knowledge statically, since the div/mod is the - // most expensive part of the operation, more so than memory accesses. - // For instance, when copying a non-contiguous to a contiguous tensor - // (or vice versa), the contiguous tensor can be collapsed to one - // dimension, and the loop to translate the linear index to the array - // index can be similarly collapsed. That is what this unrolling is for. -#define HANDLE_CASE(TYPE, DST, SRC) \ - THCudaTensor_kernel_copy \ - <<>>(dstInfo, srcInfo, (TYPE) totalElements); \ - -#define HANDLE_SRC_CASE(TYPE, DST, SRC) \ - { \ - switch (SRC) { \ - case 1: \ - HANDLE_CASE(TYPE, DST, 1); \ - break; \ - case 2: \ - HANDLE_CASE(TYPE, DST, 2); \ - break; \ - case 3: \ - HANDLE_CASE(TYPE, DST, 3); \ - break; \ - case 4: \ - HANDLE_CASE(TYPE, DST, 4); \ - break; \ - case 5: \ - HANDLE_CASE(TYPE, DST, 5); \ - break; \ - default: \ - HANDLE_CASE(TYPE, -1, -1); \ - break; \ - } \ + bool succ = + THCudaTensor_pointwiseApply2(state, dst, src, CopyOp()); + THArgCheck(succ, 2, CUTORCH_DIM_WARNING); + } else { // multi-gpu + // empirically, running the kernel on the device that holds the + // non-contiguous tensor is faster by 5-10x + int copyDev = dstContig ? srcDev : dstDev; + int remoteDev = dstContig ? dstDev : srcDev; + + // synchronize remote device before copy + cudaEvent_t dataReady; + THCudaCheck(cudaSetDevice(remoteDev)); + THCudaCheck(cudaEventCreate(&dataReady)); + THCudaCheck(cudaEventRecord(dataReady)); + THCudaCheck(cudaSetDevice(copyDev)); + THCudaCheck(cudaStreamWaitEvent(NULL, dataReady, 0)); + THCudaCheck(cudaEventDestroy(dataReady)); + + bool succ = + THCudaTensor_pointwiseApply2(state, dst, src, CopyOp()); + THArgCheck(succ, 2, CUTORCH_DIM_WARNING); + + // synchronize remote device after copy + cudaEvent_t doneCopying; + THCudaCheck(cudaEventCreate(&doneCopying)); + THCudaCheck(cudaEventRecord(doneCopying)); + THCudaCheck(cudaSetDevice(remoteDev)); + THCudaCheck(cudaStreamWaitEvent(NULL, doneCopying, 0)); + THCudaCheck(cudaEventDestroy(doneCopying)); } -#define HANDLE_DST_CASE(TYPE, DST, SRC) \ - case DST: \ - HANDLE_SRC_CASE(TYPE, DST, SRC); \ - break; - - // Can we use 32-bit integer math in the kernel (the linear ID for the copy - // and the resulting non-linear offset is all computable using 32-bit math?) - // We also use unsigned index math in the kernel, as signed div/mod has - // additional overhead. - if (canUse32BitCopyMath(state, src) && canUse32BitCopyMath(state, dst)) { - TensorInfo dstInfo = - THCudaTensor_computeTensorInfo(state, dst); - TensorInfo srcInfo = - THCudaTensor_computeTensorInfo(state, src); - - switch (dstInfo.dims) { - HANDLE_DST_CASE(unsigned int, 1, srcInfo.dims); - HANDLE_DST_CASE(unsigned int, 2, srcInfo.dims); - HANDLE_DST_CASE(unsigned int, 3, srcInfo.dims); - HANDLE_DST_CASE(unsigned int, 4, srcInfo.dims); - HANDLE_DST_CASE(unsigned int, 5, srcInfo.dims); - default: - HANDLE_DST_CASE(unsigned int, -1, srcInfo.dims); - } - } else { - TensorInfo dstInfo = - THCudaTensor_computeTensorInfo(state, dst); - TensorInfo srcInfo = - THCudaTensor_computeTensorInfo(state, src); - - switch (dstInfo.dims) { - HANDLE_DST_CASE(unsigned long, 1, srcInfo.dims); - HANDLE_DST_CASE(unsigned long, 2, srcInfo.dims); - HANDLE_DST_CASE(unsigned long, 3, srcInfo.dims); - HANDLE_DST_CASE(unsigned long, 4, srcInfo.dims); - HANDLE_DST_CASE(unsigned long, 5, srcInfo.dims); - default: - HANDLE_DST_CASE(unsigned long, -1, srcInfo.dims); - } + if (curGPU() != oldDev) { + THCudaCheck(cudaSetDevice(oldDev)); } -#undef HANDLE_CASE -#undef HANDLE_SRC_CASE -#undef HANDLE_DST_CASE } cudaError errcode = cudaGetLastError(); @@ -382,9 +85,3 @@ THCudaTensor_copy(THCState *state, THCudaTensor* dst, THCudaTensor* src) { } } -#undef DIVUP -#undef LDG -#undef ELEMENTS_PER_THREAD -#undef THREADS_PER_BLOCK -#undef MAX_GRID_SIZE -#undef MAX_DIMS diff --git a/lib/THC/THCTensorIndex.cu b/lib/THC/THCTensorIndex.cu new file mode 100644 index 00000000..9dd0d2d2 --- /dev/null +++ b/lib/THC/THCTensorIndex.cu @@ -0,0 +1,240 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +__global__ void THCudaTensor_kernel_indexFill( + float *tensor, long* stride, float *index, long src_nDim, + int dim, long idx_size, long tensor_size, long size_dim, float val +) +{ + int thread_idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; + + long flat_size = tensor_size / idx_size; + + if (thread_idx < flat_size) + { + long coeff = 0; + for (int i=0; i dim) + { + coeff = leftover / stride[d]; + leftover -= coeff * stride[d]; + srcIdx += coeff * stride[d]; + } + } + tensor[srcIdx + (int)((index[i])-1)*stride[dim]] = val; + } + } +} + +__global__ void THCudaTensor_kernel_indexCopy( + float *res, float *src, long* res_stride, float *index, + long res_nDim, int dim, long idx_size, long src_size, long size_dim +) +{ + int thread_idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; + + long flat_size = src_size / idx_size; + + if (thread_idx < flat_size) + { + long coeff = 0; + for (int i=0; i dim) + { + coeff = leftover / res_stride[d]; + leftover -= coeff * res_stride[d]; + targetIdx += coeff * res_stride[d]; + resIdx += coeff * res_stride[d]; + } + } + res[resIdx + ((int)(index[i])-1)*res_stride[dim]] = src[targetIdx + i*res_stride[dim]]; + } + } +} + +void THCudaTensor_indexCopy(THCState *state, THCudaTensor *res_, int dim, THLongTensor *indices, THCudaTensor *src) +{ + THCudaTensor *indices_; + long *stride_; + long nIndex = indices->size[0]; + long nRes; + + THArgCheck(indices->nDimension == 1, 3, "expecting vector of indices"); + THArgCheck(dim < src->nDimension, 4, "Indexing dim is out of bounds"); + THArgCheck(src->nDimension > 0, 2, "Source tensor is empty"); + THArgCheck(nIndex == src->size[dim], 4, "length of src.size[dim] is not equal to length of indices"); + + src = THCudaTensor_newContiguous(state, src); + indices_ = THCudaTensor_newWithSize1d(state, nIndex); + THCudaTensor_copyLong(state, indices_, indices); + + nRes = THCudaTensor_nElement(state, res_); + dim3 nthreads(16, 16); + dim3 nblocks(ceil((float)nRes / nIndex / (16*16))); + + THCudaCheck(cudaMalloc((void**)&stride_, res_->nDimension * sizeof(long))); + THCudaCheck(cudaMemcpy(stride_, res_->stride, res_->nDimension * sizeof(long), cudaMemcpyHostToDevice)); + + THCudaTensor_kernel_indexCopy<<>>( + THCudaTensor_data(state, res_), THCudaTensor_data(state, src), + stride_, THCudaTensor_data(state, indices_), + res_->nDimension, dim, nIndex, + THCudaTensor_nElement(state, src), res_->size[dim] + ); + + THCudaCheck(cudaFree(stride_)); + THCudaTensor_free(state, indices_); + THCudaTensor_free(state, src); +} + + +void THCudaTensor_indexFill(THCState *state, THCudaTensor *res_, int dim, THLongTensor *indices, float val) +{ + THCudaTensor *indices_; + long *stride_; + long nIndex = indices->size[0]; + long nRes; + + THArgCheck(indices->nDimension == 1, 3, "Index is supposed to be a vector"); + THArgCheck(dim < res_->nDimension,4,"Indexing dim is out of bounds"); + THArgCheck(res_->nDimension > 0, 2, "Source tensor is empty"); + + indices_ = THCudaTensor_newWithSize1d(state, nIndex); + THCudaTensor_copyLong(state, indices_, indices); + + nRes = THCudaTensor_nElement(state, res_) / res_->size[dim] * nIndex; + + + dim3 nthreads(16, 16); + dim3 nblocks(ceil((float)nRes / nIndex / (16*16))); + + THCudaCheck(cudaMalloc((void**)&stride_, res_->nDimension * sizeof(long))); + THCudaCheck(cudaMemcpy(stride_, res_->stride, res_->nDimension * sizeof(long), cudaMemcpyHostToDevice)); + + THCudaTensor_kernel_indexFill<<>>( + THCudaTensor_data(state, res_), stride_, THCudaTensor_data(state, indices_), + res_->nDimension, dim, nIndex, nRes, res_->size[dim], val + ); + + THCudaCheck(cudaFree(stride_)); + THCudaTensor_free(state, indices_); +} + +__global__ void THCudaTensor_kernel_indexSelect( + float *tensor, float *src, long* src_stride, float *index, + long src_nDim, int dim, long idx_size, long tensor_size, long size_dim +) +{ + int thread_idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; + + long flat_size = tensor_size / idx_size; + + if (thread_idx < flat_size) + { + long coeff = 0; + for (int i=0; i dim) + { + coeff = leftover / src_stride[d]; + leftover -= coeff * src_stride[d]; + targetIdx += coeff * src_stride[d]; + srcIdx += coeff * src_stride[d]; + } + } + tensor[targetIdx + i*src_stride[dim]] = src[srcIdx + ((int)(index[i])-1)*src_stride[dim]]; + } + } +} + + +void THCudaTensor_indexSelect(THCState *state, THCudaTensor *res_, THCudaTensor *src, int dim, THLongTensor *indices) +{ + THLongStorage *newSize; + THCudaTensor *indices_; + long *stride_; + long nIndex = indices->size[0]; + long nRes; + + THArgCheck(indices->nDimension == 1, 3, "expecting vector of indices"); + THArgCheck(dim < src->nDimension, 4, "Indexing dim is out of bounds"); + THArgCheck(src->nDimension > 0, 2, "Source tensor is empty"); + + newSize = THLongStorage_newWithSize(src->nDimension); + THLongStorage_rawCopy(newSize, src->size); + newSize->data[dim] = nIndex; + THCudaTensor_resize(state, res_, newSize, NULL); + THLongStorage_free(newSize); + + indices_ = THCudaTensor_newWithSize1d(state, nIndex); + THCudaTensor_copyLong(state, indices_, indices); + + nRes = THCudaTensor_nElement(state, res_); + dim3 nthreads(16, 16); + dim3 nblocks(ceil((float)nRes / nIndex / (16*16))); + + THCudaCheck(cudaMalloc((void**)&stride_, src->nDimension * sizeof(long))); + THCudaCheck(cudaMemcpy(stride_, src->stride, src->nDimension * sizeof(long), cudaMemcpyHostToDevice)); + + THCudaTensor_kernel_indexSelect<<>>( + THCudaTensor_data(state, res_), THCudaTensor_data(state, src), + stride_, THCudaTensor_data(state, indices_), + src->nDimension, dim, indices->size[0], nRes, src->size[dim] + ); + + THCudaCheck(cudaFree(stride_)); + THCudaTensor_free(state, indices_); +} diff --git a/lib/THC/THCTensorMasked.cu b/lib/THC/THCTensorMasked.cu new file mode 100644 index 00000000..2fa4891a --- /dev/null +++ b/lib/THC/THCTensorMasked.cu @@ -0,0 +1,125 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +struct TensorMaskedFillOp { + TensorMaskedFillOp(float v) : value(v) {} + __device__ __forceinline__ void operator()(float* t, float* mask) { + // Really mask should be `0` or `1` but we can't propagate errors here. + const float maskVal = *mask; + if (maskVal != 0.0f) { + *t = value; + } + } + + float value; +}; + +void THCudaTensor_maskedFill(THCState* state, + THCudaTensor *tensor, THCudaTensor *mask, float value) +{ + THArgCheck(THCudaTensor_nElement(state, tensor) == + THCudaTensor_nElement(state, mask), + 2, "sizes do not match"); + + if (!THCudaTensor_pointwiseApply2(state, tensor, mask, TensorMaskedFillOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + + THCudaCheck(cudaGetLastError()); +} + +void THCudaTensor_maskedCopy(THCState* state, + THCudaTensor *tensor, THCudaTensor *mask, THCudaTensor *src) +{ + THError("maskedCopy is not yet implemented for CUDA"); +} + +struct TensorMaskedSelectOp { + TensorMaskedSelectOp(float* t) : out(t) {} + __device__ __forceinline__ void operator()(float* mask, float* maskPrefixSum, float* in) { + // Really mask should be `0` or `1` but we can't propagate errors here. + if (*mask != 0.0f) { + out[(int) *maskPrefixSum] = *in; + } + } + + float* out; +}; + +void THCudaTensor_maskedSelect(THCState* state, + THCudaTensor *tensor, THCudaTensor *src, THCudaTensor *mask) +{ + THArgCheck(THCudaTensor_nElement(state, mask) == THCudaTensor_nElement(state, src), + 2, "sizes do not match"); + + // Determine our output size + THCudaTensor* contigMask = THCudaTensor_newContiguous(state, mask); + int totalElements = (int) THCudaTensor_sumall(state, contigMask); + THCudaTensor_resize1d(state, tensor, totalElements); + + // Use a prefix sum to determine the output locations of the masked elements + THCudaTensor* maskPrefixSum = THCudaTensor_new(state); + THCudaTensor_resizeAs(state, maskPrefixSum, mask); + + thrust::device_ptr + maskData(THCudaTensor_data(state, contigMask)); + thrust::device_ptr + maskPrefixSumData(THCudaTensor_data(state, maskPrefixSum)); + thrust::exclusive_scan(maskData, + maskData + THCudaTensor_nElement(state, contigMask), + maskPrefixSumData); + + // Then copy over the masked elements at their desired output index + bool status = THCudaTensor_pointwiseApply3( + state, contigMask, maskPrefixSum, + src, TensorMaskedSelectOp(THCudaTensor_data(state, tensor))); + + THCudaTensor_free(state, contigMask); + THCudaTensor_free(state, maskPrefixSum); + + if (!status) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + + THCudaCheck(cudaGetLastError()); +} + +void THCudaTensor_maskedFillByte(THCState* state, THCudaTensor *tensor, THByteTensor *mask, float value) +{ + THLongStorage* maskSize = THByteTensor_newSizeOf(mask); + THCudaTensor* maskCuda = THCudaTensor_newWithSize(state, maskSize, NULL); + THLongStorage_free(maskSize); + THCudaTensor_copyByte(state, maskCuda, mask); + THCudaTensor_maskedFill(state, tensor, maskCuda, value); + THCudaTensor_free(state, maskCuda); +} + +void THCudaTensor_maskedCopyByte(THCState* state, THCudaTensor *tensor, THByteTensor *mask, THCudaTensor *src) +{ + THError("maskedCopyByte is not yet implemented for CUDA"); +} + +void THCudaTensor_maskedSelectByte(THCState* state, THCudaTensor *tensor, THCudaTensor *src, THByteTensor *mask) +{ + THLongStorage* maskSize = THByteTensor_newSizeOf(mask); + THCudaTensor* maskCuda = THCudaTensor_newWithSize(state, maskSize, NULL); + THLongStorage_free(maskSize); + THCudaTensor_copyByte(state, maskCuda, mask); + THCudaTensor_maskedSelect(state, tensor, src, maskCuda); + THCudaTensor_free(state, maskCuda); +} diff --git a/lib/THC/THCTensorMath.cu b/lib/THC/THCTensorMath.cu index ac9b40a6..fafd0651 100644 --- a/lib/THC/THCTensorMath.cu +++ b/lib/THC/THCTensorMath.cu @@ -3,6 +3,8 @@ #include "THCBlas.h" #include "THCTensorCopy.h" #include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" #include #include @@ -10,27 +12,39 @@ #include #include -#define NB_THREADS_PER_BLOCK 256 - #ifndef DIVUP #define DIVUP(x, y) (((x) + (y) - 1) / (y)) #endif -void THCudaTensor_fill(THCState *state, THCudaTensor *self_, float value) -{ - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); +struct TensorFillOp { + TensorFillOp(float v) : val(v) {} + __device__ __forceinline__ void operator()(float* v) { *v = val; } + + const float val; +}; - thrust::fill(self_data, self_data+THCudaTensor_nElement(state, self), value); +void THCudaTensor_fill(THCState* state, THCudaTensor *self_, float value) +{ + if (!THCudaTensor_pointwiseApply1(state, self_, TensorFillOp(value))) { + THArgCheck(false, 1, CUTORCH_DIM_WARNING); + } - THCudaTensor_freeCopyTo(state, self, self_); + THCudaCheck(cudaGetLastError()); } void THCudaTensor_zero(THCState *state, THCudaTensor *self_) { - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - THCudaCheck(cudaMemset(THCudaTensor_data(state, self), 0, sizeof(float)*THCudaTensor_nElement(state, self))); - THCudaTensor_freeCopyTo(state, self, self_); + if (THCudaTensor_isContiguous(state, self_)) { + THCudaCheck(cudaMemsetAsync(THCudaTensor_data(state, self_), + 0, + sizeof(float) * THCudaTensor_nElement(state, self_))); + } else { + if (!THCudaTensor_pointwiseApply1(state, self_, TensorFillOp(0))) { + THArgCheck(false, 1, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); } void THCudaTensor_zeros(THCState *state, THCudaTensor *r_, THLongStorage *size) @@ -56,188 +70,82 @@ long THCudaTensor_numel(THCState *state, THCudaTensor *t) return THCudaTensor_nElement(state, t); } - -struct addvalue_functor -{ - const float value; - - addvalue_functor(float value_) : value(value_) {} - - __host__ __device__ float operator()(const float& x) const - { - return (x+value); - } -}; - -void THCudaTensor_add(THCState *state, THCudaTensor *self_, THCudaTensor *src_, float value) -{ - THCudaTensor_resizeAs(state, self_, src_); - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - THCudaTensor *src = THCudaTensor_newContiguous(state, src_); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src_data(THCudaTensor_data(state, src)); - - thrust::transform(src_data, src_data+size, self_data, addvalue_functor(value)); - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - -struct mulvalue_functor -{ - const float value; - mulvalue_functor(float value_) : value(value_) {} - __host__ __device__ float operator()(const float& x) const - { - return (x*value); +struct TensorCPowOp { + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = powf(*out, *in); } -}; - -void THCudaTensor_mul(THCState *state, THCudaTensor *self_, THCudaTensor *src_, float value) -{ - THCudaTensor_resizeAs(state, self_, src_); - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - THCudaTensor *src = THCudaTensor_newContiguous(state, src_); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src_data(THCudaTensor_data(state, src)); - - thrust::transform(src_data, src_data+size, self_data, mulvalue_functor(value)); - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - -struct divvalue_functor -{ - const float value; - divvalue_functor(float value_) : value(value_) {} - __host__ __device__ float operator()(const float& x) const - { - return (x/value); + __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) { + *out = powf(*in1, *in2); } }; -void THCudaTensor_div(THCState *state, THCudaTensor *self_, THCudaTensor *src_, float value) -{ - THCudaTensor_resizeAs(state, self_, src_); - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - THCudaTensor *src = THCudaTensor_newContiguous(state, src_); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src_data(THCudaTensor_data(state, src)); - - thrust::transform(src_data, src_data+size, self_data, divvalue_functor(value)); - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - -void THCudaTensor_cadd(THCState *state, THCudaTensor *self_, THCudaTensor* src1, float value, THCudaTensor *src2) +void THCudaTensor_cpow(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) { - THCudaTensor_resizeAs(state, self_, src1); - THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), 3, "size do not match"); - { - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + THArgCheck(THCudaTensor_nElement(state, src1) == + THCudaTensor_nElement(state, src2), 3, "sizes do not match"); - if (self_ != src1) { - src1 = THCudaTensor_newContiguous(state, src1); - THCudaTensor_copy(state, self, src1); - THCudaTensor_free(state, src1); + if (self_ == src1) { + // self = pow(self, src2) + if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorCPowOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); } + } else { + THCudaTensor_resizeAs(state, self_, src1); - src2 = THCudaTensor_newContiguous(state, src2); - - THCudaBlas_axpy(state, - THCudaTensor_nElement(state, self), value, - THCudaTensor_data(state, src2), 1, - THCudaTensor_data(state, self), 1); - - THCudaTensor_free(state, src2); - THCudaTensor_freeCopyTo(state, self, self_); + // self = pow(src1, src2) + if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorCPowOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } } -} - -void THCudaTensor_cmul(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) -{ - THCudaTensor_resizeAs(state, self_, src1); - THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), 3, "size do not match"); - { - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - long size = THCudaTensor_nElement(state, self); - src1 = THCudaTensor_newContiguous(state, src1); - src2 = THCudaTensor_newContiguous(state, src2); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src1_data(THCudaTensor_data(state, src1)); - thrust::device_ptr src2_data(THCudaTensor_data(state, src2)); - thrust::transform(src2_data, src2_data+size, src1_data, self_data, thrust::multiplies()); + THCudaCheck(cudaGetLastError()); +} - THCudaTensor_free(state, src1); - THCudaTensor_free(state, src2); - THCudaTensor_freeCopyTo(state, self, self_); +struct TensorDivOp { + __device__ __forceinline__ void + operator()(float* out, float* in) { + *out /= *in; } -} -struct cpow_functor -{ - __host__ __device__ float operator()(const float& a, const float& b) const - { - return pow(a, b); + __device__ __forceinline__ void + operator()(float* out, float* in1, float* in2) { + *out = *in1 / *in2; } }; -void THCudaTensor_cpow(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) +void THCudaTensor_cdiv(THCState* state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) { - THCudaTensor_resizeAs(state, self_, src1); - THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), 3, "size does not match"); + THArgCheck(THCudaTensor_nElement(state, src1) == + THCudaTensor_nElement(state, src2), 3, "sizes do not match"); - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - long size = THCudaTensor_nElement(state, self); - src1 = THCudaTensor_newContiguous(state, src1); - src2 = THCudaTensor_newContiguous(state, src2); - - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src1_data(THCudaTensor_data(state, src1)); - thrust::device_ptr src2_data(THCudaTensor_data(state, src2)); + if (self_ == src1) { + // self *= src2 + if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorDivOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCudaTensor_resizeAs(state, self_, src1); - thrust::transform(src1_data, src1_data + size, src2_data, self_data, cpow_functor()); + // self = src1 * src2 + if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorDivOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } - THCudaTensor_free(state, src1); - THCudaTensor_free(state, src2); - THCudaTensor_freeCopyTo(state, self, self_); + THCudaCheck(cudaGetLastError()); } -void THCudaTensor_cdiv(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) -{ - THCudaTensor_resizeAs(state, self_, src1); - THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), 3, "size does not match"); - { - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - long size = THCudaTensor_nElement(state, self); - src1 = THCudaTensor_newContiguous(state, src1); - src2 = THCudaTensor_newContiguous(state, src2); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src1_data(THCudaTensor_data(state, src1)); - thrust::device_ptr src2_data(THCudaTensor_data(state, src2)); - - thrust::transform(src1_data, src1_data+size, src2_data, self_data, thrust::divides()); +struct TensorAddCMulOp { + TensorAddCMulOp(float v) : val(v) {} - THCudaTensor_free(state, src1); - THCudaTensor_free(state, src2); - THCudaTensor_freeCopyTo(state, self, self_); + __device__ __forceinline__ void + operator()(float* out, float* in1, float* in2) { + *out += val * *in1 * *in2; } -} - -__global__ void THCudaTensor_kernel_addcmul(float *data, float value, float *src1, float *src2, long size) -{ - long k = (((blockIdx.y * gridDim.x) + blockIdx.x) * blockDim.x) + threadIdx.x; - - if(k < size) - data[k] += value*src1[k]*src2[k]; -} + float val; +}; void THCudaTensor_addcmul(THCState *state, THCudaTensor *self_, THCudaTensor *t, float value, THCudaTensor *src1, THCudaTensor *src2) { @@ -247,38 +155,27 @@ void THCudaTensor_addcmul(THCState *state, THCudaTensor *self_, THCudaTensor *t, THCudaTensor_copy(state, self_, t); } THCudaTensor_resizeAs(state, self_, src1); - THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), 3, "size do not match"); - { - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - long size = THCudaTensor_nElement(state, self); - src1 = THCudaTensor_newContiguous(state, src1); - src2 = THCudaTensor_newContiguous(state, src2); - - int nBlockPerRow, nBlockPerColumn, nThreadPerBlock; - THCudaGetGridSize(&nBlockPerRow, &nBlockPerColumn, &nThreadPerBlock, size); - dim3 threads(nThreadPerBlock); - dim3 grid(nBlockPerRow, nBlockPerColumn); - THCudaTensor_kernel_addcmul<<>>(THCudaTensor_data(state, self), value, THCudaTensor_data(state, src1), THCudaTensor_data(state, src2), size); + THArgCheck(THCudaTensor_nElement(state, src1) == + THCudaTensor_nElement(state, src2), 3, "sizes do not match"); - cudaError errcode = cudaGetLastError(); - if(errcode != cudaSuccess) - THError(cudaGetErrorString(errcode)); - - THCudaTensor_free(state, src1); - THCudaTensor_free(state, src2); - THCudaTensor_freeCopyTo(state, self, self_); + if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorAddCMulOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); } + + THCudaCheck(cudaGetLastError()); } -__global__ void THCudaTensor_kernel_addcdiv(float *data, float value, float *src1, float *src2, long size) -{ - long k = (((blockIdx.y * gridDim.x) + blockIdx.x) * blockDim.x) + threadIdx.x; +struct TensorAddCDivOp { + TensorAddCDivOp(float v) : val(v) {} - if(k < size) - data[k] += value*src1[k]/src2[k]; -} + __device__ __forceinline__ void + operator()(float* out, float* in1, float* in2) { + *out += val * *in1 / *in2; + } + float val; +}; void THCudaTensor_addcdiv(THCState *state, THCudaTensor *self_, THCudaTensor *t, float value, THCudaTensor *src1, THCudaTensor *src2) { @@ -289,47 +186,13 @@ void THCudaTensor_addcdiv(THCState *state, THCudaTensor *self_, THCudaTensor *t, } THCudaTensor_resizeAs(state, self_, src1); - THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), 3, "size do not match"); - { - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - long size = THCudaTensor_nElement(state, self); - src1 = THCudaTensor_newContiguous(state, src1); - src2 = THCudaTensor_newContiguous(state, src2); - - int nBlockPerRow, nBlockPerColumn, nThreadPerBlock; - THCudaGetGridSize(&nBlockPerRow, &nBlockPerColumn, &nThreadPerBlock, size); - dim3 threads(nThreadPerBlock); - dim3 grid(nBlockPerRow, nBlockPerColumn); - - THCudaTensor_kernel_addcdiv<<>>(THCudaTensor_data(state, self), value, THCudaTensor_data(state, src1), THCudaTensor_data(state, src2), size); - - cudaError errcode = cudaGetLastError(); - if(errcode != cudaSuccess) - THError(cudaGetErrorString(errcode)); + THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), 3, "sizes do not match"); - THCudaTensor_free(state, src1); - THCudaTensor_free(state, src2); - THCudaTensor_freeCopyTo(state, self, self_); + if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorAddCDivOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); } -} - -float THCudaTensor_dot(THCState *state, THCudaTensor *self, THCudaTensor *src) -{ - THArgCheck(THCudaTensor_nElement(state, self) == THCudaTensor_nElement(state, src), 2, "size do not match"); - - { - self = THCudaTensor_newContiguous(state, self); - src = THCudaTensor_newContiguous(state, src); - - float result = THCudaBlas_dot(state, - THCudaTensor_nElement(state, self), - THCudaTensor_data(state, self), 1, - THCudaTensor_data(state, src), 1); - THCudaTensor_free(state, src); - THCudaTensor_free(state, self); - return result; - } + THCudaCheck(cudaGetLastError()); } float THCudaTensor_minall(THCState *state, THCudaTensor *self) @@ -376,8 +239,6 @@ float THCudaTensor_prodall(THCState *state, THCudaTensor *self) return result; } - - struct dim4 { unsigned arr[4]; @@ -388,1885 +249,20 @@ struct dim4 { __host__ __device__ unsigned& operator[](const unsigned& idx) { return arr[idx]; } }; - - -/* Reduce one of the outer dimensions of a tensor - * - * For an n-d tensor where the reduction is *not* along the innermost dimension: - * - * - blockIdx.x loops over the dimensions on the outside of the reduced dimension. - * - blockIdx.y and threadIdx.x loop over the dimensions on the inside of the - * reduced dimension. - * - Each thread sequentially reduces one row of the reduced dimension. - * - * Reduction along the innermost dimension is handled in a separate kernel. - */ -template -__global__ void THCudaTensor_kernel_transformReduceOuterDim(float *tgt, float *src_, - unsigned num_orows, unsigned num_irows, unsigned row_size, - UnaryFunction unary_op, float init, BinaryFunction binary_op) -{ - - for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { - float *src = src_ + orow * row_size * num_irows + irow; - float acc = init; - - for (unsigned col = 0; col < row_size; ++col) { - acc = binary_op(acc, unary_op(*src)); - src += num_irows; - } - tgt[orow * num_irows + irow] = float(acc); - } - } -} - - - -template -__host__ void THCudaTensor_transformReduceOuterDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, - long rdim, UnaryFunction unary_op, float init, BinaryFunction binary_op) -{ - unsigned ndim = THCudaTensor_nDimension(state, src); - unsigned num_orows = 1; - for (unsigned dim = 0; dim < rdim; dim++) { - num_orows *= THCudaTensor_size(state, src, dim); - } - unsigned row_size = THCudaTensor_size(state, src, rdim); - unsigned num_irows = 1; - for (unsigned dim = rdim + 1; dim < ndim; dim++) { - num_irows *= THCudaTensor_size(state, src, dim); - } - - dim3 threads(min(512, num_irows)); - unsigned maxGridDim = 1024; - dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, DIVUP(num_irows, threads.x))); - - THCudaTensor_kernel_transformReduceOuterDim<<>>(THCudaTensor_data(state, tgt), - THCudaTensor_data(state, src), num_orows, num_irows, row_size, unary_op, init, binary_op); - cudaError errcode = cudaGetLastError(); - if(errcode != cudaSuccess) { - THError(cudaGetErrorString(errcode)); - } -} - - - -/* Reduce the innermost dimension of a tensor - * - * For an n-d tensor where the reduction is along the innermost dimension: - * - * - blockIdx.x loops over all the outer rows. - * - Threads in the same block reduce one inner row in parallel. - * - * Reduction along other dimensions is handled in a separate kernel. - */ -template -__global__ void THCudaTensor_kernel_transformReduceInnermostDim(float *tgt, float *src_, - unsigned num_rows, unsigned row_size, UnaryFunction unary_op, float init, BinaryFunction binary_op) -{ - __shared__ float sbuf[32][16]; - - for (unsigned block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { - unsigned row = block_row + threadIdx.y; - float acc = init; - if (row < num_rows) { - float *src = src_ + row * row_size; - // Sequential reduction within a thread. - for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) { - float val = unary_op(src[col]); - acc = binary_op(acc, val); - } - } - - sbuf[threadIdx.y][threadIdx.x] = acc; - - // Reduce intermediate values to single value. - float* line = &sbuf[threadIdx.y][0]; - for (unsigned s = 8; s > 1; s >>= 1) { - if (row < num_rows && threadIdx.x < s) { - line[threadIdx.x] = binary_op(line[threadIdx.x], line[threadIdx.x + s]); - } - __syncthreads(); - } - - if (row < num_rows && threadIdx.x == 0) { - tgt[row] = binary_op(line[0], line[1]); - } - __syncthreads(); - } -} - -template -__host__ void THCudaTensor_transformReduceInnermostDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, - UnaryFunction unary_op, float init, BinaryFunction binary_op) -{ - unsigned ndim = THCudaTensor_nDimension(state, src); - unsigned num_rows = 1; - for (unsigned dim = 0; dim < ndim - 1; dim++) { - num_rows *= THCudaTensor_size(state, src, dim); - } - unsigned row_size = THCudaTensor_size(state, src, ndim - 1); - - dim3 threads(16, 32); - dim3 grid(min(1024, DIVUP(num_rows, threads.y))); - - THCudaTensor_kernel_transformReduceInnermostDim<<>>(THCudaTensor_data(state, tgt), - THCudaTensor_data(state, src), num_rows, row_size, unary_op, init, binary_op); - cudaError errcode = cudaGetLastError(); - if(errcode != cudaSuccess) { - THError(cudaGetErrorString(errcode)); - } -} - - -template -void THCudaTensor_transformReduceDim(THCState *state, THCudaTensor *self_, THCudaTensor *src, - long dimension, UnaryFunction unary_op, float init, BinaryFunction binary_op) -{ - THArgCheck(dimension >= 0 && dimension < THCudaTensor_nDimension(state, src), 3, "dimension out of range"); - - THLongStorage *dim = THCudaTensor_newSizeOf(state, src); - THLongStorage_set(dim, dimension, 1); - THCudaTensor_resize(state, self_, dim, NULL); - THLongStorage_free(dim); - - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - src = THCudaTensor_newContiguous(state, src); - - if(dimension == THCudaTensor_nDimension(state, src)-1) { - THCudaTensor_transformReduceInnermostDim(state, self, src, unary_op, init, binary_op); - } else { - THCudaTensor_transformReduceOuterDim(state, self, src, dimension, unary_op, init, binary_op); - } - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - - -template -void THCudaTensor_reduceDim(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, float init, BinaryFunction binary_op) -{ - THCudaTensor_transformReduceDim(state, self_, src, dimension, thrust::identity(), init, binary_op); -} - - -void THCudaTensor_sum(THCState *state, THCudaTensor *self, THCudaTensor *src, long dimension) -{ - return THCudaTensor_reduceDim(state, self, src, dimension, 0.0f, thrust::plus()); -} - -void THCudaTensor_prod(THCState *state, THCudaTensor *self, THCudaTensor *src, long dimension) -{ - return THCudaTensor_reduceDim(state, self, src, dimension, 1.0f, thrust::multiplies()); -} - - -/* Perform an inclusive scan along an outer dimension of a tensor. - * - * - num_orows is the size of the flattened outer dimensions; - * - num_irows is the size of the flattened inner dimensions; - * - row_size is the size of the dimension along which to compute the variance; - * - * The dimensions to the outside and inside of the specified dimension are considered as flattened. - * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened - * outer dimensions, which contains several "inner rows"). - * Each thread processes a single inner row at a time. - */ -template -__global__ void THCudaTensor_kernel_scanOuterDim(float *tgt_, float *src_, - unsigned num_orows, unsigned num_irows, unsigned row_size, - float init, BinaryOp binary_op) -{ - for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { - float *src = src_ + orow * row_size * num_irows + irow; - float *tgt = tgt_ + orow * row_size * num_irows + irow; - float acc = init; - - for (unsigned col = 0; col < row_size; ++col) { - acc = binary_op(acc, *src); - *tgt = acc; - - src += num_irows; - tgt += num_irows; - } - } - } -} - -template -__host__ void THCudaTensor_scanOuterDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, long dimension, - float init, BinaryOp binary_op) -{ - unsigned ndim = THCudaTensor_nDimension(state, src); - // Treat all outer dimensions (i.e. dim < dimension) as one. - unsigned num_orows = 1; - for (unsigned dim = 0; dim < dimension; dim++) { - num_orows *= THCudaTensor_size(state, src, dim); - } - unsigned row_size = THCudaTensor_size(state, src, dimension); - // Treat all inner dimensions (i.e. dim > dimension) as one. - unsigned num_irows = 1; - for (unsigned dim = dimension + 1; dim < ndim; dim++) { - num_irows *= THCudaTensor_size(state, src, dim); - } - - dim3 threads(min(512, num_irows)); - unsigned maxGridDim = 1024; - dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, DIVUP(num_irows, threads.x))); - - THCudaTensor_kernel_scanOuterDim<<>>( - THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_orows, num_irows, row_size, init, binary_op); - cudaError errcode = cudaGetLastError(); - if (errcode != cudaSuccess) { - THError(cudaGetErrorString(errcode)); - } -} - - -/* Perform an inclusive scan along the innermost dimension of a tensor. - * - * - num_rows is the size of the flattened outer dimensions; - * - row_size is the size of the innermost dimension; - * - * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is - * considered as having 'num_rows' rows of size 'row_size'. - * Each thread block processes one or more sets of contiguous rows (processing multiple rows - * per thread block is quicker than processing a single row, especially for short rows). - */ -template -__global__ void THCudaTensor_kernel_scanInnermostDim(float *tgt_, float *src_, - unsigned num_rows, unsigned row_size, - float init, BinaryFunction binary_op) -{ - __shared__ float sbuf[num_threads_y][2 * num_threads_x]; - - float* row_buf = sbuf[threadIdx.y]; - - for (unsigned block_row = blockIdx.x * blockDim.y; - block_row < num_rows; - block_row += blockDim.y * gridDim.x) { - unsigned row = block_row + threadIdx.y; - float block_total = init; - - float *row_src = src_ + row * row_size; - float *row_tgt = tgt_ + row * row_size; - - // Perform scan on one block at a time, keeping track of the total value of - // all blocks processed so far. - for (unsigned block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { - // Load data into shared memory (two values per thread). - unsigned col1 = block_col + threadIdx.x; - unsigned col2 = block_col + num_threads_x + threadIdx.x; - if (row < num_rows) { - if (col1 < row_size) { - row_buf[threadIdx.x] = row_src[col1]; - } else { - row_buf[threadIdx.x] = init; - } - - if (col2 < row_size) { - row_buf[num_threads_x + threadIdx.x] = row_src[col2]; - } else { - row_buf[num_threads_x + threadIdx.x] = init; - } - - // Add the total value of all previous blocks to the first value of this block. - if (threadIdx.x == 0) { - row_buf[0] = binary_op(row_buf[0], block_total); - } - } - __syncthreads(); - - // Parallel reduction (up-sweep). - for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { - if (row < num_rows && threadIdx.x < s) { - unsigned offset = (2 * threadIdx.x + 1) * d - 1; - row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); - } - __syncthreads(); - } - - // Down-sweep. - for (unsigned s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { - if (row < num_rows && threadIdx.x < s - 1) { - unsigned offset = 2 * (threadIdx.x + 1) * d - 1; - row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); - } - __syncthreads(); - } - - // Write back to output. - if (row < num_rows) { - if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x]; - if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x]; - } - block_total = row_buf[2 * num_threads_x - 1]; - __syncthreads(); - } - } -} - -template -__host__ void THCudaTensor_scanInnermostDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, float init, BinaryFunction binary_op) -{ - unsigned ndim = THCudaTensor_nDimension(state, src); - // Treat all outer dimensions as a single dimension. - unsigned num_rows = 1; - for (unsigned dim = 0; dim < ndim - 1; dim++) { - num_rows *= THCudaTensor_size(state, src, dim); - } - unsigned row_size = THCudaTensor_size(state, src, ndim - 1); - - dim3 threads(16, 32); - dim3 grid(min(1024, DIVUP(num_rows, threads.y))); - - THCudaTensor_kernel_scanInnermostDim<16, 32><<>>( - THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_rows, row_size, init, binary_op); - cudaError errcode = cudaGetLastError(); - if (errcode != cudaSuccess) { - THError(cudaGetErrorString(errcode)); - } -} - -template -void THCudaTensor_scanDim(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, float init, BinaryFunction binary_op) -{ - THCudaTensor_resizeAs(state, self_, src); - - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - src = THCudaTensor_newContiguous(state, src); - - if (dimension == THCudaTensor_nDimension(state, src) - 1) { - THCudaTensor_scanInnermostDim(state, self, src, init, binary_op); - } else { - THCudaTensor_scanOuterDim(state, self, src, dimension, init, binary_op); - } - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - -void THCudaTensor_cumsum(THCState *state, THCudaTensor *self, THCudaTensor *src, long dimension) -{ - return THCudaTensor_scanDim(state, self, src, dimension, 0.0f, thrust::plus()); -} - -void THCudaTensor_cumprod(THCState *state, THCudaTensor *self, THCudaTensor *src, long dimension) -{ - return THCudaTensor_scanDim(state, self, src, dimension, 1.0f, thrust::multiplies()); -} - -/* A set of reduction kernels that take in binary ops on thrust pairs (of value, index). - These are useful when you not only have to do a reduction, but you might have - to preserve the location of contention (for example min/max operations). - The structure of the kernels follows the structure of the reduction kernels. -*/ -template -__global__ void THCudaTensor_kernel_transformReduceOuterDimIndex(float *tgt1, float *tgt2, - float *src_, - unsigned num_orows, - unsigned num_irows, - unsigned row_size, - thrust::pair init, - BinaryFunction binary_op) -{ - for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { - float *src = src_ + orow * row_size * num_irows + irow; - thrust::pair acc = init; - - for (unsigned col = 0; col < row_size; ++col) { - acc = binary_op(thrust::make_pair(*src, col+1), acc); // i+1 for 1-indexing - src += num_irows; - } - tgt1[orow * num_irows + irow] = acc.first; - tgt2[orow * num_irows + irow] = acc.second; - } - } -} - -template -__host__ void THCudaTensor_transformReduceOuterDimIndex(THCState *state, THCudaTensor *tgt1, THCudaTensor *tgt2, - THCudaTensor *src, - long rdim, thrust::pair init, - BinaryFunction binary_op) -{ - unsigned ndim = THCudaTensor_nDimension(state, src); - unsigned num_orows = 1; - for (unsigned dim = 0; dim < rdim; dim++) { - num_orows *= THCudaTensor_size(state, src, dim); - } - unsigned row_size = THCudaTensor_size(state, src, rdim); - unsigned num_irows = 1; - for (unsigned dim = rdim + 1; dim < ndim; dim++) { - num_irows *= THCudaTensor_size(state, src, dim); - } - - dim3 threads(min(512, num_irows)); - unsigned maxGridDim = 1024; - dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, DIVUP(num_irows, threads.x))); - - THCudaTensor_kernel_transformReduceOuterDimIndex<<>>( - THCudaTensor_data(state, tgt1), THCudaTensor_data(state, tgt2), - THCudaTensor_data(state, src), num_orows, num_irows, row_size, init, binary_op); - cudaError errcode = cudaGetLastError(); - if(errcode != cudaSuccess) { - THError(cudaGetErrorString(errcode)); - } -} - -/* Reduce the innermost dimension of a tensor (on thrust::pair functors which are (value, index)) - * - * For an n-d tensor (n <= 4) where the reduction is along the innermost dimension: - * - * - block.x is the innermost dimension, i.e. dimension 0; - * - block.y and grid.y make up dimension 1; and - * - grid.x and grid z are the remaining two outer dimensions (if any) - * - * Reduction along other dimensions is handled in a separate kernel. - */ -template -__global__ void THCudaTensor_kernel_transformReduceInnermostDimIndex( - float *tgt1, float* tgt2, float *src_, - unsigned num_rows, unsigned row_size, - thrust::pair init, BinaryFunction binary_op) -{ - __shared__ float sbuf[32][16]; - __shared__ float ibuf[32][16]; - - for (unsigned block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { - unsigned row = block_row + threadIdx.y; - thrust::pair acc = init; - if (row < num_rows) { - float *src = src_ + row * row_size; - // Sequential reduction within a thread. - for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) { - acc = binary_op(thrust::make_pair(src[col], col+1), acc); - } - } - - sbuf[threadIdx.y][threadIdx.x] = acc.first; - ibuf[threadIdx.y][threadIdx.x] = acc.second; - - // Reduce intermediate values to single value. - float* sline = &sbuf[threadIdx.y][0]; - float* iline = &ibuf[threadIdx.y][0]; - for (unsigned s = 8; s > 0; s >>= 1) { - if (row < num_rows && threadIdx.x < s) { - thrust::pair arg1 = thrust::make_pair(sline[threadIdx.x], iline[threadIdx.x]); - thrust::pair arg2 = thrust::make_pair(sline[threadIdx.x + s], iline[threadIdx.x + s]); - thrust::pair res = binary_op(arg1, arg2); - sline[threadIdx.x] = res.first; - iline[threadIdx.x] = res.second; - } - __syncthreads(); - } - - if (row < num_rows && threadIdx.x == 0) { - tgt1[row] = sline[0]; - tgt2[row] = iline[0]; - } - __syncthreads(); - } -} - -template -__host__ void THCudaTensor_transformReduceInnermostDimIndex( - THCState *state, THCudaTensor *tgt1, THCudaTensor *tgt2, THCudaTensor *src, - thrust::pair init, BinaryFunction binary_op) -{ - unsigned ndim = THCudaTensor_nDimension(state, src); - unsigned num_rows = 1; - for (unsigned dim = 0; dim < ndim - 1; dim++) { - num_rows *= THCudaTensor_size(state, src, dim); - } - unsigned row_size = THCudaTensor_size(state, src, ndim - 1); - - dim3 threads(16, 32); - dim3 grid(min(1024, DIVUP(num_rows, threads.y))); - - THCudaTensor_kernel_transformReduceInnermostDimIndex<<>>( - THCudaTensor_data(state, tgt1), THCudaTensor_data(state, tgt2), - THCudaTensor_data(state, src), num_rows, row_size, init, binary_op); - cudaError errcode = cudaGetLastError(); - if(errcode != cudaSuccess) { - THError(cudaGetErrorString(errcode)); - } -} - -template -void THCudaTensor_reduceDimIndex(THCState *state, THCudaTensor *tgt1_, THCudaTensor *tgt2_, THCudaTensor *src, - long dimension, thrust::pair init, - BinaryFunction binary_op) -{ - THArgCheck(dimension >= 0 && dimension < THCudaTensor_nDimension(state, src), 3, "dimension out of range"); - - THLongStorage *dim = THCudaTensor_newSizeOf(state, src); - THLongStorage_set(dim, dimension, 1); - THCudaTensor_resize(state, tgt1_, dim, NULL); - THCudaTensor_resize(state, tgt2_, dim, NULL); - THLongStorage_free(dim); - - THCudaTensor *tgt1 = THCudaTensor_newContiguous(state, tgt1_); - THCudaTensor *tgt2 = THCudaTensor_newContiguous(state, tgt2_); - src = THCudaTensor_newContiguous(state, src); - - if(dimension == THCudaTensor_nDimension(state, src)-1) { - THCudaTensor_transformReduceInnermostDimIndex(state, tgt1, tgt2, src, init, binary_op); - } else { - THCudaTensor_transformReduceOuterDimIndex(state, tgt1, tgt2, src, dimension, init, binary_op); - } - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, tgt1, tgt1_); - THCudaTensor_freeCopyTo(state, tgt2, tgt2_); -} - -struct maxvalue_functor +void THCudaTensor_sum(THCState* state, THCudaTensor *self, THCudaTensor *src, long dimension) { - __host__ __device__ thrust::pair operator()(const thrust::pair &a, - const thrust::pair &b) - { - if (a.first > b.first) return a; - else return b; - } -}; + THCudaTensor_reduceDim( + state, self, src, + thrust::identity(), thrust::plus(), 0.0f, dimension); -void THCudaTensor_max(THCState *state, THCudaTensor *values, THCudaTensor *indices, THCudaTensor *src, long dimension) -{ - const float minfloat32 = -3.402823466e+38f; - thrust::pair init = thrust::make_pair(minfloat32, -1); - return THCudaTensor_reduceDimIndex(state, values, indices, src, dimension, init, - maxvalue_functor()); + THCudaCheck(cudaGetLastError()); } -struct minvalue_functor +void THCudaTensor_prod(THCState* state, THCudaTensor *self, THCudaTensor *src, long dimension) { - __host__ __device__ thrust::pair operator()(const thrust::pair &a, - const thrust::pair &b) - { - if (a.first < b.first) return a; - else return b; - } -}; - -void THCudaTensor_min(THCState *state, THCudaTensor *values, THCudaTensor *indices, THCudaTensor *src, long dimension) -{ - const float maxfloat32 = 3.402823466e+38f; - thrust::pair init = thrust::make_pair(maxfloat32, -1); - return THCudaTensor_reduceDimIndex(state, values, indices, src, dimension, init, - minvalue_functor()); -} - - -void THCudaTensor_addmv(THCState *state, THCudaTensor *r_, float beta, THCudaTensor *t, float alpha, THCudaTensor *mat, THCudaTensor *vec) -{ - if( (mat->nDimension != 2) || (vec->nDimension != 1) ) - THError("matrix and vector expected"); - - if( mat->size[1] != vec->size[0] ) - THError("size mismatch"); - - if(t->nDimension != 1) - THError("size mismatch"); - - if(t->size[0] != mat->size[0]) - THError("size mismatch"); - - if(r_ != t) - { - THCudaTensor_resizeAs(state, r_, t); - THCudaTensor_copy(state, r_, t); - } - - if(mat->stride[0] == 1) - { - THCudaBlas_gemv(state, 'n', mat->size[0], mat->size[1], - alpha, THCudaTensor_data(state, mat), mat->stride[1], - THCudaTensor_data(state, vec), vec->stride[0], - beta, THCudaTensor_data(state, r_), r_->stride[0]); - } - else if(mat->stride[1] == 1) - { - THCudaBlas_gemv(state, 't', mat->size[1], mat->size[0], - alpha, THCudaTensor_data(state, mat), mat->stride[0], - THCudaTensor_data(state, vec), vec->stride[0], - beta, THCudaTensor_data(state, r_), r_->stride[0]); - } - else - { - THCudaTensor *cmat = THCudaTensor_newContiguous(state, mat); - - THCudaBlas_gemv(state, 't', mat->size[1], mat->size[0], - alpha, THCudaTensor_data(state, cmat), cmat->stride[0], - THCudaTensor_data(state, vec), vec->stride[0], - beta, THCudaTensor_data(state, r_), r_->stride[0]); - - THCudaTensor_free(state, cmat); - } -} - -void THCudaTensor_addmm(THCState *state, THCudaTensor *r_, float beta, THCudaTensor *t, float alpha, THCudaTensor *m1, THCudaTensor *m2) -{ - char transpose_r, transpose_m1, transpose_m2; - THCudaTensor *r__, *m1_, *m2_; - - if( (m1->nDimension != 2) || (m2->nDimension != 2) ) - THError("matrix and matrix expected"); - - if(t->nDimension != 2) - THError("size mismatch"); - - if( (t->size[0] != m1->size[0]) || (t->size[1] != m2->size[1]) || (m1->size[1] != m2->size[0]) ) - THError("size mismatch"); - - if(t != r_) - { - THCudaTensor_resizeAs(state, r_, t); - THCudaTensor_copy(state, r_, t); - } - - /* r_ */ - if(r_->stride[0] == 1) - { - transpose_r = 'n'; - r__ = r_; - } - else if(r_->stride[1] == 1) - { - THCudaTensor *swap = m2; - m2 = m1; - m1 = swap; - transpose_r = 't'; - r__ = r_; - } - else - { - transpose_r = 'n'; - - r__ = THCudaTensor_newWithSize2d(state, r_->size[1], r_->size[0]); - THCudaTensor_copy(state, r__, r_); - THCudaTensor_transpose(state, r__, NULL, 0, 1); - } - - /* m1 */ - if(m1->stride[(transpose_r == 'n' ? 0 : 1)] == 1) - { - transpose_m1 = 'n'; - m1_ = m1; - } - else if(m1->stride[(transpose_r == 'n' ? 1 : 0)] == 1) - { - transpose_m1 = 't'; - m1_ = m1; - } - else - { - transpose_m1 = (transpose_r == 'n' ? 't' : 'n'); - m1_ = THCudaTensor_newContiguous(state, m1); - } - - /* m2 */ - if(m2->stride[(transpose_r == 'n' ? 0 : 1)] == 1) - { - transpose_m2 = 'n'; - m2_ = m2; - } - else if(m2->stride[(transpose_r == 'n' ? 1 : 0)] == 1) - { - transpose_m2 = 't'; - m2_ = m2; - } - else - { - transpose_m2 = (transpose_r == 'n' ? 't' : 'n'); - m2_ = THCudaTensor_newContiguous(state, m2); - } - - /* do the operation */ - THCudaBlas_gemm(state, - transpose_m1, - transpose_m2, - r__->size[(transpose_r == 'n' ? 0 : 1)], - r__->size[(transpose_r == 'n' ? 1 : 0)], - m1_->size[(transpose_r == 'n' ? 1 : 0)], - alpha, - THCudaTensor_data(state, m1_), - (transpose_m1 == 'n' ? m1_->stride[(transpose_r == 'n' ? 1 : 0)] : m1_->stride[(transpose_r == 'n' ? 0 : 1)]), - THCudaTensor_data(state, m2_), - (transpose_m2 == 'n' ? m2_->stride[(transpose_r == 'n' ? 1 : 0)] : m2_->stride[(transpose_r == 'n' ? 0 : 1)]), - beta, - THCudaTensor_data(state, r__), - r__->stride[(transpose_r == 'n' ? 1 : 0)]); - - /* free intermediate variables */ - if(m1_ != m1) - THCudaTensor_free(state, m1_); - - if(m2_ != m2) - THCudaTensor_free(state, m2_); - - if(r__ != r_) - THCudaTensor_freeCopyTo(state, r__, r_); -} - -void THCudaTensor_addr(THCState *state, THCudaTensor *r_, float beta, THCudaTensor *t, float alpha, THCudaTensor *vec1, THCudaTensor *vec2) -{ - if( (vec1->nDimension != 1) || (vec2->nDimension != 1) ) - THError("vector and vector expected"); - - if(t->nDimension != 2) - THError("size mismatch"); - - if( (t->size[0] != vec1->size[0]) || (t->size[1] != vec2->size[0]) ) - THError("size mismatch"); - - if(r_ != t) - { - THCudaTensor_resizeAs(state, r_, t); - THCudaTensor_copy(state, r_, t); - } - - if(beta != 1) - THCudaTensor_mul(state, r_, r_, beta); - - if(r_->stride[0] == 1) - { - THCudaBlas_ger(state, vec1->size[0], vec2->size[0], - alpha, THCudaTensor_data(state, vec1), vec1->stride[0], - THCudaTensor_data(state, vec2), vec2->stride[0], - THCudaTensor_data(state, r_), r_->stride[1]); - } - else if(r_->stride[1] == 1) - { - THCudaBlas_ger(state, vec2->size[0], vec1->size[0], - alpha, THCudaTensor_data(state, vec2), vec2->stride[0], - THCudaTensor_data(state, vec1), vec1->stride[0], - THCudaTensor_data(state, r_), r_->stride[0]); - } - else - { - THCudaTensor *cr = THCudaTensor_newClone(state, r_); - - THCudaBlas_ger(state, vec2->size[0], vec1->size[0], - alpha, THCudaTensor_data(state, vec2), vec2->stride[0], - THCudaTensor_data(state, vec1), vec1->stride[0], - THCudaTensor_data(state, cr), cr->stride[0]); - - THCudaTensor_freeCopyTo(state, cr, r_); - } -} - -void THCudaTensor_baddbmm(THCState *state, THCudaTensor *result, float beta, THCudaTensor *t, - float alpha, THCudaTensor *batch1, THCudaTensor *batch2) { - THArgCheck(THCudaTensor_nDimension(state, t) == 3, 4, "expected 3D tensor"); - THArgCheck(THCudaTensor_nDimension(state, batch1) == 3, 6, "expected 3D tensor"); - THArgCheck(THCudaTensor_nDimension(state, batch2) == 3, 7, "expected 3D tensor"); - THArgCheck(THCudaTensor_size(state, t, 0) == THCudaTensor_size(state, batch1, 0), 6, - "equal number of batches expected"); - THArgCheck(THCudaTensor_size(state, t, 0) == THCudaTensor_size(state, batch2, 0), 7, - "equal number of batches expected"); - THArgCheck(THCudaTensor_size(state, t, 1) == THCudaTensor_size(state, batch1, 1), 6, - "wrong matrix size"); - THArgCheck(THCudaTensor_size(state, t, 2) == THCudaTensor_size(state, batch2, 2), 7, - "wrong matrix size"); - THArgCheck(THCudaTensor_size(state, batch1, 2) == THCudaTensor_size(state, batch2, 1), 6, - "wrong matrix size"); - - if (t != result) { - THCudaTensor_resizeAs(state, result, t); - THCudaTensor_copy(state, result, t); - } - - bool transpose_result; - char transpose_batch1, transpose_batch2; - long lda, ldb, ldc; - THCudaTensor *result_, *batch1_, *batch2_; - if (result->stride[1] == 1) - { - transpose_result = false; - result_ = result; - ldc = result_->stride[2]; - } - else if (result->stride[2] == 1) - { - transpose_result = true; - - THCudaTensor *swap = batch2; - batch2 = batch1; - batch1 = swap; - - result_ = result; - ldc = result_->stride[1]; - } - else - { - transpose_result = false; - - result_ = THCudaTensor_newWithSize3d(state, result->size[0], result->size[2], result->size[1]); - THCudaTensor_copy(state, result_, result); - THCudaTensor_transpose(state, result_, NULL, 1, 2); - - ldc = result_->stride[2]; - } - - if (batch1->stride[transpose_result ? 2 : 1] == 1) - { - transpose_batch1 = 'n'; - batch1_ = batch1; - lda = batch1_->stride[transpose_result ? 1 : 2]; - } - else if (batch1->stride[transpose_result ? 1 : 2] == 1) - { - transpose_batch1 = 't'; - batch1_ = batch1; - lda = batch1_->stride[transpose_result ? 2 : 1]; - } - else - { - transpose_batch1 = transpose_result ? 'n' : 't'; - batch1_ = THCudaTensor_newContiguous(state, batch1); - lda = batch1_->stride[1]; - } - - if (batch2->stride[transpose_result ? 2 : 1] == 1) - { - transpose_batch2 = 'n'; - batch2_ = batch2; - ldb = batch2_->stride[transpose_result ? 1 : 2]; - } - else if (batch2->stride[transpose_result ? 1 : 2] == 1) - { - transpose_batch2 = 't'; - batch2_ = batch2; - ldb = batch2_->stride[transpose_result ? 2 : 1]; - } - else - { - transpose_batch2 = transpose_result ? 'n' : 't'; - batch2_ = THCudaTensor_newContiguous(state, batch2); - ldb = batch2_->stride[1]; - } - - // Compute pointers to matrices in each batch. - long num_batches = result_->size[0]; - size_t matrices_size = num_batches * sizeof(float*); - const float **matrices1 = (const float **)THAlloc(matrices_size); - const float **matrices2 = (const float **)THAlloc(matrices_size); - float **result_matrices = (float **)THAlloc(matrices_size); - for (int i = 0; i < num_batches; ++i) - { - matrices1[i] = THCudaTensor_data(state, batch1_) + i * batch1_->stride[0]; - matrices2[i] = THCudaTensor_data(state, batch2_) + i * batch2_->stride[0]; - result_matrices[i] = THCudaTensor_data(state, result_) + i * result_->stride[0]; - } - - // Copy pointers to device. - const float **d_matrices1, **d_matrices2; - float **d_result_matrices; - THCudaCheck(cudaMalloc(&d_matrices1, matrices_size)); - THCudaCheck(cudaMalloc(&d_matrices2, matrices_size)); - THCudaCheck(cudaMalloc(&d_result_matrices, matrices_size)); - - THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size, cudaMemcpyHostToDevice)); - THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, matrices_size, cudaMemcpyHostToDevice)); - THCudaCheck(cudaMemcpyAsync(d_result_matrices, result_matrices, matrices_size, cudaMemcpyHostToDevice)); - - THCudaBlas_gemmBatched( - state, - transpose_batch1, - transpose_batch2, - result_->size[transpose_result ? 2 : 1], - result_->size[transpose_result ? 1 : 2], - batch1_->size[transpose_result ? 1 : 2], - alpha, - d_matrices1, lda, - d_matrices2, ldb, - beta, - d_result_matrices, ldc, - num_batches); - - cudaFree(d_matrices1); - cudaFree(d_matrices2); - cudaFree(d_result_matrices); - THFree(matrices1); - THFree(matrices2); - THFree(result_matrices); - - if (batch1_ != batch1) - THCudaTensor_free(state, batch1_); - - if (batch2_ != batch2) - THCudaTensor_free(state, batch2_); - - if (result_ != result) - THCudaTensor_freeCopyTo(state, result_, result); -} - -#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(NAME, CFUNC) \ - struct NAME##_functor \ - { \ - __host__ __device__ float operator()(const float& x) const \ - { \ - return CFUNC(x); \ - } \ - }; \ - \ - void THCudaTensor_##NAME(THCState *state, THCudaTensor *self_, THCudaTensor *src) \ - { \ - THCudaTensor_resizeAs(state, self_, src); \ - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); \ - src = THCudaTensor_newContiguous(state, src); \ - long size = THCudaTensor_nElement(state, self); \ - thrust::device_ptr self_data(THCudaTensor_data(state, self)); \ - thrust::device_ptr src_data(THCudaTensor_data(state, src)); \ - \ - thrust::transform(src_data, src_data+size, self_data, NAME##_functor()); \ - \ - THCudaTensor_free(state, src); \ - THCudaTensor_freeCopyTo(state, self, self_); \ - } - -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log, log) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log1p, log1p) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(exp, exp) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(cos, cos) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(acos, acos) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(cosh, cosh) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(sin, sin) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(asin, asin) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(sinh, sinh) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(tan, tan) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(atan, atan) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(tanh, tanh) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(sqrt, sqrt) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(ceil, ceil) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(floor, floor) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(abs, fabs) -IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(round, roundf) - -struct pow_functor -{ - const float value; - - pow_functor(float value_) : value(value_) {} - - __host__ __device__ float operator()(const float& x) const - { - return pow(x, value); - } -}; - -void THCudaTensor_pow(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) -{ - THCudaTensor_resizeAs(state, self_, src); - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - src = THCudaTensor_newContiguous(state, src); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src_data(THCudaTensor_data(state, src)); - - thrust::transform(src_data, src_data+size, self_data, pow_functor(value)); - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - -struct tpow_functor -{ - const float value; - - tpow_functor(float value_) : value(value_) {} - - __host__ __device__ float operator()(const float& x) const - { - return pow(value, x); - } -}; - -void THCudaTensor_tpow(THCState *state, THCudaTensor *self_, float value, THCudaTensor *src) -{ - THCudaTensor_resizeAs(state, self_, src); - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - src = THCudaTensor_newContiguous(state, src); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src_data(THCudaTensor_data(state, src)); - - thrust::transform(src_data, src_data+size, self_data, tpow_functor(value)); - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - -struct atan2_functor -{ - __host__ __device__ float operator()(const float& x, const float& y) const - { - return atan2f(x, y); - } -}; - -void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx, THCudaTensor *ty) -{ - THCudaTensor_resizeAs(state, self_, tx); - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - tx = THCudaTensor_newContiguous(state, tx); - ty = THCudaTensor_newContiguous(state, ty); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr tx_data(THCudaTensor_data(state, tx)); - thrust::device_ptr ty_data(THCudaTensor_data(state, ty)); - - thrust::transform(tx_data, tx_data+size, ty_data, self_data, atan2_functor()); - - THCudaTensor_free(state, tx); - THCudaTensor_free(state, ty); - THCudaTensor_freeCopyTo(state, self, self_); -} - - -struct clamp_functor -{ - const float min_value; - const float max_value; - - clamp_functor(float min_value_, float max_value_) : min_value(min_value_), max_value(max_value_) {} - - __host__ __device__ float operator()(const float& x) const - { - if (x < min_value) { - return min_value; - } - if (x > max_value) { - return max_value; - } - return x; - } -}; - -void THCudaTensor_clamp(THCState *state, THCudaTensor *self_, THCudaTensor *src, float min_value, - float max_value) -{ - THArgCheck(THCudaTensor_nElement(state, self_) == THCudaTensor_nElement(state, src), 2, "sizes do not match"); - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - src = THCudaTensor_newContiguous(state, src); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src_data(THCudaTensor_data(state, src)); - - thrust::transform(src_data, src_data+size, self_data, clamp_functor(min_value, - max_value)); - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - - -struct sign_functor -{ - __device__ float operator()(const float &v) const { - return (v > 0) - (v < 0); - } -}; - - -void THCudaTensor_sign(THCState *state, THCudaTensor *self_, THCudaTensor *src) -{ - THCudaTensor_resizeAs(state, self_, src); - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - long size = THCudaTensor_nElement(state, self); - src = THCudaTensor_newContiguous(state, src); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src_data(THCudaTensor_data(state, src)); - - thrust::transform(src_data, src_data+size, self_data, sign_functor()); - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - -float THCudaTensor_meanall(THCState *state, THCudaTensor *self) -{ - THArgCheck(self->nDimension > 0, 1, "empty Tensor"); - return THCudaTensor_sumall(state, self)/THCudaTensor_nElement(state, self); -} - -void -THCudaTensor_mean(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim) -{ - THCudaTensor_sum(state, self, src, dim); - THCudaTensor_div(state, self, self, THCudaTensor_size(state, src, dim)); -} - -struct square_functor -{ - const float mean; - - square_functor(float mean_) : mean(mean_) {} - - __host__ __device__ float operator()(const float& x) const - { - return (x-mean)*(x-mean); - } -}; - -float THCudaTensor_varall(THCState *state, THCudaTensor *self) -{ - self = THCudaTensor_newContiguous(state, self); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - - float mean = THCudaTensor_meanall(state, self); - float result = thrust::transform_reduce(self_data, self_data+size, square_functor(mean), (float)0, thrust::plus()); - - result = result/(THCudaTensor_nElement(state, self)-1); - - THCudaTensor_free(state, self); - return result; -} - -float THCudaTensor_stdall(THCState *state, THCudaTensor *self) -{ - return sqrt(THCudaTensor_varall(state, self)); -} - -// Given the sum of values and the sum of squares, compute the variance or standard deviation. -template -__forceinline__ __device__ float THCudaTensor_computeVar(float sum, float sum2, unsigned row_size) { - if (flag) { - sum /= row_size; - sum2 /= row_size; - sum2 -= sum * sum; - sum2 = (sum2 < 0 ? 0 : sum2); - } - else { - sum /= row_size; - sum2 /= row_size - 1; - sum2 -= ((float)row_size) / ((float)(row_size - 1)) * sum * sum; - sum2 = (sum2 < 0 ? 0 : sum2); - } - if (apply_sqrt) - return sqrt(sum2); - else - return sum2; -} - -/* Compute the variance (or standard deviation) along an outer dimension of a tensor. - * - * - num_orows is the size of the flattened outer dimensions; - * - num_irows is the size of the flattened inner dimensions; - * - row_size is the size of the dimension along which to compute the variance; - * - if flag is set, normalize by `row_size` instead of `row_size - 1` - * - if apply_sqrt is set, compute the standard deviation instead of variance - * - * The dimensions to the outside and inside of the specified dimension are considered as flattened. - * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened - * outer dimensions, which contains several "inner rows"). - * Each thread processes a single inner row at a time. - */ -template -__global__ void THCudaTensor_kernel_varOuterDim(float *tgt, float *src_, unsigned num_orows, unsigned num_irows, unsigned row_size) -{ - for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { - for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { - float *src = src_ + orow * row_size * num_irows + irow; - float sum = 0, sum2 = 0; - - for (unsigned col = 0; col < row_size; ++col) { - float val = *src; - sum += val; - sum2 += val * val; - - src += num_irows; - } - - tgt[orow * num_irows + irow] = THCudaTensor_computeVar(sum, sum2, row_size); - } - } -} - -template -__host__ void THCudaTensor_varOuterDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, long dimension, int flag) -{ - unsigned ndim = THCudaTensor_nDimension(state, src); - // Treat all outer dimensions (i.e. dim < dimension) as one. - unsigned num_orows = 1; - for (unsigned dim = 0; dim < dimension; dim++) { - num_orows *= THCudaTensor_size(state, src, dim); - } - unsigned row_size = THCudaTensor_size(state, src, dimension); - // Treat all inner dimensions (i.e. dim > dimension) as one. - unsigned num_irows = 1; - for (unsigned dim = dimension + 1; dim < ndim; dim++) { - num_irows *= THCudaTensor_size(state, src, dim); - } - - dim3 threads(min(512, num_irows)); - unsigned maxGridDim = 1024; - dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, DIVUP(num_irows, threads.x))); - - if (flag) { - THCudaTensor_kernel_varOuterDim<<>>( - THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_orows, num_irows, row_size); - } else { - THCudaTensor_kernel_varOuterDim<<>>( - THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_orows, num_irows, row_size); - } - cudaError errcode = cudaGetLastError(); - if (errcode != cudaSuccess) { - THError(cudaGetErrorString(errcode)); - } -} - - -/* Compute the variance (or standard deviation) of the innermost dimension of a tensor. - * - * - num_rows is the size of the flattened outer dimensions; - * - row_size is the size of the innermost dimension; - * - if flag is set, normalize by `row_size` instead of `row_size - 1` - * - if apply_sqrt is set, compute the standard deviation instead of variance - * - * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is - * considered as having 'num_rows' rows of size 'row_size'. - * Each thread block processes one or more sets of contiguous rows (processing multiple rows - * per thread block is quicker than processing a single row, especially for short rows). - */ -template -__global__ void THCudaTensor_kernel_varInnermostDim(float *tgt, float *src_, unsigned num_rows, unsigned row_size) -{ - __shared__ float ssum[32][16]; - __shared__ float ssum2[32][16]; - - for (unsigned block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { - unsigned row = block_row + threadIdx.y; - float sum = 0, sum2 = 0; - if (row < num_rows) { - float *src = src_ + row * row_size; - // Sequential reduction within a thread. - for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) { - float val = src[col]; - sum += val; - sum2 += val * val; - } - } - ssum[threadIdx.y][threadIdx.x] = sum; - ssum2[threadIdx.y][threadIdx.x] = sum2; - __syncthreads(); - - // Reduce intermediate values to single value. - for (unsigned s = 8; s > 1; s >>= 1) { - if (row < num_rows && threadIdx.x < s) { - ssum[threadIdx.y][threadIdx.x] += ssum[threadIdx.y][threadIdx.x + s]; - ssum2[threadIdx.y][threadIdx.x] += ssum2[threadIdx.y][threadIdx.x + s]; - } - __syncthreads(); - } - - if (row < num_rows && threadIdx.x == 0) { - sum = ssum[threadIdx.y][0] + ssum[threadIdx.y][1]; - sum2 = ssum2[threadIdx.y][0] + ssum2[threadIdx.y][1]; - tgt[row] = THCudaTensor_computeVar(sum, sum2, row_size); - } - __syncthreads(); - } -} - -template -__host__ void THCudaTensor_varInnermostDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, int flag) -{ - unsigned ndim = THCudaTensor_nDimension(state, src); - // Treat all outer dimensions as a single dimension. - unsigned num_rows = 1; - for (unsigned dim = 0; dim < ndim - 1; dim++) { - num_rows *= THCudaTensor_size(state, src, dim); - } - unsigned row_size = THCudaTensor_size(state, src, ndim - 1); - - // From limited testing, 16x32 seemed a good compromise for handling both long and short dimensions. - dim3 threads(16, 32); - dim3 grid(min(1024, DIVUP(num_rows, threads.y))); - - if (flag) { - THCudaTensor_kernel_varInnermostDim<<>>( - THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_rows, row_size); - } else { - THCudaTensor_kernel_varInnermostDim<<>>( - THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_rows, row_size); - } - cudaError errcode = cudaGetLastError(); - if (errcode != cudaSuccess) { - THError(cudaGetErrorString(errcode)); - } -} - -void THCudaTensor_var(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag) -{ - THLongStorage *dim = THCudaTensor_newSizeOf(state, src); - THLongStorage_set(dim, dimension, 1); - THCudaTensor_resize(state, self_, dim, NULL); - THLongStorage_free(dim); - - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - src = THCudaTensor_newContiguous(state, src); - - if (dimension == THCudaTensor_nDimension(state, src) - 1) { - THCudaTensor_varInnermostDim(state, self, src, flag); - } else { - THCudaTensor_varOuterDim(state, self, src, dimension, flag); - } - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - -void THCudaTensor_std(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag) -{ - THLongStorage *dim = THCudaTensor_newSizeOf(state, src); - THLongStorage_set(dim, dimension, 1); - THCudaTensor_resize(state, self_, dim, NULL); - THLongStorage_free(dim); - - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - src = THCudaTensor_newContiguous(state, src); - - if (dimension == THCudaTensor_nDimension(state, src) - 1) { - THCudaTensor_varInnermostDim(state, self, src, flag); - } else { - THCudaTensor_varOuterDim(state, self, src, dimension, flag); - } - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - - -template -void THCudaTensor_logicalValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, Op op) -{ - THCudaTensor_resizeAs(state, self_, src); - - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - long size = THCudaTensor_nElement(state, self); - src = THCudaTensor_newContiguous(state, src); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src_data(THCudaTensor_data(state, src)); - - thrust::transform(src_data, src_data+size, self_data, op); - - THCudaTensor_free(state, src); - THCudaTensor_freeCopyTo(state, self, self_); -} - - -struct partial_less_functor -{ - const float rhs; - partial_less_functor(float rhs) : rhs(rhs) {} - __host__ __device__ bool operator()(const float &lhs) const {return lhs < rhs;} -}; - - -void THCudaTensor_ltValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) -{ - THCudaTensor_logicalValue(state, self_, src, partial_less_functor(value)); -} - - -struct partial_greater_functor -{ - const float rhs; - partial_greater_functor(float rhs) : rhs(rhs) {} - __host__ __device__ bool operator()(const float &lhs) const {return lhs > rhs;} -}; - - -void THCudaTensor_gtValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) -{ - THCudaTensor_logicalValue(state, self_, src, partial_greater_functor(value)); -} - - -struct partial_less_equal_functor -{ - const float rhs; - partial_less_equal_functor(float rhs) : rhs(rhs) {} - __host__ __device__ bool operator()(const float &lhs) const {return lhs <= rhs;} -}; - - -void THCudaTensor_leValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) -{ - THCudaTensor_logicalValue(state, self_, src, partial_less_equal_functor(value)); -} - - -struct partial_greater_equal_functor -{ - const float rhs; - partial_greater_equal_functor(float rhs) : rhs(rhs) {} - __host__ __device__ bool operator()(const float &lhs) const {return lhs >= rhs;} -}; - - -void THCudaTensor_geValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) -{ - THCudaTensor_logicalValue(state, self_, src, partial_greater_equal_functor(value)); -} - - -struct partial_equal_functor -{ - const float rhs; - partial_equal_functor(float rhs) : rhs(rhs) {} - __host__ __device__ bool operator()(const float &lhs) const {return lhs == rhs;} -}; - - -void THCudaTensor_eqValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) -{ - THCudaTensor_logicalValue(state, self_, src, partial_equal_functor(value)); -} - - -struct partial_not_equal_functor -{ - const float rhs; - partial_not_equal_functor(float rhs) : rhs(rhs) {} - __host__ __device__ bool operator()(const float &lhs) const {return lhs != rhs;} -}; - - -void THCudaTensor_neValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) -{ - THCudaTensor_logicalValue(state, self_, src, partial_not_equal_functor(value)); -} - - -template -void THCudaTensor_logicalTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2, Op op) -{ - THCudaTensor_resizeAs(state, self_, src1); - THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), 3, "size do not match"); - - THCudaTensor *self = THCudaTensor_newContiguous(state, self_); - long size = THCudaTensor_nElement(state, self); - src1 = THCudaTensor_newContiguous(state, src1); - src2 = THCudaTensor_newContiguous(state, src2); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src1_data(THCudaTensor_data(state, src1)); - thrust::device_ptr src2_data(THCudaTensor_data(state, src2)); - - thrust::transform(src1_data, src1_data+size, src2_data, self_data, op); - - THCudaTensor_free(state, src1); - THCudaTensor_free(state, src2); - THCudaTensor_freeCopyTo(state, self, self_); -} - - -void THCudaTensor_ltTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) -{ - THCudaTensor_logicalTensor(state, self_, src1, src2, thrust::less()); -} - - -void THCudaTensor_gtTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) -{ - THCudaTensor_logicalTensor(state, self_, src1, src2, thrust::greater()); -} - - -void THCudaTensor_leTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) -{ - THCudaTensor_logicalTensor(state, self_, src1, src2, thrust::less_equal()); -} - - -void THCudaTensor_geTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) -{ - THCudaTensor_logicalTensor(state, self_, src1, src2, thrust::greater_equal()); -} - - -void THCudaTensor_eqTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) -{ - THCudaTensor_logicalTensor(state, self_, src1, src2, thrust::equal_to()); -} - - -void THCudaTensor_neTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) -{ - THCudaTensor_logicalTensor(state, self_, src1, src2, thrust::not_equal_to()); -} - - -struct norm_functor -{ - const float exponent; - - norm_functor(float exponent_) : exponent(exponent_) {} - - __host__ __device__ float operator()(const float& x) const - { - return pow(fabs(x), exponent); - } -}; - - -float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value) -{ - self = THCudaTensor_newContiguous(state, self); - long size = THCudaTensor_nElement(state, self); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - - float result; - if(value == 0.0f) { - result = thrust::transform_reduce(self_data, self_data+size, partial_not_equal_functor(0.0f), (float)0, thrust::plus()); - } else { - result = thrust::transform_reduce(self_data, self_data+size, norm_functor(value), (float)0, thrust::plus()); - result = pow(result, (float)1.0/value); - } - - THCudaTensor_free(state, self); - return result; -} - -void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension) -{ - if (value == 0.0f) { - THCudaTensor_transformReduceDim(state, self, src, dimension, partial_not_equal_functor(0.0f), (float)0, thrust::plus()); - } else { - THCudaTensor_transformReduceDim(state, self, src, dimension, norm_functor(value), (float)0, thrust::plus()); - THCudaTensor_pow(state, self, self, 1/value); - } -} - -__global__ void THCudaTensor_kernel_renorm(float *data, const float value, const long size, const float maxnorm) -{ - __shared__ float buffer[32]; - long tx = threadIdx.x; - long bx = blockIdx.x; - long step = blockDim.x; - float *row = data + size*bx; - - buffer[tx] = 0; - - // get norm of axis - for (long i=tx; i> 1; stride > 0; stride >>= 1) - { - __syncthreads(); - if (tx < stride) - buffer[tx] += buffer[tx+stride]; - } - // clip norms - __syncthreads(); - float norm = pow(buffer[0], 1/value); - if (norm > maxnorm) - { - norm = maxnorm / (norm + 1e-7); - // renormalize - for (long i=tx; isize[0]; - - THArgCheck(dimension >= 0 && dimension < THCudaTensor_nDimension(state, src), 3, "invalid dimension"); - THArgCheck(value > 0, 2, "non-positive-norm not supported"); - THArgCheck(THCudaTensor_nDimension(state, src) > 1, 1, "need at least 2 dimensions"); - - dim3 grid(data->size[0]); - dim3 threads(32); - - THCudaTensor_kernel_renorm<<>>(THCudaTensor_data(state, data), value, size, maxnorm); - - cudaError errcode = cudaGetLastError(); - if(errcode != cudaSuccess) - THError(cudaGetErrorString(errcode)); - - THCudaTensor_free(state, src_); - self_ = THCudaTensor_newTranspose(state, data, dimension, 0); - THCudaTensor_resizeAs(state, self, self_); - THCudaTensor_freeCopyTo(state, self_, self); - THCudaTensor_free(state, data); -} - -struct dist_functor -{ - const float exponent; - - dist_functor(float exponent_) : exponent(exponent_) {} - - __host__ __device__ float operator()(const float& x, const float& y) const - { - return pow(fabs(x-y), exponent); - } -}; - -float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value) -{ - self = THCudaTensor_newContiguous(state, self); - long size = THCudaTensor_nElement(state, self); - src = THCudaTensor_newContiguous(state, src); - thrust::device_ptr self_data(THCudaTensor_data(state, self)); - thrust::device_ptr src_data(THCudaTensor_data(state, src)); - - float result = thrust::inner_product(self_data, self_data+size, src_data, (float) 0,thrust::plus(), dist_functor(value)); - - THCudaTensor_free(state, src); - THCudaTensor_free(state, self); - - return pow(result, (float)1.0/value); -} - -void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size) -{ - THCudaTensor_resize(state, r_, size, NULL); - THCudaTensor_uniform(state, r_, 0, 1); -} - -void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size) -{ - THCudaTensor_resize(state, r_, size, NULL); - THCudaTensor_normal(state, r_, 0, 1); -} - -__global__ void THCudaTensor_kernel_indexFill( - float *tensor, long* stride, float *index, long src_nDim, - int dim, long idx_size, long tensor_size, long size_dim, float val -) -{ - int thread_idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; - - long flat_size = tensor_size / idx_size; - - if (thread_idx < flat_size) - { - long coeff = 0; - for (int i=0; i dim) - { - coeff = leftover / stride[d]; - leftover -= coeff * stride[d]; - srcIdx += coeff * stride[d]; - } - } - tensor[srcIdx + (int)((index[i])-1)*stride[dim]] = val; - } - } -} - -__global__ void THCudaTensor_kernel_indexCopy( - float *res, float *src, long* res_stride, float *index, - long res_nDim, int dim, long idx_size, long src_size, long size_dim -) -{ - int thread_idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; - - long flat_size = src_size / idx_size; - - if (thread_idx < flat_size) - { - long coeff = 0; - for (int i=0; i dim) - { - coeff = leftover / res_stride[d]; - leftover -= coeff * res_stride[d]; - targetIdx += coeff * res_stride[d]; - resIdx += coeff * res_stride[d]; - } - } - res[resIdx + ((int)(index[i])-1)*res_stride[dim]] = src[targetIdx + i*res_stride[dim]]; - } - } -} - -void THCudaTensor_indexCopy(THCState *state, THCudaTensor *res_, int dim, THLongTensor *indices, THCudaTensor *src) -{ - THCudaTensor *indices_; - long *stride_; - long nIndex = indices->size[0]; - long nRes; - - THArgCheck(indices->nDimension == 1, 3, "expecting vector of indices"); - THArgCheck(dim < src->nDimension, 4, "Indexing dim is out of bounds"); - THArgCheck(src->nDimension > 0, 2, "Source tensor is empty"); - THArgCheck(nIndex == src->size[dim], 4, "length of src.size[dim] is not equal to length of indices"); - - src = THCudaTensor_newContiguous(state, src); - indices_ = THCudaTensor_newWithSize1d(state, nIndex); - THCudaTensor_copyLong(state, indices_, indices); - - nRes = THCudaTensor_nElement(state, res_); - dim3 nthreads(16, 16); - dim3 nblocks(ceil((float)nRes / nIndex / (16*16))); - - THCudaCheck(cudaMalloc((void**)&stride_, res_->nDimension * sizeof(long))); - THCudaCheck(cudaMemcpy(stride_, res_->stride, res_->nDimension * sizeof(long), cudaMemcpyHostToDevice)); - - THCudaTensor_kernel_indexCopy<<>>( - THCudaTensor_data(state, res_), THCudaTensor_data(state, src), - stride_, THCudaTensor_data(state, indices_), - res_->nDimension, dim, nIndex, - THCudaTensor_nElement(state, src), res_->size[dim] - ); - - THCudaCheck(cudaFree(stride_)); - THCudaTensor_free(state, indices_); - THCudaTensor_free(state, src); -} - - -void THCudaTensor_indexFill(THCState *state, THCudaTensor *res_, int dim, THLongTensor *indices, float val) -{ - THCudaTensor *indices_; - long *stride_; - long nIndex = indices->size[0]; - long nRes; - - THArgCheck(indices->nDimension == 1, 3, "Index is supposed to be a vector"); - THArgCheck(dim < res_->nDimension,4,"Indexing dim is out of bounds"); - THArgCheck(res_->nDimension > 0, 2, "Source tensor is empty"); - - indices_ = THCudaTensor_newWithSize1d(state, nIndex); - THCudaTensor_copyLong(state, indices_, indices); - - nRes = THCudaTensor_nElement(state, res_) / res_->size[dim] * nIndex; - - - dim3 nthreads(16, 16); - dim3 nblocks(ceil((float)nRes / nIndex / (16*16))); - - THCudaCheck(cudaMalloc((void**)&stride_, res_->nDimension * sizeof(long))); - THCudaCheck(cudaMemcpy(stride_, res_->stride, res_->nDimension * sizeof(long), cudaMemcpyHostToDevice)); - - THCudaTensor_kernel_indexFill<<>>( - THCudaTensor_data(state, res_), stride_, THCudaTensor_data(state, indices_), - res_->nDimension, dim, nIndex, nRes, res_->size[dim], val - ); - - THCudaCheck(cudaFree(stride_)); - THCudaTensor_free(state, indices_); -} - -__global__ void THCudaTensor_kernel_indexSelect( - float *tensor, float *src, long* src_stride, float *index, - long src_nDim, int dim, long idx_size, long tensor_size, long size_dim -) -{ - int thread_idx = blockIdx.x * blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x; - - long flat_size = tensor_size / idx_size; - - if (thread_idx < flat_size) - { - long coeff = 0; - for (int i=0; i dim) - { - coeff = leftover / src_stride[d]; - leftover -= coeff * src_stride[d]; - targetIdx += coeff * src_stride[d]; - srcIdx += coeff * src_stride[d]; - } - } - tensor[targetIdx + i*src_stride[dim]] = src[srcIdx + ((int)(index[i])-1)*src_stride[dim]]; - } - } -} - - -void THCudaTensor_indexSelect(THCState *state, THCudaTensor *res_, THCudaTensor *src, int dim, THLongTensor *indices) -{ - THLongStorage *newSize; - THCudaTensor *indices_; - long *stride_; - long nIndex = indices->size[0]; - long nRes; - - THArgCheck(indices->nDimension == 1, 3, "expecting vector of indices"); - THArgCheck(dim < src->nDimension, 4, "Indexing dim is out of bounds"); - THArgCheck(src->nDimension > 0, 2, "Source tensor is empty"); - - newSize = THLongStorage_newWithSize(src->nDimension); - THLongStorage_rawCopy(newSize, src->size); - newSize->data[dim] = nIndex; - THCudaTensor_resize(state, res_, newSize, NULL); - THLongStorage_free(newSize); - - indices_ = THCudaTensor_newWithSize1d(state, nIndex); - THCudaTensor_copyLong(state, indices_, indices); - - nRes = THCudaTensor_nElement(state, res_); - dim3 nthreads(16, 16); - dim3 nblocks(ceil((float)nRes / nIndex / (16*16))); - - THCudaCheck(cudaMalloc((void**)&stride_, src->nDimension * sizeof(long))); - THCudaCheck(cudaMemcpy(stride_, src->stride, src->nDimension * sizeof(long), cudaMemcpyHostToDevice)); - - THCudaTensor_kernel_indexSelect<<>>( - THCudaTensor_data(state, res_), THCudaTensor_data(state, src), - stride_, THCudaTensor_data(state, indices_), - src->nDimension, dim, indices->size[0], nRes, src->size[dim] - ); + THCudaTensor_reduceDim( + state, self, src, + thrust::identity(), thrust::multiplies(), 1.0f, dimension); - THCudaCheck(cudaFree(stride_)); - THCudaTensor_free(state, indices_); + THCudaCheck(cudaGetLastError()); } diff --git a/lib/THC/THCTensorMath.h b/lib/THC/THCTensorMath.h index 2ea21365..8d6984fa 100644 --- a/lib/THC/THCTensorMath.h +++ b/lib/THC/THCTensorMath.h @@ -99,5 +99,12 @@ THC_API void THCudaTensor_indexCopy(THCState *state, THCudaTensor *res_, int dim THC_API void THCudaTensor_indexFill(THCState *state, THCudaTensor *tensor, int dim, THLongTensor *index, float val); THC_API void THCudaTensor_indexSelect(THCState *state, THCudaTensor *tensor, THCudaTensor *src, int dim, THLongTensor *index); +THC_API void THCudaTensor_maskedFill(THCState* state, THCudaTensor *tensor, THCudaTensor *mask, float value); +THC_API void THCudaTensor_maskedCopy(THCState* state, THCudaTensor *tensor, THCudaTensor *mask, THCudaTensor *src); +THC_API void THCudaTensor_maskedSelect(THCState* state, THCudaTensor *tensor, THCudaTensor *src, THCudaTensor *mask); + +THC_API void THCudaTensor_maskedFillByte(THCState* state, THCudaTensor *tensor, THByteTensor *mask, float value); +THC_API void THCudaTensor_maskedCopyByte(THCState* state, THCudaTensor *tensor, THByteTensor *mask, THCudaTensor *src); +THC_API void THCudaTensor_maskedSelectByte(THCState* state, THCudaTensor *tensor, THCudaTensor *src, THByteTensor *mask); #endif diff --git a/lib/THC/THCTensorMath2.cu b/lib/THC/THCTensorMath2.cu new file mode 100644 index 00000000..c1c71e91 --- /dev/null +++ b/lib/THC/THCTensorMath2.cu @@ -0,0 +1,569 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +struct TensorPowOp { + TensorPowOp(float v) : val(v) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = powf(*in, val); + } + + __device__ __forceinline__ void operator()(float* v) { + *v = powf(*v, val); + } + + const float val; +}; + +void THCudaTensor_pow(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) +{ + if (self_ == src) { + if (!THCudaTensor_pointwiseApply1(state, self_, TensorPowOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCudaTensor_resizeAs(state, self_, src); + + if (!THCudaTensor_pointwiseApply2(state, self_, src, TensorPowOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + +struct TensorTPowOp { + TensorTPowOp(float v) : val(v) {} + + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = powf(val, *in); + } + + __device__ __forceinline__ void operator()(float* v) { + *v = powf(val, *v); + } + + const float val; +}; + +void THCudaTensor_tpow(THCState *state, THCudaTensor *self_, float value, THCudaTensor *src) +{ + if (self_ == src) { + if (!THCudaTensor_pointwiseApply1(state, self_, TensorTPowOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCudaTensor_resizeAs(state, self_, src); + + if (!THCudaTensor_pointwiseApply2(state, self_, src, TensorTPowOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + +struct TensorATan2Op { + __device__ __forceinline__ void operator()(float* out, float* a, float* b) { + *out = atan2f(*a, *b); + } +}; + +void THCudaTensor_atan2(THCState *state, THCudaTensor *self_, THCudaTensor *tx, THCudaTensor *ty) +{ + THArgCheck(THCudaTensor_nElement(state, tx) == + THCudaTensor_nElement(state, ty), 3, "sizes do not match"); + THCudaTensor_resizeAs(state, self_, tx); + + if (!THCudaTensor_pointwiseApply3(state, self_, tx, ty, TensorATan2Op())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + + THCudaCheck(cudaGetLastError()); +} + +struct TensorClampOp { + TensorClampOp(float min, float max) : minValue(min), maxValue(max) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = max(min(*in, maxValue), minValue); + } + + __device__ __forceinline__ void operator()(float* v) { + *v = max(min(*v, maxValue), minValue); + } + + const float minValue; + const float maxValue; +}; + +void THCudaTensor_clamp(THCState *state, THCudaTensor *self_, THCudaTensor *src, float min_value, + float max_value) +{ + if (self_ == src) { + if (!THCudaTensor_pointwiseApply1(state, self_, TensorClampOp(min_value, max_value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCudaTensor_resizeAs(state, self_, src); + + if (!THCudaTensor_pointwiseApply2(state, self_, src, TensorClampOp(min_value, max_value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + +struct TensorSignOp { + __device__ __forceinline__ void operator()(float* out, float* in) { + float orig = *in; + *out = (orig > 0) - (orig < 0); + } + + __device__ __forceinline__ void operator()(float* v) { + float orig = *v; + *v = (orig > 0) - (orig < 0); + } +}; + +void THCudaTensor_sign(THCState *state, THCudaTensor *self_, THCudaTensor *src) +{ + if (self_ == src) { + if (!THCudaTensor_pointwiseApply1(state, self_, TensorSignOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCudaTensor_resizeAs(state, self_, src); + + if (!THCudaTensor_pointwiseApply2(state, self_, src, TensorSignOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + +float THCudaTensor_meanall(THCState *state, THCudaTensor *self) +{ + THArgCheck(self->nDimension > 0, 1, "empty Tensor"); + return THCudaTensor_sumall(state, self)/THCudaTensor_nElement(state, self); +} + +void +THCudaTensor_mean(THCState *state, THCudaTensor *self, THCudaTensor *src, long dim) +{ + THCudaTensor_sum(state, self, src, dim); + THCudaTensor_div(state, self, self, THCudaTensor_size(state, src, dim)); +} + +struct square_functor +{ + const float mean; + + square_functor(float mean_) : mean(mean_) {} + + __host__ __device__ float operator()(const float& x) const + { + return (x-mean)*(x-mean); + } +}; + +float THCudaTensor_varall(THCState *state, THCudaTensor *self) +{ + self = THCudaTensor_newContiguous(state, self); + long size = THCudaTensor_nElement(state, self); + thrust::device_ptr self_data(THCudaTensor_data(state, self)); + + float mean = THCudaTensor_meanall(state, self); + float result = thrust::transform_reduce(self_data, self_data+size, square_functor(mean), (float)0, thrust::plus()); + + result = result/(THCudaTensor_nElement(state, self)-1); + + THCudaTensor_free(state, self); + return result; +} + +float THCudaTensor_stdall(THCState *state, THCudaTensor *self) +{ + return sqrt(THCudaTensor_varall(state, self)); +} + +// Given the sum of values and the sum of squares, compute the variance or standard deviation. +template +__forceinline__ __device__ float THCudaTensor_computeVar(float sum, float sum2, unsigned row_size) { + if (flag) { + sum /= row_size; + sum2 /= row_size; + sum2 -= sum * sum; + sum2 = (sum2 < 0 ? 0 : sum2); + } + else { + sum /= row_size; + sum2 /= row_size - 1; + sum2 -= ((float)row_size) / ((float)(row_size - 1)) * sum * sum; + sum2 = (sum2 < 0 ? 0 : sum2); + } + if (apply_sqrt) + return sqrt(sum2); + else + return sum2; +} + +/* Compute the variance (or standard deviation) along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to compute the variance; + * - if flag is set, normalize by `row_size` instead of `row_size - 1` + * - if apply_sqrt is set, compute the standard deviation instead of variance + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void THCudaTensor_kernel_varOuterDim(float *tgt, float *src_, unsigned num_orows, unsigned num_irows, unsigned row_size) +{ + for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + float *src = src_ + orow * row_size * num_irows + irow; + float sum = 0, sum2 = 0; + + for (unsigned col = 0; col < row_size; ++col) { + float val = *src; + sum += val; + sum2 += val * val; + + src += num_irows; + } + + tgt[orow * num_irows + irow] = THCudaTensor_computeVar(sum, sum2, row_size); + } + } +} + +template +__host__ void THCudaTensor_varOuterDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, long dimension, int flag) +{ + unsigned ndim = THCudaTensor_nDimension(state, src); + // Treat all outer dimensions (i.e. dim < dimension) as one. + unsigned num_orows = 1; + for (unsigned dim = 0; dim < dimension; dim++) { + num_orows *= THCudaTensor_size(state, src, dim); + } + unsigned row_size = THCudaTensor_size(state, src, dimension); + // Treat all inner dimensions (i.e. dim > dimension) as one. + unsigned num_irows = 1; + for (unsigned dim = dimension + 1; dim < ndim; dim++) { + num_irows *= THCudaTensor_size(state, src, dim); + } + + dim3 threads(min(512, num_irows)); + unsigned maxGridDim = 1024; + dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, DIVUP(num_irows, threads.x))); + + if (flag) { + THCudaTensor_kernel_varOuterDim<<>>( + THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_orows, num_irows, row_size); + } else { + THCudaTensor_kernel_varOuterDim<<>>( + THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_orows, num_irows, row_size); + } + cudaError errcode = cudaGetLastError(); + if (errcode != cudaSuccess) { + THError(cudaGetErrorString(errcode)); + } +} + + +/* Compute the variance (or standard deviation) of the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * - if flag is set, normalize by `row_size` instead of `row_size - 1` + * - if apply_sqrt is set, compute the standard deviation instead of variance + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__global__ void THCudaTensor_kernel_varInnermostDim(float *tgt, float *src_, unsigned num_rows, unsigned row_size) +{ + __shared__ float ssum[32][16]; + __shared__ float ssum2[32][16]; + + for (unsigned block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { + unsigned row = block_row + threadIdx.y; + float sum = 0, sum2 = 0; + if (row < num_rows) { + float *src = src_ + row * row_size; + // Sequential reduction within a thread. + for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) { + float val = src[col]; + sum += val; + sum2 += val * val; + } + } + ssum[threadIdx.y][threadIdx.x] = sum; + ssum2[threadIdx.y][threadIdx.x] = sum2; + __syncthreads(); + + // Reduce intermediate values to single value. + for (unsigned s = 8; s > 1; s >>= 1) { + if (row < num_rows && threadIdx.x < s) { + ssum[threadIdx.y][threadIdx.x] += ssum[threadIdx.y][threadIdx.x + s]; + ssum2[threadIdx.y][threadIdx.x] += ssum2[threadIdx.y][threadIdx.x + s]; + } + __syncthreads(); + } + + if (row < num_rows && threadIdx.x == 0) { + sum = ssum[threadIdx.y][0] + ssum[threadIdx.y][1]; + sum2 = ssum2[threadIdx.y][0] + ssum2[threadIdx.y][1]; + tgt[row] = THCudaTensor_computeVar(sum, sum2, row_size); + } + __syncthreads(); + } +} + +template +__host__ void THCudaTensor_varInnermostDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, int flag) +{ + unsigned ndim = THCudaTensor_nDimension(state, src); + // Treat all outer dimensions as a single dimension. + unsigned num_rows = 1; + for (unsigned dim = 0; dim < ndim - 1; dim++) { + num_rows *= THCudaTensor_size(state, src, dim); + } + unsigned row_size = THCudaTensor_size(state, src, ndim - 1); + + // From limited testing, 16x32 seemed a good compromise for handling both long and short dimensions. + dim3 threads(16, 32); + dim3 grid(min(1024, DIVUP(num_rows, threads.y))); + + if (flag) { + THCudaTensor_kernel_varInnermostDim<<>>( + THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_rows, row_size); + } else { + THCudaTensor_kernel_varInnermostDim<<>>( + THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_rows, row_size); + } + cudaError errcode = cudaGetLastError(); + if (errcode != cudaSuccess) { + THError(cudaGetErrorString(errcode)); + } +} + +void THCudaTensor_var(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag) +{ + THLongStorage *dim = THCudaTensor_newSizeOf(state, src); + THLongStorage_set(dim, dimension, 1); + THCudaTensor_resize(state, self_, dim, NULL); + THLongStorage_free(dim); + + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + src = THCudaTensor_newContiguous(state, src); + + if (dimension == THCudaTensor_nDimension(state, src) - 1) { + THCudaTensor_varInnermostDim(state, self, src, flag); + } else { + THCudaTensor_varOuterDim(state, self, src, dimension, flag); + } + + THCudaTensor_free(state, src); + THCudaTensor_freeCopyTo(state, self, self_); +} + +void THCudaTensor_std(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, int flag) +{ + THLongStorage *dim = THCudaTensor_newSizeOf(state, src); + THLongStorage_set(dim, dimension, 1); + THCudaTensor_resize(state, self_, dim, NULL); + THLongStorage_free(dim); + + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + src = THCudaTensor_newContiguous(state, src); + + if (dimension == THCudaTensor_nDimension(state, src) - 1) { + THCudaTensor_varInnermostDim(state, self, src, flag); + } else { + THCudaTensor_varOuterDim(state, self, src, dimension, flag); + } + + THCudaTensor_free(state, src); + THCudaTensor_freeCopyTo(state, self, self_); +} + + +struct norm_functor +{ + const float exponent; + + norm_functor(float exponent_) : exponent(exponent_) {} + + __host__ __device__ float operator()(const float& x) const + { + return pow(fabs(x), exponent); + } +}; + +struct partial_not_equal_functor +{ + const float rhs; + partial_not_equal_functor(float rhs) : rhs(rhs) {} + __host__ __device__ bool operator()(const float &lhs) const {return lhs != rhs;} +}; + +float THCudaTensor_normall(THCState *state, THCudaTensor *self, float value) +{ + self = THCudaTensor_newContiguous(state, self); + long size = THCudaTensor_nElement(state, self); + thrust::device_ptr self_data(THCudaTensor_data(state, self)); + + float result; + if(value == 0.0f) { + result = thrust::transform_reduce(self_data, self_data+size, partial_not_equal_functor(0.0f), (float)0, thrust::plus()); + } else { + result = thrust::transform_reduce(self_data, self_data+size, norm_functor(value), (float)0, thrust::plus()); + result = pow(result, (float)1.0/value); + } + + THCudaTensor_free(state, self); + return result; +} + +void THCudaTensor_norm(THCState *state, THCudaTensor* self, THCudaTensor* src, float value, long dimension) +{ + if (value == 0.0f) { + THCudaTensor_reduceDim(state, self, src, + partial_not_equal_functor(0.0f), thrust::plus(), + 0.0f, dimension); + } else { + THCudaTensor_reduceDim(state, self, src, + norm_functor(value), thrust::plus(), + 0.0f, dimension); + THCudaTensor_pow(state, self, self, 1/value); + } + + THCudaCheck(cudaGetLastError()); +} + +__global__ void THCudaTensor_kernel_renorm(float *data, const float value, const long size, const float maxnorm) +{ + __shared__ float buffer[32]; + long tx = threadIdx.x; + long bx = blockIdx.x; + long step = blockDim.x; + float *row = data + size*bx; + + buffer[tx] = 0; + + // get norm of axis + for (long i=tx; i> 1; stride > 0; stride >>= 1) + { + __syncthreads(); + if (tx < stride) + buffer[tx] += buffer[tx+stride]; + } + // clip norms + __syncthreads(); + float norm = pow(buffer[0], 1/value); + if (norm > maxnorm) + { + norm = maxnorm / (norm + 1e-7); + // renormalize + for (long i=tx; isize[0]; + + THArgCheck(dimension >= 0 && dimension < THCudaTensor_nDimension(state, src), 3, "invalid dimension"); + THArgCheck(value > 0, 2, "non-positive-norm not supported"); + THArgCheck(THCudaTensor_nDimension(state, src) > 1, 1, "need at least 2 dimensions"); + + dim3 grid(data->size[0]); + dim3 threads(32); + + THCudaTensor_kernel_renorm<<>>(THCudaTensor_data(state, data), value, size, maxnorm); + + cudaError errcode = cudaGetLastError(); + if(errcode != cudaSuccess) + THError(cudaGetErrorString(errcode)); + + THCudaTensor_free(state, src_); + self_ = THCudaTensor_newTranspose(state, data, dimension, 0); + THCudaTensor_resizeAs(state, self, self_); + THCudaTensor_freeCopyTo(state, self_, self); + THCudaTensor_free(state, data); +} + +struct dist_functor +{ + const float exponent; + + dist_functor(float exponent_) : exponent(exponent_) {} + + __host__ __device__ float operator()(const float& x, const float& y) const + { + return pow(fabs(x-y), exponent); + } +}; + +float THCudaTensor_dist(THCState *state, THCudaTensor *self, THCudaTensor *src, float value) +{ + self = THCudaTensor_newContiguous(state, self); + long size = THCudaTensor_nElement(state, self); + src = THCudaTensor_newContiguous(state, src); + thrust::device_ptr self_data(THCudaTensor_data(state, self)); + thrust::device_ptr src_data(THCudaTensor_data(state, src)); + + float result = thrust::inner_product(self_data, self_data+size, src_data, (float) 0,thrust::plus(), dist_functor(value)); + + THCudaTensor_free(state, src); + THCudaTensor_free(state, self); + + return pow(result, (float)1.0/value); +} + +void THCudaTensor_rand(THCState *state, THCudaTensor *r_, THLongStorage *size) +{ + THCudaTensor_resize(state, r_, size, NULL); + THCudaTensor_uniform(state, r_, 0, 1); +} + +void THCudaTensor_randn(THCState *state, THCudaTensor *r_, THLongStorage *size) +{ + THCudaTensor_resize(state, r_, size, NULL); + THCudaTensor_normal(state, r_, 0, 1); +} diff --git a/lib/THC/THCTensorMathBlas.cu b/lib/THC/THCTensorMathBlas.cu new file mode 100644 index 00000000..2f622ac0 --- /dev/null +++ b/lib/THC/THCTensorMathBlas.cu @@ -0,0 +1,380 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +float THCudaTensor_dot(THCState *state, THCudaTensor *self, THCudaTensor *src) +{ + THArgCheck(THCudaTensor_nElement(state, self) == THCudaTensor_nElement(state, src), 2, "sizes do not match"); + + { + self = THCudaTensor_newContiguous(state, self); + src = THCudaTensor_newContiguous(state, src); + + float result = THCudaBlas_dot(state, + THCudaTensor_nElement(state, self), + THCudaTensor_data(state, self), 1, + THCudaTensor_data(state, src), 1); + THCudaTensor_free(state, src); + THCudaTensor_free(state, self); + + return result; + } +} + +void THCudaTensor_addmv(THCState *state, THCudaTensor *r_, float beta, THCudaTensor *t, float alpha, THCudaTensor *mat, THCudaTensor *vec) +{ + if( (mat->nDimension != 2) || (vec->nDimension != 1) ) + THError("matrix and vector expected"); + + if( mat->size[1] != vec->size[0] ) + THError("size mismatch"); + + if(t->nDimension != 1) + THError("size mismatch"); + + if(t->size[0] != mat->size[0]) + THError("size mismatch"); + + if(r_ != t) + { + THCudaTensor_resizeAs(state, r_, t); + THCudaTensor_copy(state, r_, t); + } + + if(mat->stride[0] == 1) + { + THCudaBlas_gemv(state, 'n', mat->size[0], mat->size[1], + alpha, THCudaTensor_data(state, mat), mat->stride[1], + THCudaTensor_data(state, vec), vec->stride[0], + beta, THCudaTensor_data(state, r_), r_->stride[0]); + } + else if(mat->stride[1] == 1) + { + THCudaBlas_gemv(state, 't', mat->size[1], mat->size[0], + alpha, THCudaTensor_data(state, mat), mat->stride[0], + THCudaTensor_data(state, vec), vec->stride[0], + beta, THCudaTensor_data(state, r_), r_->stride[0]); + } + else + { + THCudaTensor *cmat = THCudaTensor_newContiguous(state, mat); + + THCudaBlas_gemv(state, 't', mat->size[1], mat->size[0], + alpha, THCudaTensor_data(state, cmat), cmat->stride[0], + THCudaTensor_data(state, vec), vec->stride[0], + beta, THCudaTensor_data(state, r_), r_->stride[0]); + + THCudaTensor_free(state, cmat); + } +} + +void THCudaTensor_addmm(THCState *state, THCudaTensor *r_, float beta, THCudaTensor *t, float alpha, THCudaTensor *m1, THCudaTensor *m2) +{ + char transpose_r, transpose_m1, transpose_m2; + THCudaTensor *r__, *m1_, *m2_; + + if( (m1->nDimension != 2) || (m2->nDimension != 2) ) + THError("matrix and matrix expected"); + + if(t->nDimension != 2) + THError("size mismatch"); + + if( (t->size[0] != m1->size[0]) || (t->size[1] != m2->size[1]) || (m1->size[1] != m2->size[0]) ) + THError("size mismatch"); + + if(t != r_) + { + THCudaTensor_resizeAs(state, r_, t); + THCudaTensor_copy(state, r_, t); + } + + /* r_ */ + if(r_->stride[0] == 1) + { + transpose_r = 'n'; + r__ = r_; + } + else if(r_->stride[1] == 1) + { + THCudaTensor *swap = m2; + m2 = m1; + m1 = swap; + transpose_r = 't'; + r__ = r_; + } + else + { + transpose_r = 'n'; + + r__ = THCudaTensor_newWithSize2d(state, r_->size[1], r_->size[0]); + THCudaTensor_copy(state, r__, r_); + THCudaTensor_transpose(state, r__, NULL, 0, 1); + } + + /* m1 */ + if(m1->stride[(transpose_r == 'n' ? 0 : 1)] == 1) + { + transpose_m1 = 'n'; + m1_ = m1; + } + else if(m1->stride[(transpose_r == 'n' ? 1 : 0)] == 1) + { + transpose_m1 = 't'; + m1_ = m1; + } + else + { + transpose_m1 = (transpose_r == 'n' ? 't' : 'n'); + m1_ = THCudaTensor_newContiguous(state, m1); + } + + /* m2 */ + if(m2->stride[(transpose_r == 'n' ? 0 : 1)] == 1) + { + transpose_m2 = 'n'; + m2_ = m2; + } + else if(m2->stride[(transpose_r == 'n' ? 1 : 0)] == 1) + { + transpose_m2 = 't'; + m2_ = m2; + } + else + { + transpose_m2 = (transpose_r == 'n' ? 't' : 'n'); + m2_ = THCudaTensor_newContiguous(state, m2); + } + + /* do the operation */ + THCudaBlas_gemm(state, + transpose_m1, + transpose_m2, + r__->size[(transpose_r == 'n' ? 0 : 1)], + r__->size[(transpose_r == 'n' ? 1 : 0)], + m1_->size[(transpose_r == 'n' ? 1 : 0)], + alpha, + THCudaTensor_data(state, m1_), + (transpose_m1 == 'n' ? m1_->stride[(transpose_r == 'n' ? 1 : 0)] : m1_->stride[(transpose_r == 'n' ? 0 : 1)]), + THCudaTensor_data(state, m2_), + (transpose_m2 == 'n' ? m2_->stride[(transpose_r == 'n' ? 1 : 0)] : m2_->stride[(transpose_r == 'n' ? 0 : 1)]), + beta, + THCudaTensor_data(state, r__), + r__->stride[(transpose_r == 'n' ? 1 : 0)]); + + /* free intermediate variables */ + if(m1_ != m1) + THCudaTensor_free(state, m1_); + + if(m2_ != m2) + THCudaTensor_free(state, m2_); + + if(r__ != r_) + THCudaTensor_freeCopyTo(state, r__, r_); +} + +void THCudaTensor_addr(THCState *state, THCudaTensor *r_, float beta, THCudaTensor *t, float alpha, THCudaTensor *vec1, THCudaTensor *vec2) +{ + if( (vec1->nDimension != 1) || (vec2->nDimension != 1) ) + THError("vector and vector expected"); + + if(t->nDimension != 2) + THError("size mismatch"); + + if( (t->size[0] != vec1->size[0]) || (t->size[1] != vec2->size[0]) ) + THError("size mismatch"); + + if(r_ != t) + { + THCudaTensor_resizeAs(state, r_, t); + THCudaTensor_copy(state, r_, t); + } + + if(beta != 1) + THCudaTensor_mul(state, r_, r_, beta); + + if(r_->stride[0] == 1) + { + THCudaBlas_ger(state, vec1->size[0], vec2->size[0], + alpha, THCudaTensor_data(state, vec1), vec1->stride[0], + THCudaTensor_data(state, vec2), vec2->stride[0], + THCudaTensor_data(state, r_), r_->stride[1]); + } + else if(r_->stride[1] == 1) + { + THCudaBlas_ger(state, vec2->size[0], vec1->size[0], + alpha, THCudaTensor_data(state, vec2), vec2->stride[0], + THCudaTensor_data(state, vec1), vec1->stride[0], + THCudaTensor_data(state, r_), r_->stride[0]); + } + else + { + THCudaTensor *cr = THCudaTensor_newClone(state, r_); + + THCudaBlas_ger(state, vec2->size[0], vec1->size[0], + alpha, THCudaTensor_data(state, vec2), vec2->stride[0], + THCudaTensor_data(state, vec1), vec1->stride[0], + THCudaTensor_data(state, cr), cr->stride[0]); + + THCudaTensor_freeCopyTo(state, cr, r_); + } +} + +void THCudaTensor_baddbmm(THCState *state, THCudaTensor *result, float beta, THCudaTensor *t, + float alpha, THCudaTensor *batch1, THCudaTensor *batch2) { + THArgCheck(THCudaTensor_nDimension(state, t) == 3, 4, "expected 3D tensor"); + THArgCheck(THCudaTensor_nDimension(state, batch1) == 3, 6, "expected 3D tensor"); + THArgCheck(THCudaTensor_nDimension(state, batch2) == 3, 7, "expected 3D tensor"); + THArgCheck(THCudaTensor_size(state, t, 0) == THCudaTensor_size(state, batch1, 0), 6, + "equal number of batches expected"); + THArgCheck(THCudaTensor_size(state, t, 0) == THCudaTensor_size(state, batch2, 0), 7, + "equal number of batches expected"); + THArgCheck(THCudaTensor_size(state, t, 1) == THCudaTensor_size(state, batch1, 1), 6, + "wrong matrix size"); + THArgCheck(THCudaTensor_size(state, t, 2) == THCudaTensor_size(state, batch2, 2), 7, + "wrong matrix size"); + THArgCheck(THCudaTensor_size(state, batch1, 2) == THCudaTensor_size(state, batch2, 1), 6, + "wrong matrix size"); + + if (t != result) { + THCudaTensor_resizeAs(state, result, t); + THCudaTensor_copy(state, result, t); + } + + bool transpose_result; + char transpose_batch1, transpose_batch2; + long lda, ldb, ldc; + THCudaTensor *result_, *batch1_, *batch2_; + if (result->stride[1] == 1) + { + transpose_result = false; + result_ = result; + ldc = result_->stride[2]; + } + else if (result->stride[2] == 1) + { + transpose_result = true; + + THCudaTensor *swap = batch2; + batch2 = batch1; + batch1 = swap; + + result_ = result; + ldc = result_->stride[1]; + } + else + { + transpose_result = false; + + result_ = THCudaTensor_newWithSize3d(state, result->size[0], result->size[2], result->size[1]); + THCudaTensor_copy(state, result_, result); + THCudaTensor_transpose(state, result_, NULL, 1, 2); + + ldc = result_->stride[2]; + } + + if (batch1->stride[transpose_result ? 2 : 1] == 1) + { + transpose_batch1 = 'n'; + batch1_ = batch1; + lda = batch1_->stride[transpose_result ? 1 : 2]; + } + else if (batch1->stride[transpose_result ? 1 : 2] == 1) + { + transpose_batch1 = 't'; + batch1_ = batch1; + lda = batch1_->stride[transpose_result ? 2 : 1]; + } + else + { + transpose_batch1 = transpose_result ? 'n' : 't'; + batch1_ = THCudaTensor_newContiguous(state, batch1); + lda = batch1_->stride[1]; + } + + if (batch2->stride[transpose_result ? 2 : 1] == 1) + { + transpose_batch2 = 'n'; + batch2_ = batch2; + ldb = batch2_->stride[transpose_result ? 1 : 2]; + } + else if (batch2->stride[transpose_result ? 1 : 2] == 1) + { + transpose_batch2 = 't'; + batch2_ = batch2; + ldb = batch2_->stride[transpose_result ? 2 : 1]; + } + else + { + transpose_batch2 = transpose_result ? 'n' : 't'; + batch2_ = THCudaTensor_newContiguous(state, batch2); + ldb = batch2_->stride[1]; + } + + // Compute pointers to matrices in each batch. + long num_batches = result_->size[0]; + size_t matrices_size = num_batches * sizeof(float*); + const float **matrices1 = (const float **)THAlloc(matrices_size); + const float **matrices2 = (const float **)THAlloc(matrices_size); + float **result_matrices = (float **)THAlloc(matrices_size); + for (int i = 0; i < num_batches; ++i) + { + matrices1[i] = THCudaTensor_data(state, batch1_) + i * batch1_->stride[0]; + matrices2[i] = THCudaTensor_data(state, batch2_) + i * batch2_->stride[0]; + result_matrices[i] = THCudaTensor_data(state, result_) + i * result_->stride[0]; + } + + // Copy pointers to device. + const float **d_matrices1, **d_matrices2; + float **d_result_matrices; + THCudaCheck(cudaMalloc(&d_matrices1, matrices_size)); + THCudaCheck(cudaMalloc(&d_matrices2, matrices_size)); + THCudaCheck(cudaMalloc(&d_result_matrices, matrices_size)); + + THCudaCheck(cudaMemcpyAsync(d_matrices1, matrices1, matrices_size, cudaMemcpyHostToDevice)); + THCudaCheck(cudaMemcpyAsync(d_matrices2, matrices2, matrices_size, cudaMemcpyHostToDevice)); + THCudaCheck(cudaMemcpyAsync(d_result_matrices, result_matrices, matrices_size, cudaMemcpyHostToDevice)); + + THCudaBlas_gemmBatched( + state, + transpose_batch1, + transpose_batch2, + result_->size[transpose_result ? 2 : 1], + result_->size[transpose_result ? 1 : 2], + batch1_->size[transpose_result ? 1 : 2], + alpha, + d_matrices1, lda, + d_matrices2, ldb, + beta, + d_result_matrices, ldc, + num_batches); + + cudaFree(d_matrices1); + cudaFree(d_matrices2); + cudaFree(d_result_matrices); + THFree(matrices1); + THFree(matrices2); + THFree(result_matrices); + + if (batch1_ != batch1) + THCudaTensor_free(state, batch1_); + + if (batch2_ != batch2) + THCudaTensor_free(state, batch2_); + + if (result_ != result) + THCudaTensor_freeCopyTo(state, result_, result); +} diff --git a/lib/THC/THCTensorMathCompare.cu b/lib/THC/THCTensorMathCompare.cu new file mode 100644 index 00000000..0e8f6010 --- /dev/null +++ b/lib/THC/THCTensorMathCompare.cu @@ -0,0 +1,113 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +template +void THCudaTensor_logicalValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, Op op) +{ + THCudaTensor_resizeAs(state, self_, src); + + if (!THCudaTensor_pointwiseApply2(state, self_, src, op)) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + + THCudaCheck(cudaGetLastError()); +} + +struct TensorLTValueOp { + TensorLTValueOp(float v) : value(v) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = (*in < value); + } + + const float value; +}; + +void THCudaTensor_ltValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) +{ + THCudaTensor_logicalValue(state, self_, src, TensorLTValueOp(value)); +} + +struct TensorGTValueOp { + TensorGTValueOp(float v) : value(v) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = (*in > value); + } + + const float value; +}; + +void THCudaTensor_gtValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) +{ + THCudaTensor_logicalValue(state, self_, src, TensorGTValueOp(value)); +} + +struct TensorLEValueOp { + TensorLEValueOp(float v) : value(v) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = (*in <= value); + } + + const float value; +}; + +void THCudaTensor_leValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) +{ + THCudaTensor_logicalValue(state, self_, src, TensorLEValueOp(value)); +} + +struct TensorGEValueOp { + TensorGEValueOp(float v) : value(v) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = (*in >= value); + } + + const float value; +}; + +void THCudaTensor_geValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) +{ + THCudaTensor_logicalValue(state, self_, src, TensorGEValueOp(value)); +} + +struct TensorEQValueOp { + TensorEQValueOp(float v) : value(v) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = (*in == value); + } + + const float value; +}; + +void THCudaTensor_eqValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) +{ + THCudaTensor_logicalValue(state, self_, src, TensorEQValueOp(value)); +} + +struct TensorNEValueOp { + TensorNEValueOp(float v) : value(v) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = (*in != value); + } + + const float value; +}; + +void THCudaTensor_neValue(THCState *state, THCudaTensor *self_, THCudaTensor *src, float value) +{ + THCudaTensor_logicalValue(state, self_, src, TensorNEValueOp(value)); +} diff --git a/lib/THC/THCTensorMathCompareT.cu b/lib/THC/THCTensorMathCompareT.cu new file mode 100644 index 00000000..4759b704 --- /dev/null +++ b/lib/THC/THCTensorMathCompareT.cu @@ -0,0 +1,101 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +template +void THCudaTensor_logicalTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2, Op op) +{ + THCudaTensor_resizeAs(state, self_, src1); + THArgCheck(THCudaTensor_nElement(state, src1) == THCudaTensor_nElement(state, src2), 3, "sizes do not match"); + + if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, op)) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + + THCudaCheck(cudaGetLastError()); +} + +struct TensorLTOp { + __device__ __forceinline__ void operator()(float* out, float* a, float* b) { + *out = (float) (*a < *b); + } +}; + +struct TensorGTOp { + __device__ __forceinline__ void operator()(float* out, float* a, float* b) { + *out = (float) (*a > *b); + } +}; + +struct TensorLEOp { + __device__ __forceinline__ void operator()(float* out, float* a, float* b) { + *out = (float) (*a <= *b); + } +}; + +struct TensorGEOp { + __device__ __forceinline__ void operator()(float* out, float* a, float* b) { + *out = (float) (*a >= *b); + } +}; + +struct TensorEQOp { + __device__ __forceinline__ void operator()(float* out, float* a, float* b) { + *out = (float) (*a == *b); + } +}; + +struct TensorNEOp { + __device__ __forceinline__ void operator()(float* out, float* a, float* b) { + *out = (float) (*a != *b); + } +}; + +void THCudaTensor_ltTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) +{ + THCudaTensor_logicalTensor(state, self_, src1, src2, TensorLTOp()); +} + + +void THCudaTensor_gtTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) +{ + THCudaTensor_logicalTensor(state, self_, src1, src2, TensorGTOp()); +} + + +void THCudaTensor_leTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) +{ + THCudaTensor_logicalTensor(state, self_, src1, src2, TensorLEOp()); +} + + +void THCudaTensor_geTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) +{ + THCudaTensor_logicalTensor(state, self_, src1, src2, TensorGEOp()); +} + + +void THCudaTensor_eqTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) +{ + THCudaTensor_logicalTensor(state, self_, src1, src2, TensorEQOp()); +} + + +void THCudaTensor_neTensor(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) +{ + THCudaTensor_logicalTensor(state, self_, src1, src2, TensorNEOp()); +} diff --git a/lib/THC/THCTensorMathPairwise.cu b/lib/THC/THCTensorMathPairwise.cu new file mode 100644 index 00000000..a50f6aab --- /dev/null +++ b/lib/THC/THCTensorMathPairwise.cu @@ -0,0 +1,96 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +struct TensorAddConstantOp { + TensorAddConstantOp(float v) : val(v) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = *in + val; + } + + __device__ __forceinline__ void operator()(float* v) { + *v += val; + } + + const float val; +}; + +void THCudaTensor_add(THCState *state, THCudaTensor *self_, THCudaTensor *src_, float value) +{ + if (self_ == src_) { + if (!THCudaTensor_pointwiseApply1(state, self_, TensorAddConstantOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCudaTensor_resizeAs(state, self_, src_); + + if (!THCudaTensor_pointwiseApply2(state, self_, src_, TensorAddConstantOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + +struct TensorMulConstantOp { + TensorMulConstantOp(float v) : val(v) {} + __device__ __forceinline__ void operator()(float* out, float* in) { + *out = *in * val; + } + + __device__ __forceinline__ void operator()(float* v) { + *v *= val; + } + + const float val; +}; + +void THCudaTensor_mul(THCState *state, THCudaTensor *self_, THCudaTensor *src_, float value) +{ + if (self_ == src_) { + if (!THCudaTensor_pointwiseApply1(state, self_, TensorMulConstantOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCudaTensor_resizeAs(state, self_, src_); + + if (!THCudaTensor_pointwiseApply2(state, self_, src_, TensorMulConstantOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + +void THCudaTensor_div(THCState* state, THCudaTensor *self_, THCudaTensor *src_, float value) +{ + THArgCheck(value != 0.0f, 3, "divide by zero"); + + if (self_ == src_) { + if (!THCudaTensor_pointwiseApply1(state, self_, TensorMulConstantOp(1.0f / value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCudaTensor_resizeAs(state, self_, src_); + + if (!THCudaTensor_pointwiseApply2(state, self_, src_, TensorMulConstantOp(1.0f / value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} diff --git a/lib/THC/THCTensorMathPointwise.cu b/lib/THC/THCTensorMathPointwise.cu new file mode 100644 index 00000000..587f59b2 --- /dev/null +++ b/lib/THC/THCTensorMathPointwise.cu @@ -0,0 +1,157 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +#define IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(NAME, CFUNC) \ + struct Tensor##NAME##Op { \ + __device__ __forceinline__ void operator()(float* out, float* in) const { \ + *out = CFUNC(*in); \ + } \ + \ + __device__ __forceinline__ void operator()(float* v) const { \ + *v = CFUNC(*v); \ + } \ + }; \ + \ + void THCudaTensor_##NAME(THCState* state, THCudaTensor* self_, THCudaTensor* src) { \ + if (self_ == src) { \ + if (!THCudaTensor_pointwiseApply1(state, self_, Tensor##NAME##Op())) { \ + THArgCheck(false, 2, CUTORCH_DIM_WARNING); \ + } \ + } else { \ + THCudaTensor_resizeAs(state, self_, src); \ + \ + if (!THCudaTensor_pointwiseApply2(state, self_, src, Tensor##NAME##Op())) { \ + THArgCheck(false, 2, CUTORCH_DIM_WARNING); \ + } \ + } \ + \ + THCudaCheck(cudaGetLastError()); \ + } + +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log, log) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(log1p, log1p) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(exp, exp) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(cos, cos) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(acos, acos) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(cosh, cosh) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(sin, sin) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(asin, asin) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(sinh, sinh) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(tan, tan) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(atan, atan) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(tanh, tanh) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(sqrt, sqrt) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(ceil, ceil) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(floor, floor) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(abs, fabs) +IMPLEMENT_CUDA_TENSOR_BASIC_FUNC(round, roundf) + +#undef IMPLEMENT_CUDA_TENSOR_BASIC_FUNC + +struct TensorAddOp { + __device__ __forceinline__ void operator()(float* out, float* in) { + *out += *in; + } + + __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) { + *out = *in1 + *in2; + } +}; + +struct TensorCAddOp { + TensorCAddOp(float v) : val(v) {} + + __device__ __forceinline__ void operator()(float* out, float* in) { + *out += val * *in; + } + + __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) { + *out = *in1 + val * *in2; + } + + float val; +}; + +void THCudaTensor_cadd(THCState *state, THCudaTensor *self_, THCudaTensor* src1, float value, THCudaTensor *src2) +{ + THArgCheck(THCudaTensor_nElement(state, src1) == + THCudaTensor_nElement(state, src2), 3, "sizes do not match"); + + if (self_ == src1) { + if (value == 1.0f) { + // self += src2 + if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorAddOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + // self += value * src2 + if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorCAddOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + } else { + THCudaTensor_resizeAs(state, self_, src1); + + if (value == 1.0f) { + // self = src1 + src2 + if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorAddOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + // self = src1 + value * src2 + if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorCAddOp(value))) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + } + + THCudaCheck(cudaGetLastError()); +} + +struct TensorMulOp { + __device__ __forceinline__ void operator()(float* out, float* in) { + *out *= *in; + } + + __device__ __forceinline__ void operator()(float* out, float* in1, float* in2) { + *out = *in1 * *in2; + } +}; + +void THCudaTensor_cmul(THCState *state, THCudaTensor *self_, THCudaTensor *src1, THCudaTensor *src2) +{ + THArgCheck(THCudaTensor_nElement(state, src1) == + THCudaTensor_nElement(state, src2), 3, "sizes do not match"); + + if (self_ == src1) { + // self *= src2 + if (!THCudaTensor_pointwiseApply2(state, self_, src2, TensorMulOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } else { + THCudaTensor_resizeAs(state, self_, src1); + + // self = src1 * src2 + if (!THCudaTensor_pointwiseApply3(state, self_, src1, src2, TensorMulOp())) { + THArgCheck(false, 2, CUTORCH_DIM_WARNING); + } + } + + THCudaCheck(cudaGetLastError()); +} + diff --git a/lib/THC/THCTensorMathScan.cu b/lib/THC/THCTensorMathScan.cu new file mode 100644 index 00000000..76d6d085 --- /dev/null +++ b/lib/THC/THCTensorMathScan.cu @@ -0,0 +1,213 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +/* Perform an inclusive scan along an outer dimension of a tensor. + * + * - num_orows is the size of the flattened outer dimensions; + * - num_irows is the size of the flattened inner dimensions; + * - row_size is the size of the dimension along which to compute the variance; + * + * The dimensions to the outside and inside of the specified dimension are considered as flattened. + * Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened + * outer dimensions, which contains several "inner rows"). + * Each thread processes a single inner row at a time. + */ +template +__global__ void THCudaTensor_kernel_scanOuterDim(float *tgt_, float *src_, + unsigned num_orows, unsigned num_irows, unsigned row_size, + float init, BinaryOp binary_op) +{ + for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + float *src = src_ + orow * row_size * num_irows + irow; + float *tgt = tgt_ + orow * row_size * num_irows + irow; + float acc = init; + + for (unsigned col = 0; col < row_size; ++col) { + acc = binary_op(acc, *src); + *tgt = acc; + + src += num_irows; + tgt += num_irows; + } + } + } +} + +template +__host__ void THCudaTensor_scanOuterDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, long dimension, + float init, BinaryOp binary_op) +{ + unsigned ndim = THCudaTensor_nDimension(state, src); + // Treat all outer dimensions (i.e. dim < dimension) as one. + unsigned num_orows = 1; + for (unsigned dim = 0; dim < dimension; dim++) { + num_orows *= THCudaTensor_size(state, src, dim); + } + unsigned row_size = THCudaTensor_size(state, src, dimension); + // Treat all inner dimensions (i.e. dim > dimension) as one. + unsigned num_irows = 1; + for (unsigned dim = dimension + 1; dim < ndim; dim++) { + num_irows *= THCudaTensor_size(state, src, dim); + } + + dim3 threads(min(512, num_irows)); + unsigned maxGridDim = 1024; + dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, DIVUP(num_irows, threads.x))); + + THCudaTensor_kernel_scanOuterDim<<>>( + THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_orows, num_irows, row_size, init, binary_op); + cudaError errcode = cudaGetLastError(); + if (errcode != cudaSuccess) { + THError(cudaGetErrorString(errcode)); + } +} + + +/* Perform an inclusive scan along the innermost dimension of a tensor. + * + * - num_rows is the size of the flattened outer dimensions; + * - row_size is the size of the innermost dimension; + * + * The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is + * considered as having 'num_rows' rows of size 'row_size'. + * Each thread block processes one or more sets of contiguous rows (processing multiple rows + * per thread block is quicker than processing a single row, especially for short rows). + */ +template +__global__ void THCudaTensor_kernel_scanInnermostDim(float *tgt_, float *src_, + unsigned num_rows, unsigned row_size, + float init, BinaryFunction binary_op) +{ + __shared__ float sbuf[num_threads_y][2 * num_threads_x]; + + float* row_buf = sbuf[threadIdx.y]; + + for (unsigned block_row = blockIdx.x * blockDim.y; + block_row < num_rows; + block_row += blockDim.y * gridDim.x) { + unsigned row = block_row + threadIdx.y; + float block_total = init; + + float *row_src = src_ + row * row_size; + float *row_tgt = tgt_ + row * row_size; + + // Perform scan on one block at a time, keeping track of the total value of + // all blocks processed so far. + for (unsigned block_col = 0; block_col < row_size; block_col += 2 * num_threads_x) { + // Load data into shared memory (two values per thread). + unsigned col1 = block_col + threadIdx.x; + unsigned col2 = block_col + num_threads_x + threadIdx.x; + if (row < num_rows) { + if (col1 < row_size) { + row_buf[threadIdx.x] = row_src[col1]; + } else { + row_buf[threadIdx.x] = init; + } + + if (col2 < row_size) { + row_buf[num_threads_x + threadIdx.x] = row_src[col2]; + } else { + row_buf[num_threads_x + threadIdx.x] = init; + } + + // Add the total value of all previous blocks to the first value of this block. + if (threadIdx.x == 0) { + row_buf[0] = binary_op(row_buf[0], block_total); + } + } + __syncthreads(); + + // Parallel reduction (up-sweep). + for (unsigned s = num_threads_x, d = 1; s >= 1; s >>= 1, d <<= 1) { + if (row < num_rows && threadIdx.x < s) { + unsigned offset = (2 * threadIdx.x + 1) * d - 1; + row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); + } + __syncthreads(); + } + + // Down-sweep. + for (unsigned s = 2, d = num_threads_x / 2; d >= 1; s <<= 1, d >>= 1) { + if (row < num_rows && threadIdx.x < s - 1) { + unsigned offset = 2 * (threadIdx.x + 1) * d - 1; + row_buf[offset + d] = binary_op(row_buf[offset], row_buf[offset + d]); + } + __syncthreads(); + } + + // Write back to output. + if (row < num_rows) { + if (col1 < row_size) row_tgt[col1] = row_buf[threadIdx.x]; + if (col2 < row_size) row_tgt[col2] = row_buf[num_threads_x + threadIdx.x]; + } + block_total = row_buf[2 * num_threads_x - 1]; + __syncthreads(); + } + } +} + +template +__host__ void THCudaTensor_scanInnermostDim(THCState *state, THCudaTensor *tgt, THCudaTensor *src, float init, BinaryFunction binary_op) +{ + unsigned ndim = THCudaTensor_nDimension(state, src); + // Treat all outer dimensions as a single dimension. + unsigned num_rows = 1; + for (unsigned dim = 0; dim < ndim - 1; dim++) { + num_rows *= THCudaTensor_size(state, src, dim); + } + unsigned row_size = THCudaTensor_size(state, src, ndim - 1); + + dim3 threads(16, 32); + dim3 grid(min(1024, DIVUP(num_rows, threads.y))); + + THCudaTensor_kernel_scanInnermostDim<16, 32><<>>( + THCudaTensor_data(state, tgt), THCudaTensor_data(state, src), num_rows, row_size, init, binary_op); + cudaError errcode = cudaGetLastError(); + if (errcode != cudaSuccess) { + THError(cudaGetErrorString(errcode)); + } +} + +template +void THCudaTensor_scanDim(THCState *state, THCudaTensor *self_, THCudaTensor *src, long dimension, float init, BinaryFunction binary_op) +{ + THCudaTensor_resizeAs(state, self_, src); + + THCudaTensor *self = THCudaTensor_newContiguous(state, self_); + src = THCudaTensor_newContiguous(state, src); + + if (dimension == THCudaTensor_nDimension(state, src) - 1) { + THCudaTensor_scanInnermostDim(state, self, src, init, binary_op); + } else { + THCudaTensor_scanOuterDim(state, self, src, dimension, init, binary_op); + } + + THCudaTensor_free(state, src); + THCudaTensor_freeCopyTo(state, self, self_); +} + +void THCudaTensor_cumsum(THCState *state, THCudaTensor *self, THCudaTensor *src, long dimension) +{ + return THCudaTensor_scanDim(state, self, src, dimension, 0.0f, thrust::plus()); +} + +void THCudaTensor_cumprod(THCState *state, THCudaTensor *self, THCudaTensor *src, long dimension) +{ + return THCudaTensor_scanDim(state, self, src, dimension, 1.0f, thrust::multiplies()); +} diff --git a/lib/THC/THCTensorMathTransformReduce.cu b/lib/THC/THCTensorMathTransformReduce.cu new file mode 100644 index 00000000..a3e294f1 --- /dev/null +++ b/lib/THC/THCTensorMathTransformReduce.cu @@ -0,0 +1,219 @@ +#include "THCTensorMath.h" +#include "THCGeneral.h" +#include "THCBlas.h" +#include "THCTensorCopy.h" +#include "THCTensorRandom.h" +#include "THCApply.cuh" +#include "THCReduce.cuh" + +#include +#include +#include +#include +#include + +#ifndef DIVUP +#define DIVUP(x, y) (((x) + (y) - 1) / (y)) +#endif + +/* A set of reduction kernels that take in binary ops on thrust pairs (of value, index). + These are useful when you not only have to do a reduction, but you might have + to preserve the location of contention (for example min/max operations). + The structure of the kernels follows the structure of the reduction kernels. +*/ +template +__global__ void THCudaTensor_kernel_transformReduceOuterDimIndex(float *tgt1, float *tgt2, + float *src_, + unsigned num_orows, + unsigned num_irows, + unsigned row_size, + thrust::pair init, + BinaryFunction binary_op) +{ + for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) { + for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) { + float *src = src_ + orow * row_size * num_irows + irow; + thrust::pair acc = init; + + for (unsigned col = 0; col < row_size; ++col) { + acc = binary_op(thrust::make_pair(*src, col+1), acc); // i+1 for 1-indexing + src += num_irows; + } + tgt1[orow * num_irows + irow] = acc.first; + tgt2[orow * num_irows + irow] = acc.second; + } + } +} + +template +__host__ void THCudaTensor_transformReduceOuterDimIndex(THCState *state, THCudaTensor *tgt1, THCudaTensor *tgt2, + THCudaTensor *src, + long rdim, thrust::pair init, + BinaryFunction binary_op) +{ + unsigned ndim = THCudaTensor_nDimension(state, src); + unsigned num_orows = 1; + for (unsigned dim = 0; dim < rdim; dim++) { + num_orows *= THCudaTensor_size(state, src, dim); + } + unsigned row_size = THCudaTensor_size(state, src, rdim); + unsigned num_irows = 1; + for (unsigned dim = rdim + 1; dim < ndim; dim++) { + num_irows *= THCudaTensor_size(state, src, dim); + } + + dim3 threads(min(512, num_irows)); + unsigned maxGridDim = 1024; + dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, DIVUP(num_irows, threads.x))); + + THCudaTensor_kernel_transformReduceOuterDimIndex<<>>( + THCudaTensor_data(state, tgt1), THCudaTensor_data(state, tgt2), + THCudaTensor_data(state, src), num_orows, num_irows, row_size, init, binary_op); + cudaError errcode = cudaGetLastError(); + if(errcode != cudaSuccess) { + THError(cudaGetErrorString(errcode)); + } +} + +/* Reduce the innermost dimension of a tensor (on thrust::pair functors which are (value, index)) + * + * For an n-d tensor (n <= 4) where the reduction is along the innermost dimension: + * + * - block.x is the innermost dimension, i.e. dimension 0; + * - block.y and grid.y make up dimension 1; and + * - grid.x and grid z are the remaining two outer dimensions (if any) + * + * Reduction along other dimensions is handled in a separate kernel. + */ +template +__global__ void THCudaTensor_kernel_transformReduceInnermostDimIndex( + float *tgt1, float* tgt2, float *src_, + unsigned num_rows, unsigned row_size, + thrust::pair init, BinaryFunction binary_op) +{ + __shared__ float sbuf[32][16]; + __shared__ float ibuf[32][16]; + + for (unsigned block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) { + unsigned row = block_row + threadIdx.y; + thrust::pair acc = init; + if (row < num_rows) { + float *src = src_ + row * row_size; + // Sequential reduction within a thread. + for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) { + acc = binary_op(thrust::make_pair(src[col], col+1), acc); + } + } + + sbuf[threadIdx.y][threadIdx.x] = acc.first; + ibuf[threadIdx.y][threadIdx.x] = acc.second; + + // Reduce intermediate values to single value. + float* sline = &sbuf[threadIdx.y][0]; + float* iline = &ibuf[threadIdx.y][0]; + for (unsigned s = 8; s > 0; s >>= 1) { + if (row < num_rows && threadIdx.x < s) { + thrust::pair arg1 = thrust::make_pair(sline[threadIdx.x], iline[threadIdx.x]); + thrust::pair arg2 = thrust::make_pair(sline[threadIdx.x + s], iline[threadIdx.x + s]); + thrust::pair res = binary_op(arg1, arg2); + sline[threadIdx.x] = res.first; + iline[threadIdx.x] = res.second; + } + __syncthreads(); + } + + if (row < num_rows && threadIdx.x == 0) { + tgt1[row] = sline[0]; + tgt2[row] = iline[0]; + } + __syncthreads(); + } +} + +template +__host__ void THCudaTensor_transformReduceInnermostDimIndex( + THCState *state, THCudaTensor *tgt1, THCudaTensor *tgt2, THCudaTensor *src, + thrust::pair init, BinaryFunction binary_op) +{ + unsigned ndim = THCudaTensor_nDimension(state, src); + unsigned num_rows = 1; + for (unsigned dim = 0; dim < ndim - 1; dim++) { + num_rows *= THCudaTensor_size(state, src, dim); + } + unsigned row_size = THCudaTensor_size(state, src, ndim - 1); + + dim3 threads(16, 32); + dim3 grid(min(1024, DIVUP(num_rows, threads.y))); + + THCudaTensor_kernel_transformReduceInnermostDimIndex<<>>( + THCudaTensor_data(state, tgt1), THCudaTensor_data(state, tgt2), + THCudaTensor_data(state, src), num_rows, row_size, init, binary_op); + cudaError errcode = cudaGetLastError(); + if(errcode != cudaSuccess) { + THError(cudaGetErrorString(errcode)); + } +} + +template +void THCudaTensor_reduceDimIndex(THCState *state, THCudaTensor *tgt1_, THCudaTensor *tgt2_, THCudaTensor *src, + long dimension, thrust::pair init, + BinaryFunction binary_op) +{ + THArgCheck(dimension >= 0 && dimension < THCudaTensor_nDimension(state, src), 3, "dimension out of range"); + + THLongStorage *dim = THCudaTensor_newSizeOf(state, src); + THLongStorage_set(dim, dimension, 1); + THCudaTensor_resize(state, tgt1_, dim, NULL); + THCudaTensor_resize(state, tgt2_, dim, NULL); + THLongStorage_free(dim); + + THCudaTensor *tgt1 = THCudaTensor_newContiguous(state, tgt1_); + THCudaTensor *tgt2 = THCudaTensor_newContiguous(state, tgt2_); + src = THCudaTensor_newContiguous(state, src); + + if(dimension == THCudaTensor_nDimension(state, src)-1) { + THCudaTensor_transformReduceInnermostDimIndex(state, tgt1, tgt2, src, init, binary_op); + } else { + THCudaTensor_transformReduceOuterDimIndex(state, tgt1, tgt2, src, dimension, init, binary_op); + } + + THCudaTensor_free(state, src); + THCudaTensor_freeCopyTo(state, tgt1, tgt1_); + THCudaTensor_freeCopyTo(state, tgt2, tgt2_); +} + +struct maxvalue_functor +{ + __host__ __device__ thrust::pair operator()(const thrust::pair &a, + const thrust::pair &b) + { + if (a.first > b.first) return a; + else return b; + } +}; + +void THCudaTensor_max(THCState *state, THCudaTensor *values, THCudaTensor *indices, THCudaTensor *src, long dimension) +{ + const float minfloat32 = -3.402823466e+38f; + thrust::pair init = thrust::make_pair(minfloat32, -1); + return THCudaTensor_reduceDimIndex(state, values, indices, src, dimension, init, + maxvalue_functor()); +} + +struct minvalue_functor +{ + __host__ __device__ thrust::pair operator()(const thrust::pair &a, + const thrust::pair &b) + { + if (a.first < b.first) return a; + else return b; + } +}; + +void THCudaTensor_min(THCState *state, THCudaTensor *values, THCudaTensor *indices, THCudaTensor *src, long dimension) +{ + const float maxfloat32 = 3.402823466e+38f; + thrust::pair init = thrust::make_pair(maxfloat32, -1); + return THCudaTensor_reduceDimIndex(state, values, indices, src, dimension, init, + minvalue_functor()); +} diff --git a/rocks/cutorch-scm-1.rockspec b/rocks/cutorch-scm-1.rockspec index ef8607ba..cf1a268f 100644 --- a/rocks/cutorch-scm-1.rockspec +++ b/rocks/cutorch-scm-1.rockspec @@ -20,7 +20,7 @@ dependencies = { build = { type = "command", build_command = [[ -cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) +cmake -E make_directory build && cd build && cmake .. -DCMAKE_BUILD_TYPE=Release -DCMAKE_PREFIX_PATH="$(LUA_BINDIR)/.." -DCMAKE_INSTALL_PREFIX="$(PREFIX)" && $(MAKE) -j$(getconf _NPROCESSORS_ONLN) install ]], - install_command = "cd build && $(MAKE) install" + install_command = "cd build" } diff --git a/test/test.lua b/test/test.lua index 59e27925..578f37af 100644 --- a/test/test.lua +++ b/test/test.lua @@ -3,7 +3,7 @@ if not cutorch then require 'cutorch' runtests = true end -local tester + local test = {} local msize = 100 local minsize = 100 @@ -35,6 +35,8 @@ local function isEqual(a, b, tolerance, ...) end local function compareFloatAndCuda(x, fn, ...) + local args = {...} + args['input'] = x local x_cpu = x:float() local x_cuda = x_cpu:cuda() local res1_cpu, res2_cpu, res3_cpu, res4_cpu @@ -51,14 +53,30 @@ local function compareFloatAndCuda(x, fn, ...) error("Incorrect function type") end local tolerance = 1e-5 - tester:assert(isEqual(res1_cpu, res1_cuda, tolerance), - string.format("Divergent results between CPU and CUDA for function '%s'", tostring(fn))) - tester:assert(isEqual(res2_cpu, res2_cuda, tolerance), - string.format("Divergent results between CPU and CUDA for function '%s'", tostring(fn))) - tester:assert(isEqual(res3_cpu, res3_cuda, tolerance), - string.format("Divergent results between CPU and CUDA for function '%s'", tostring(fn))) - tester:assert(isEqual(res4_cpu, res4_cuda, tolerance), - string.format("Divergent results between CPU and CUDA for function '%s'", tostring(fn))) + if not isEqual(res1_cpu, res1_cuda, tolerance) then + print(args) + tester:assert(false, + string.format("Divergent results between CPU and CUDA" .. + " for function '%s' (return value 1)", tostring(fn))) + end + if not isEqual(res2_cpu, res2_cuda, tolerance) then + print(args) + tester:assert(false, + string.format("Divergent results between CPU and CUDA" .. + " for function '%s' (return value 2)", tostring(fn))) + end + if not isEqual(res3_cpu, res3_cuda, tolerance) then + print(args) + tester:assert(false, + string.format("Divergent results between CPU and CUDA" .. + " for function '%s' (return value 3)", tostring(fn))) + end + if not isEqual(res4_cpu, res4_cuda, tolerance) then + print(args) + tester:assert(false, + string.format("Divergent results between CPU and CUDA" .. + " for function '%s' (return value 4)", tostring(fn))) + end end local function compareFloatAndCudaTensorArgs(x, fn, ...) @@ -1229,7 +1247,7 @@ function test.multi_gpu_random() for i = 2, device_count do cutorch.setDevice(i) local actual = torch.CudaTensor(n):uniform():float() - assert(isEqual(expected, actual), "random tensors dont seem to be equal") + tester:assert(isEqual(expected, actual), "random tensors dont seem to be equal") end cutorch.setRNGState(rs) -- cleanup after yourself end @@ -1242,14 +1260,55 @@ function test.get_device() end -- Unallocated tensors are on device 0 for i = 1,device_count do - assert(tensors[i]:getDevice() == 0, "unallocated tensor does not have deviceID 0") - -- Now allocate it - cutorch.setDevice(i) - tensors[i]:resize(1, 2, 3) - assert(tensors[i]:getDevice() == i, "tensor does not have the correct deviceID") + tester:assert(tensors[i]:getDevice() == 0, "unallocated tensor does not have deviceID 0") + -- Now allocate it + cutorch.setDevice(i) + tensors[i]:resize(1, 2, 3) + tester:assert(tensors[i]:getDevice() == i, "tensor does not have the correct deviceID") end end +function test.multi_gpu_copy_noncontig() + local srcDevice = 1 + local dstDevice = cutorch.getDeviceCount() + + local t1, t2 + for transposeSrc = 0,1 do + for transposeDst = 0,1 do + cutorch.withDevice(srcDevice, + function() t1 = torch.CudaTensor(100000, 1000):fill(1) end) + cutorch.withDevice(dstDevice, + function() t2 = torch.CudaTensor(100000, 1000):fill(2) end) + + if transposeSrc == 1 then -- maybe make t1 non-contiguous + cutorch.withDevice(srcDevice, function() t1=t1:transpose(1,2) end) + end + if transposeDst == 1 then -- maybe make t2 non-contiguous + cutorch.withDevice(dstDevice, function() t2=t2:transpose(1,2) end) + end + cutorch.synchronize() + + -- try to induce a race on t2 + cutorch.withDevice(dstDevice, function() t2:fill(3) end) + + -- perform the copy + -- CudaTensor:copy() should not depend on the current device + t2:copy(t1) + + -- try to induce a race on t1 + cutorch.withDevice(srcDevice, function() t1:fill(4) end) + + -- only synchronize with dstDevice because + -- previous line guarantees synchronization with srcDevice + cutorch.withDevice(dstDevice, function() cutorch.synchronize() end) + + local t2_max = t2:max() + tester:assert(t2_max == 1, "bad copy, transposeSrc= " .. transposeSrc .. + " transposeDst= " .. transposeDst .. ". t2:max() = " .. t2_max) + end + end +end + function test.reset_device() local sz = math.floor(torch.uniform(minsize,maxsize)) @@ -1268,6 +1327,179 @@ function test.reset_device() tester:assertTensorEq(tf, u:float(), 1e-6, "values not equal after restoring the RNG state") end +function test.maskedSelect() + local n_row = math.random(minsize,maxsize) + local n_col = math.random(minsize,maxsize) + + -- contiguous, no result tensor, cuda mask + local x = torch.randn(n_row, n_col):float() + local mask = torch.ByteTensor(n_row,n_col):bernoulli() + local y = x:maskedSelect(mask) + x=x:cuda() + mask=mask:cuda() + local y_cuda = x:maskedSelect(mask) + tester:assertTensorEq(y, y_cuda:float(), 0.00001, "Error in maskedSelect") + + -- non-contiguous, no result tensor, cuda mask + local x = torch.randn(n_row, n_col):float() + local mask = torch.ByteTensor(n_row,n_col):bernoulli() + local y = x:t():maskedSelect(mask) + x=x:cuda() + mask=mask:cuda() + local y_cuda = x:t():maskedSelect(mask) + tester:assertTensorEq(y, y_cuda:float(), 0.00001, "Error in maskedSelect non-contiguous") + + -- contiguous, with result tensor, cuda mask + local x = torch.randn(n_row, n_col):float() + local mask = torch.ByteTensor(n_row,n_col):bernoulli() + local y = torch.FloatTensor() + y:maskedSelect(x, mask) + x=x:cuda() + mask=mask:cuda() + local y_cuda = torch.CudaTensor() + y_cuda:maskedSelect(x, mask) + tester:assertTensorEq(y, y_cuda:float(), 0.00001, "Error in maskedSelect (with result)") + + -- non-contiguous, with result tensor, cuda mask + local x = torch.randn(n_row, n_col):float() + local mask = torch.ByteTensor(n_row,n_col):bernoulli() + local y = torch.FloatTensor() + y:maskedSelect(x:t(), mask) + x=x:cuda() + mask=mask:cuda() + local y_cuda = torch.CudaTensor() + y_cuda:maskedSelect(x:t(), mask) + tester:assertTensorEq(y, y_cuda:float(), 0.00001, + "Error in maskedSelect non-contiguous (with result)") + + -- indexing maskedSelect a[a:gt(0.5)] for example + local x = torch.randn(n_row, n_col):float() + local y = x[x:gt(0.5)] + x=x:cuda() + mask=mask:cuda() + local y_cuda = x[x:gt(0.5)] + tester:assertTensorEq(y, y_cuda:float(), 0.00001, "Error in maskedSelect indexing x[x:gt(y)]") + + -- indexing maskedSelect (non-contiguous) a[a:gt(0.5)] for example + local x = torch.randn(n_row, n_col):float() + local y = x:t()[x:t():gt(0.5)] + x=x:cuda() + mask=mask:cuda() + local y_cuda = x:t()[x:t():gt(0.5)] + tester:assertTensorEq(y, y_cuda:float(), 0.00001, + "Error in maskedSelect indexing (non-contiguous) x[x:gt(y)]") +end + +--[[ +waiting on clarification for: https://github.com/torch/torch7/pull/187 +function test.maskedCopy() + local n_row = math.random(minsize,maxsize) + local n_col = math.random(minsize,maxsize) + + -- contiguous, cuda mask + local x = torch.randn(n_row, n_col):float() + local y = x:clone():fill(-1) + local mask = torch.ByteTensor(n_row,n_col):bernoulli() + y:maskedCopy(mask, x:clone()) + local y_cuda=x:cuda():fill(-1) + mask=mask:cuda() + x=x:cuda() + y_cuda:maskedCopy(mask, x) + tester:assertTensorEq(y, y_cuda:float(), 0.00001, "Error in maskedCopy (contiguous)") + -- non-contiguous source, cuda mask + local x = torch.randn(n_row, n_col):float() + local y = x:clone():fill(-1) + local mask = torch.ByteTensor(n_row,n_col):bernoulli() + y:maskedCopy(mask, x:t()) + local y_cuda=x:cuda():fill(-1) + x=x:cuda() + mask=mask:cuda() + y_cuda:maskedCopy(mask, x:t()) + tester:assertTensorEq(y, y_cuda:float(), 0.00001, "Error in maskedCopy (non-contiguous source)") + + -- non-contiguous result, cuda mask + local x = torch.randn(n_row, n_col):float() + local y = x:clone():fill(-1) + local mask = torch.ByteTensor(n_row,n_col):bernoulli() + y:t():maskedCopy(mask, x:t()) + local y_cuda=x:cuda():fill(-1) + x=x:cuda() + mask=mask:cuda() + y_cuda:t():maskedCopy(mask, x:t()) + tester:assertTensorEq(y, y_cuda:float(), 0.00001, "Error in maskedCopy (non-contiguous dest)") + + -- indexing maskedCopy a[a:gt(0.5)] for example + local gt = torch.randn(n_row, n_col):float() + local x = gt:clone() + local y = torch.randn(n_row, n_col):float() + x[x:gt(0.5)] = y + local x_cuda = gt:cuda() + y=y:cuda() + x_cuda[x_cuda:gt(0.5)] = y + tester:assertTensorEq(x, x_cuda:float(), 0.00001, "Error in maskedCopy indexing x[x:gt(y)]") + + -- indexing maskedCopy non-contiguous src a[a:gt(0.5)] for example + local gt = torch.randn(n_row, n_col):float() + local x = gt:clone() + local y = torch.randn(n_row, n_col):float() + x[x:gt(0.5)] = y:t() + local x_cuda = gt:cuda() + y=y:cuda() + x_cuda[x_cuda:gt(0.5)] = y:t() + tester:assertTensorEq(x, x_cuda:float(), 0.00001, "Error in maskedCopy indexing x[x:gt(y)]") + + -- indexing maskedCopy non-contiguous dst a[a:gt(0.5)] for example + local gt = torch.randn(n_row, n_col):float() + local x = gt:clone() + local y = torch.randn(n_row, n_col):float() + x:t()[x:t():gt(0.5)] = y + local x_cuda = gt:cuda() + y=y:cuda() + x_cuda:t()[x_cuda:t():gt(0.5)] = y:t() + tester:assertTensorEq(x, x_cuda:float(), 0.00001, "Error in maskedCopy indexing x[x:gt(y)]") +end +]]-- + +function test.maskedFill() + local n_row = math.random(minsize,maxsize) + local n_col = math.random(minsize,maxsize) + + -- contiguous, no result tensor, cuda mask + local gt = torch.randn(n_row, n_col):float() + local x = gt:clone() + local mask = torch.ByteTensor(n_row,n_col):bernoulli() + x:maskedFill(mask, 334) + local x_cuda=gt:cuda() + mask=mask:cuda() + x_cuda:maskedFill(mask, 334) + tester:assertTensorEq(x, x_cuda:float(), 0.00001, "Error in maskedFill") + + -- non-contiguous, no result tensor, cuda mask + local x = gt:clone() + mask = mask:byte() + x:t():maskedFill(mask, 334) + local x_cuda = gt:cuda() + mask=mask:cuda() + x_cuda:t():maskedFill(mask, 334) + tester:assertTensorEq(x, x_cuda:float(), 0.00001, "Error in maskedFill non-contiguous") + + -- indexing maskedFill a[a:gt(0.5)] for example + local x = gt:clone() + x[x:gt(0.5)] = 334 + local x_cuda = gt:cuda() + x_cuda[x_cuda:gt(0.5)] = 334 + tester:assertTensorEq(x, x_cuda:float(), 0.00001, "Error in maskedFill indexing x[x:gt(y)]") + + -- indexing maskedFill a[a:gt(0.5)] for example + local x = gt:clone() + x:t()[x:t():gt(0.5)] = 334 + local x_cuda = gt:cuda() + x_cuda:t()[x_cuda:t():gt(0.5)] = 334 + tester:assertTensorEq(x, x_cuda:float(), 0.00001, + "Error in maskedFill non-contiguous indexing x[x:gt(y)]") + +end + function cutorch.test(tests) math.randomseed(os.time()) torch.manualSeed(os.time()) @@ -1275,12 +1507,13 @@ function cutorch.test(tests) tester = torch.Tester() tester:add(test) tester:run(tests) - print '' - for module,tm in pairs(times) do - print(module .. ': \t average speedup is ' .. (tm.cpu / (tm.gpu or 1e6))) - end + -- print '' + -- for module,tm in pairs(times) do + -- print(module .. ': \t average speedup is ' .. (tm.cpu / (tm.gpu or 1e6))) + -- end end if runtests then cutorch.test() end +return test diff --git a/torch/generic/Tensor.c b/torch/generic/Tensor.c index d5f70475..228cb283 100644 --- a/torch/generic/Tensor.c +++ b/torch/generic/Tensor.c @@ -137,7 +137,7 @@ static int torch_Tensor_(new)(lua_State *L) THStorage_(set)(state, THTensor_(storage)(state, tensor), si++, (real)lua_tonumber(L, -1)); lua_pop(L, 1); } - + if(size->size == 1) break; @@ -191,7 +191,7 @@ static int torch_Tensor_(new)(lua_State *L) torch_Tensor_(c_readTensorStorageSizeStride)(L, 1, 1, 1, 1, 1, &storage, &storageOffset, &size, &stride); - + tensor = THTensor_(newWithStorage)(state, storage, storageOffset, size, stride); THLongStorage_free(size); @@ -308,7 +308,7 @@ static int torch_Tensor_(sub)(lua_State *L) d1e += tensor->size[1]+1; luaL_argcheck(L, tensor->nDimension > 1, 4, "invalid dimension"); luaL_argcheck(L, d1s >= 0 && d1s < tensor->size[1], 4, "out of range"); - luaL_argcheck(L, d1e >= 0 && d1e < tensor->size[1], 5, "out of range"); + luaL_argcheck(L, d1e >= 0 && d1e < tensor->size[1], 5, "out of range"); luaL_argcheck(L, d1e >= d1s, 5, "end smaller than beginning"); if(!lua_isnone(L, 6)) @@ -321,7 +321,7 @@ static int torch_Tensor_(sub)(lua_State *L) d2e += tensor->size[2]+1; luaL_argcheck(L, tensor->nDimension > 2, 6, "invalid dimension"); luaL_argcheck(L, d2s >= 0 && d2s < tensor->size[2], 6, "out of range"); - luaL_argcheck(L, d2e >= 0 && d2e < tensor->size[2], 7, "out of range"); + luaL_argcheck(L, d2e >= 0 && d2e < tensor->size[2], 7, "out of range"); luaL_argcheck(L, d2e >= d2s, 7, "end smaller than beginning"); if(!lua_isnone(L, 8)) @@ -334,7 +334,7 @@ static int torch_Tensor_(sub)(lua_State *L) d3e += tensor->size[3]+1; luaL_argcheck(L, tensor->nDimension > 3, 8, "invalid dimension"); luaL_argcheck(L, d3s >= 0 && d3s < tensor->size[3], 8, "out of range"); - luaL_argcheck(L, d3e >= 0 && d3e < tensor->size[3], 9, "out of range"); + luaL_argcheck(L, d3e >= 0 && d3e < tensor->size[3], 9, "out of range"); luaL_argcheck(L, d3e >= d3s, 9, "end smaller than beginning"); } } @@ -568,6 +568,7 @@ static int torch_Tensor_(__newindex__)(lua_State *L) THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THLongStorage *idx = NULL; THByteTensor *mask; + THCudaTensor *maskCuda; if(lua_isnumber(L, 2)) { @@ -740,11 +741,27 @@ static int torch_Tensor_(__newindex__)(lua_State *L) THTensor *vals; if (lua_isnumber(L, 3)) { - THTensor_(maskedFill)(state, tensor, mask, (real)(luaL_checknumber(L,3))); + THTensor_(maskedFillByte)(state, tensor, mask, (real)(luaL_checknumber(L,3))); } else if((vals = luaT_toudata(L, 3, torch_Tensor))) { - THTensor_(maskedCopy)(state, tensor, mask, vals); + THTensor_(maskedCopyByte)(state, tensor, mask, vals); + } + else + { + luaL_error(L,"number or tensor expected"); + } + } + else if((maskCuda = luaT_toudata(L, 2, "torch.CudaTensor"))) + { + THTensor *vals; + if (lua_isnumber(L, 3)) + { + THTensor_(maskedFill)(state, tensor, maskCuda, (real)(luaL_checknumber(L,3))); + } + else if((vals = luaT_toudata(L, 3, torch_Tensor))) + { + THTensor_(maskedCopy)(state, tensor, maskCuda, vals); } else { @@ -763,6 +780,7 @@ static int torch_Tensor_(__index__)(lua_State *L) THTensor *tensor = luaT_checkudata(L, 1, torch_Tensor); THLongStorage *idx = NULL; THByteTensor *mask; + THCudaTensor *maskCuda; if(lua_isnumber(L, 2)) { @@ -871,7 +889,15 @@ static int torch_Tensor_(__index__)(lua_State *L) else if((mask = luaT_toudata(L, 2, "torch.ByteTensor"))) { THTensor *vals = THTensor_(new)(state); - THTensor_(maskedSelect)(state, vals, tensor, mask); + THTensor_(maskedSelectByte)(state, vals, tensor, mask); + luaT_pushudata(L, vals, torch_Tensor); + lua_pushboolean(L, 1); + return 2; + } + else if((maskCuda = luaT_toudata(L, 2, "torch.CudaTensor"))) + { + THTensor *vals = THTensor_(new)(state); + THTensor_(maskedSelect)(state, vals, tensor, maskCuda); luaT_pushudata(L, vals, torch_Tensor); lua_pushboolean(L, 1); return 2; @@ -895,7 +921,7 @@ static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowSt { THLongStorage *size = NULL; THLongStorage *stride = NULL; - + if( (size = luaT_toudata(L, index, "torch.LongStorage")) ) { if(!lua_isnoneornil(L, index+1)) @@ -925,7 +951,7 @@ static void torch_Tensor_(c_readSizeStride)(lua_State *L, int index, int allowSt if(lua_isnone(L, index+2*i)) break; size->data[i] = luaL_checklong(L, index+2*i); - + if(lua_isnone(L, index+2*i+1)) break; stride->data[i] = luaL_checklong(L, index+2*i+1);