Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single-dimension std and var. #96

Merged
merged 1 commit into from
Jan 6, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion TensorMath.lua
Original file line number Diff line number Diff line change
Expand Up @@ -585,7 +585,12 @@ for _,name in ipairs({"var", "std"}) do
wrap(name,
cname(name .. "all"),
{{name=Tensor},
{name=real, creturned=true}})
{name=real, creturned=true}},
cname(name),
{{name=Tensor, default=true, returned=true},
{name=Tensor},
{name="index"},
{name="boolean", default=false}})
end

wrap("norm",
Expand Down
208 changes: 208 additions & 0 deletions lib/THC/THCTensorMath.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1190,6 +1190,214 @@ float THCudaTensor_stdall(THCudaTensor *self)
return sqrt(THCudaTensor_varall(self));
}

// Given the sum of values and the sum of squares, compute the variance or standard deviation.
template<bool flag, bool apply_sqrt>
__forceinline__ __device__ float THCudaTensor_computeVar(float sum, float sum2, unsigned row_size) {
if (flag) {
sum /= row_size;
sum2 /= row_size;
sum2 -= sum * sum;
sum2 = (sum2 < 0 ? 0 : sum2);
}
else {
sum /= row_size;
sum2 /= row_size - 1;
sum2 -= ((float)row_size) / ((float)(row_size - 1)) * sum * sum;
sum2 = (sum2 < 0 ? 0 : sum2);
}
if (apply_sqrt)
return sqrt(sum2);
else
return sum2;
}

/* Compute the variance (or standard deviation) along an outer dimension of a tensor.
*
* - num_orows is the size of the flattened outer dimensions;
* - num_irows is the size of the flattened inner dimensions;
* - row_size is the size of the dimension along which to compute the variance;
* - if flag is set, normalize by `row_size` instead of `row_size - 1`
* - if apply_sqrt is set, compute the standard deviation instead of variance
*
* The dimensions to the outside and inside of the specified dimension are considered as flattened.
* Thread blocks with the same blockIdx.y process an "outer row" (i.e. an element of the flattened
* outer dimensions, which contains several "inner rows").
* Each thread processes a single inner row at a time.
*/
template<bool flag, bool apply_sqrt>
__global__ void THCudaTensor_kernel_varOuterDim(float *tgt, float *src_, unsigned num_orows, unsigned num_irows, unsigned row_size)
{
for (unsigned orow = blockIdx.x; orow < num_orows; orow += gridDim.x) {
for (unsigned irow = blockIdx.y * blockDim.x + threadIdx.x; irow < num_irows; irow += gridDim.y * blockDim.x) {
float *src = src_ + orow * row_size * num_irows + irow;
float sum = 0, sum2 = 0;

for (unsigned col = 0; col < row_size; ++col) {
float val = *src;
sum += val;
sum2 += val * val;

src += num_irows;
}

tgt[orow * num_irows + irow] = THCudaTensor_computeVar<flag, apply_sqrt>(sum, sum2, row_size);
}
}
}

template<bool apply_sqrt>
__host__ void THCudaTensor_varOuterDim(THCudaTensor *tgt, THCudaTensor *src, long dimension, int flag)
{
unsigned ndim = THCudaTensor_nDimension(src);
// Treat all outer dimensions (i.e. dim < dimension) as one.
unsigned num_orows = 1;
for (unsigned dim = 0; dim < dimension; dim++) {
num_orows *= THCudaTensor_size(src, dim);
}
unsigned row_size = THCudaTensor_size(src, dimension);
// Treat all inner dimensions (i.e. dim > dimension) as one.
unsigned num_irows = 1;
for (unsigned dim = dimension + 1; dim < ndim; dim++) {
num_irows *= THCudaTensor_size(src, dim);
}

dim3 threads(min(512, num_irows));
unsigned maxGridDim = 1024;
dim3 grid(min(maxGridDim, num_orows), min(maxGridDim, DIVUP(num_irows, threads.x)));

if (flag) {
THCudaTensor_kernel_varOuterDim<true, apply_sqrt><<<grid, threads>>>(
THCudaTensor_data(tgt), THCudaTensor_data(src), num_orows, num_irows, row_size);
} else {
THCudaTensor_kernel_varOuterDim<false, apply_sqrt><<<grid, threads>>>(
THCudaTensor_data(tgt), THCudaTensor_data(src), num_orows, num_irows, row_size);
}
cudaError errcode = cudaGetLastError();
if (errcode != cudaSuccess) {
THError(cudaGetErrorString(errcode));
}
}


