diff --git a/SoftMax.cu b/SoftMax.cu index 61499bf0..364b6fb7 100644 --- a/SoftMax.cu +++ b/SoftMax.cu @@ -3,12 +3,12 @@ #define MINUS_LOG_THRESHOLD -18.42 #define SOFTMAX_THREADS 128 -__global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input, int nframe, int dim) +__global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input, + int nframe, int dim, int stride) { __shared__ float buffer[SOFTMAX_THREADS+1]; - int k = blockIdx.x; - float *input_k = input + k*dim; - float *output_k = output + k*dim; + float *input_k = input + blockIdx.x*dim*stride + blockIdx.y; + float *output_k = output + blockIdx.x*dim*stride + blockIdx.y; int i_start = threadIdx.x; int i_end = dim; @@ -18,7 +18,7 @@ __global__ void cunn_SoftMax_updateOutput_kernel(float *output, float *input, in buffer[threadIdx.x] = -FLT_MAX; for (int i=i_start; inDimension == 1) { - dim3 blocks(1); - dim3 threads(SOFTMAX_THREADS); - cunn_SoftMax_updateOutput_kernel<<>>(THCudaTensor_data(state, output), - THCudaTensor_data(state, input), - 1, input->size[0]); + batchSize = 1; + dim = input->size[0]; + stride = 1; } else if(input->nDimension == 2) { - dim3 blocks(input->size[0]); - dim3 threads(SOFTMAX_THREADS); - cunn_SoftMax_updateOutput_kernel<<>>(THCudaTensor_data(state, output), - THCudaTensor_data(state, input), - input->size[0], input->size[1]); + batchSize = input->size[0]; + dim = input->size[1]; + stride = 1; + } + else if(input->nDimension == 3) + { + batchSize = 1; + dim = input->size[0]; + stride = input->size[1]*input->size[2]; + } + else if(input->nDimension == 4) + { + batchSize = input->size[0]; + dim = input->size[1]; + stride = input->size[2]*input->size[3]; } else - THError("vector or matrix expected"); + THError("1D, 2D, 3D or 4D tensor expected"); + + dim3 blocks(batchSize, stride); + dim3 threads(SOFTMAX_THREADS); + cunn_SoftMax_updateOutput_kernel<<>>(THCudaTensor_data(state, output), + THCudaTensor_data(state, input), + batchSize, dim, stride); cudaError errcode = cudaGetLastError(); if(errcode != cudaSuccess) @@ -142,18 +156,6 @@ static int cunn_SoftMax_updateOutput(lua_State *L) return 1; } -struct softmaxupdateGradInput_functor -{ - float value; - - softmaxupdateGradInput_functor(float value_) : value(value_) {} - - __host__ __device__ float operator()(const float& output, const float& gradOutput) const - { - return gradOutput - exp(output)*value; - } -}; - static int cunn_SoftMax_updateGradInput(lua_State *L) { THCState *state = getCutorchState(L); @@ -166,31 +168,42 @@ static int cunn_SoftMax_updateGradInput(lua_State *L) gradOutput = THCudaTensor_newContiguous(state, gradOutput); THCudaTensor_resizeAs(state, gradInput, output); + long batchSize, dim, stride; if(gradInput->nDimension == 1) { - dim3 blocks(1); - dim3 threads(SOFTMAX_THREADS); - - cunn_SoftMax_updateGradInput_kernel<<>>(THCudaTensor_data(state, gradInput), - THCudaTensor_data(state, output), - THCudaTensor_data(state, gradOutput), - 1, gradInput->size[0]); + batchSize = 1; + dim = gradInput->size[0]; + stride = 1; } else if(gradInput->nDimension == 2) { - dim3 blocks(gradInput->size[0]); - dim3 threads(SOFTMAX_THREADS); - - cunn_SoftMax_updateGradInput_kernel<<>>(THCudaTensor_data(state, gradInput), - THCudaTensor_data(state, output), - THCudaTensor_data(state, gradOutput), - gradInput->size[0], gradInput->size[1]); + batchSize = gradInput->size[0]; + dim = gradInput->size[1]; + stride = 1; + } + else if(gradInput->nDimension == 3) + { + batchSize = 1; + dim = gradInput->size[0]; + stride = gradInput->size[1]*gradInput->size[2]; + } + else if(gradInput->nDimension == 4) + { + batchSize = gradInput->size[0]; + dim = gradInput->size[1]; + stride = gradInput->size[2]*gradInput->size[3]; } else - THError("vector or matrix expected"); + THError("1D, 2D, 3D or 4D tensor expected"); + + dim3 blocks(batchSize, stride); + dim3 threads(SOFTMAX_THREADS); + cunn_SoftMax_updateGradInput_kernel<<>>(THCudaTensor_data(state, gradInput), + THCudaTensor_data(state, output), + THCudaTensor_data(state, gradOutput), + batchSize, dim, stride); cudaError errcode = cudaGetLastError(); if(errcode != cudaSuccess)