/* Compute the variance (or standard deviation) of the innermost dimension of a tensor.
*
* - num_rows is the size of the flattened outer dimensions;
* - row_size is the size of the innermost dimension;
* - if flag is set, normalize by `row_size` instead of `row_size - 1`
* - if apply_sqrt is set, compute the standard deviation instead of variance
*
* The outer dimensions of the tensor are considered as a single dimension, i.e. the tensor is
* considered as having 'num_rows' rows of size 'row_size'.
* Each thread block processes one or more sets of contiguous rows (processing multiple rows
* per thread block is quicker than processing a single row, especially for short rows).
*/
template<bool flag, bool apply_sqrt>
__global__ void THCudaTensor_kernel_varInnermostDim(float *tgt, float *src_, unsigned num_rows, unsigned row_size)
{
__shared__ float ssum[32][16];
__shared__ float ssum2[32][16];

for (unsigned block_row = blockIdx.x * blockDim.y; block_row < num_rows; block_row += blockDim.y * gridDim.x) {
unsigned row = block_row + threadIdx.y;
float sum = 0, sum2 = 0;
if (row < num_rows) {
float *src = src_ + row * row_size;
// Sequential reduction within a thread.
for (unsigned col = threadIdx.x; col < row_size; col += blockDim.x) {
float val = src[col];
sum += val;
sum2 += val * val;
}
}
ssum[threadIdx.y][threadIdx.x] = sum;
ssum2[threadIdx.y][threadIdx.x] = sum2;
__syncthreads();

// Reduce intermediate values to single value.
for (unsigned s = 8; s > 1; s >>= 1) {
if (row < num_rows && threadIdx.x < s) {
ssum[threadIdx.y][threadIdx.x] += ssum[threadIdx.y][threadIdx.x + s];
ssum2[threadIdx.y][threadIdx.x] += ssum2[threadIdx.y][threadIdx.x + s];
}
__syncthreads();
}

if (row < num_rows && threadIdx.x == 0) {
sum = ssum[threadIdx.y][0] + ssum[threadIdx.y][1];
sum2 = ssum2[threadIdx.y][0] + ssum2[threadIdx.y][1];
tgt[row] = THCudaTensor_computeVar<flag, apply_sqrt>(sum, sum2, row_size);
}
__syncthreads();
}
}

template<bool apply_sqrt>
__host__ void THCudaTensor_varInnermostDim(THCudaTensor *tgt, THCudaTensor *src, int flag)
{
unsigned ndim = THCudaTensor_nDimension(src);
// Treat all outer dimensions as a single dimension.
unsigned num_rows = 1;
for (unsigned dim = 0; dim < ndim - 1; dim++) {
num_rows *= THCudaTensor_size(src, dim);
}
unsigned row_size = THCudaTensor_size(src, ndim - 1);

// From limited testing, 16x32 seemed a good compromise for handling both long and short dimensions.
dim3 threads(16, 32);
dim3 grid(min(1024, DIVUP(num_rows, threads.y)));

if (flag) {
THCudaTensor_kernel_varInnermostDim<true, apply_sqrt><<<grid, threads>>>(
THCudaTensor_data(tgt), THCudaTensor_data(src), num_rows, row_size);
} else {
THCudaTensor_kernel_varInnermostDim<false, apply_sqrt><<<grid, threads>>>(
THCudaTensor_data(tgt), THCudaTensor_data(src), num_rows, row_size);
}
cudaError errcode = cudaGetLastError();
if (errcode != cudaSuccess) {
THError(cudaGetErrorString(errcode));
}
}

void THCudaTensor_var(THCudaTensor *self_, THCudaTensor *src, long dimension, int flag)
{
THLongStorage *dim = THCudaTensor_newSizeOf(src);
THLongStorage_set(dim, dimension, 1);
THCudaTensor_resize(self_, dim, NULL);
THLongStorage_free(dim);

THCudaTensor *self = THCudaTensor_newContiguous(self_);
src = THCudaTensor_newContiguous(src);

if (dimension == THCudaTensor_nDimension(src) - 1) {
THCudaTensor_varInnermostDim<false>(self, src, flag);
} else {
THCudaTensor_varOuterDim<false>(self, src, dimension, flag);
}

THCudaTensor_free(src);
THCudaTensor_freeCopyTo(self, self_);
}

void THCudaTensor_std(THCudaTensor *self_, THCudaTensor *src, long dimension, int flag)
{
THLongStorage *dim = THCudaTensor_newSizeOf(src);
THLongStorage_set(dim, dimension, 1);
THCudaTensor_resize(self_, dim, NULL);
THLongStorage_free(dim);

THCudaTensor *self = THCudaTensor_newContiguous(self_);
src = THCudaTensor_newContiguous(src);

if (dimension == THCudaTensor_nDimension(src) - 1) {
THCudaTensor_varInnermostDim<true>(self, src, flag);
} else {
THCudaTensor_varOuterDim<true>(self, src, dimension, flag);
}

THCudaTensor_free(src);
THCudaTensor_freeCopyTo(self, self_);
}


template<class Op>
Expand Down
2 changes: 2 additions & 0 deletions lib/THC/THCTensorMath.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ THC_API void THCudaTensor_neTensor(THCudaTensor *self_, THCudaTensor *src1, THCu
THC_API float THCudaTensor_meanall(THCudaTensor *self);
THC_API void THCudaTensor_mean(THCudaTensor *self, THCudaTensor *src, long dim);
THC_API float THCudaTensor_varall(THCudaTensor *self);
THC_API void THCudaTensor_var(THCudaTensor *self, THCudaTensor *src, long dim, int flag);
THC_API float THCudaTensor_stdall(THCudaTensor *self);
THC_API void THCudaTensor_std(THCudaTensor *self, THCudaTensor *src, long dim, int flag);
THC_API float THCudaTensor_normall(THCudaTensor *self, float value);
THC_API void THCudaTensor_norm(THCudaTensor* self, THCudaTensor* src, float value, long dimension);
THC_API void THCudaTensor_renorm(THCudaTensor* self, THCudaTensor* src, float value, long dimension, float max_norm);
Expand Down
18 changes: 8 additions & 10 deletions test/test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -589,23 +589,21 @@ function test.var()
local sz2 = math.floor(torch.uniform(minsize,maxsize))
local x = torch.FloatTensor():rand(sz1, sz2)
compareFloatAndCuda(x, 'var')
-- multi-dim var is not implemented
-- compareFloatAndCuda(x, 'var', 1, true)
-- compareFloatAndCuda(x, 'var', 1, false)
-- compareFloatAndCuda(x, 'var', 2, true)
-- compareFloatAndCuda(x, 'var', 2, false)
compareFloatAndCuda(x, 'var', 1, true)
compareFloatAndCuda(x, 'var', 1, false)
compareFloatAndCuda(x, 'var', 2, true)
compareFloatAndCuda(x, 'var', 2, false)
end

function test.std()
local sz1 = math.floor(torch.uniform(minsize,maxsize))
local sz2 = math.floor(torch.uniform(minsize,maxsize))
local x = torch.FloatTensor():rand(sz1, sz2)
compareFloatAndCuda(x, 'std')
-- multi-dim std is not implemented
-- compareFloatAndCuda(x, 'std', 1, true)
-- compareFloatAndCuda(x, 'std', 1, false)
-- compareFloatAndCuda(x, 'std', 2, true)
-- compareFloatAndCuda(x, 'std', 2, false)
compareFloatAndCuda(x, 'std', 1, true)
compareFloatAndCuda(x, 'std', 1, false)
compareFloatAndCuda(x, 'std', 2, true)
compareFloatAndCuda(x, 'std', 2, false)
end

-- Test element-wise unary operators with both one and two arguments.
Expand Down