From ca879aa9b9c16b724b4395980946a38eebf5ee9e Mon Sep 17 00:00:00 2001 From: moisesh Date: Fri, 6 Dec 2019 16:35:58 -0800 Subject: [PATCH 1/6] Update multi_sum_sq to avoid AtomicAdd --- src/operator/contrib/multi_sum_sq-inl.h | 10 ++- src/operator/contrib/multi_sum_sq.cc | 10 ++- src/operator/contrib/multi_sum_sq.cu | 87 ++++++++++++++++--------- 3 files changed, 71 insertions(+), 36 deletions(-) diff --git a/src/operator/contrib/multi_sum_sq-inl.h b/src/operator/contrib/multi_sum_sq-inl.h index 876155215d1c..b8609c0f217f 100644 --- a/src/operator/contrib/multi_sum_sq-inl.h +++ b/src/operator/contrib/multi_sum_sq-inl.h @@ -21,7 +21,7 @@ * Copyright (c) 2019 by Contributors * \file multi_l2_norm-inl.h * \brief vectorized L2 norm over multiple arrays operators - * \author Clement Fuji Tsang, Andrei Ivanov + * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez */ @@ -32,6 +32,10 @@ #include #include "../operator_common.h" +namespace multi_sum_sq { +enum MultiSumSqUpdateResource {kTempSpace}; +} // namespace multi_sum_sq + namespace mxnet { namespace op { @@ -80,7 +84,7 @@ inline bool MultiSumSqType(const NodeAttrs& attrs, template void MultiSumSqRun(const std::vector &inputs, int nInputs, - float *out_ptr, mshadow::Stream *s); + float *out_ptr, const OpContext &ctx); template void MultiSumSq(const nnvm::NodeAttrs& attrs, @@ -91,7 +95,7 @@ void MultiSumSq(const nnvm::NodeAttrs& attrs, auto s = ctx.get_stream(); const auto& p = dmlc::get(attrs.parsed); float* out_ptr = outputs[0].FlatTo2D(s).dptr_; - MultiSumSqRun(inputs, p.num_arrays, out_ptr, s); + MultiSumSqRun(inputs, p.num_arrays, out_ptr, ctx); } } // namespace op diff --git a/src/operator/contrib/multi_sum_sq.cc b/src/operator/contrib/multi_sum_sq.cc index cdb5423db23f..912bc1a59197 100644 --- a/src/operator/contrib/multi_sum_sq.cc +++ b/src/operator/contrib/multi_sum_sq.cc @@ -21,7 +21,7 @@ * Copyright (c) 2019 by Contributors * \file multi_sum_sq.cc * \brief vectorized sum or squared over multiple arrays operators - * \author Clement Fuji Tsang, Andrei Ivanov + * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez */ #include "./multi_sum_sq-inl.h" @@ -52,6 +52,10 @@ NNVM_REGISTER_OP(multi_sum_sq) return ret; }) .set_attr("FCompute", MultiSumSq) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) .add_argument("data", "NDArray-or-Symbol[]", "Arrays") .add_arguments(MultiSumSqParam::__FIELDS__()); @@ -74,9 +78,9 @@ inline void CalcSumSq(const std::vector &inputs, int nInputs, template<> void MultiSumSqRun(const std::vector &inputs, int nInputs, - float *out_ptr, mshadow::Stream *s) { + float *out_ptr, const OpContext &ctx) { MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, - CalcSumSq(inputs, nInputs, out_ptr, s); + CalcSumSq(inputs, nInputs, out_ptr, ctx.get_stream()); ) } diff --git a/src/operator/contrib/multi_sum_sq.cu b/src/operator/contrib/multi_sum_sq.cu index 6f6fe56bfd81..ad81b3dd9736 100644 --- a/src/operator/contrib/multi_sum_sq.cu +++ b/src/operator/contrib/multi_sum_sq.cu @@ -21,7 +21,7 @@ * Copyright (c) 2019 by Contributors * \file multi_sum_sq.cu * \brief vectorized sums of squares norm over multiple arrays operators - * \author Clement Fuji Tsang, Andrei Ivanov + * \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez */ #include "./multi_sum_sq-inl.h" #include @@ -43,15 +43,14 @@ struct MultiSumSqKernelParam { int sizes[ARRAY_LIMIT]; unsigned char block_to_tensor[BLOCK_LIMIT]; int block_to_chunk[BLOCK_LIMIT]; + int max_chunks_per_tensor = -1; }; template __device__ __forceinline__ DType reduce_block_into_lanes(DType* x, - DType val, - int lanes = 1, - bool share_result = false) { - int tid = threadIdx.x + threadIdx.y * blockDim.x; - int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32. + DType val) { + int tid = threadIdx.x; + int blockSize = blockDim.x; if (blockSize >= 64) { x[tid] = val; @@ -72,27 +71,19 @@ __device__ __forceinline__ DType reduce_block_into_lanes(DType* x, final = x[tid] + x[tid+32]; else final = val; - // __SYNCWARP(); #pragma unroll - for (int i = 16; i >= lanes; i >>= 1) + for (int i = 16; i >= 1; i >>= 1) final = final + __shfl_down_sync(0xffffffff, final, i); } - - if (share_result) { - if (tid < lanes) - x[tid] = final; // EpilogueOp - // Make sure the smem result is visible to all warps. - __syncthreads(); - } - return final; } template __global__ void MultiSumSqKernel(int chunk_size, MultiSumSqKernelParam param, - float* output) { + float* block_reductions, + int start_tensor_id) { const int tensor_loc = param.block_to_tensor[blockIdx.x]; const int chunk_len = param.block_to_chunk[blockIdx.x] * chunk_size; const int n = param.sizes[tensor_loc] - chunk_len; @@ -106,32 +97,65 @@ __global__ void MultiSumSqKernel(int chunk_size, i_start < iMax; i_start += blockDim.x * ILP) { int i = i_start + threadIdx.x; - // #pragma unroll +#pragma unroll for (int ii = 0; ii < ILP && i < iMax; ++ii, i += blockDim.x) { const auto incoming_val = static_cast(x[i]); val += incoming_val * incoming_val; } } - const float final = reduce_block_into_lanes(vals, val); - if (threadIdx.x == 0) - atomicAdd(output + tensor_loc, final); + + if (threadIdx.x == 0){ + block_reductions[(start_tensor_id + tensor_loc) * param.max_chunks_per_tensor + + param.block_to_chunk[blockIdx.x]] = final; + } +} + +template +__global__ void GlobalReductionKernel(MultiSumSqKernelParam param, + float* block_reductions, + float* output) { + __shared__ float vals[512]; + float* reductions_this_tensor = block_reductions + blockIdx.x * param.max_chunks_per_tensor; + float val = 0; + for(int i = threadIdx.x; i < param.max_chunks_per_tensor; i += blockDim.x) + val += reductions_this_tensor[i]; + + float final = reduce_block_into_lanes(vals, val); + + if(threadIdx.x == 0) + output[blockIdx.x] = final; } template<> void MultiSumSqRun(const std::vector &inputs, int nInputs, - float *out_ptr, mshadow::Stream *s) { + float *out_ptr, const OpContext &ctx) { const int chunk_size = 32768; const int block_size = 512; using namespace mxnet_op; + auto s = ctx.get_stream(); auto stream = mshadow::Stream::GetStream(s); - CUDA_CALL(cudaMemsetAsync(out_ptr, 0, nInputs * sizeof(float), stream)); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { MultiSumSqKernelParam param; + // find max num of chunks in tensors + for (int t = 0; t < nInputs; t++) { + int chunks_this_tensor = (inputs[t].shape_.Size() + chunk_size - 1) / chunk_size; + if(chunks_this_tensor > param.max_chunks_per_tensor) + param.max_chunks_per_tensor = chunks_this_tensor; + } + // temporary storage for the reduction of each block + size_t workspace_size = nInputs * param.max_chunks_per_tensor * sizeof(float); + Tensor workspace = + ctx.requested[multi_sum_sq::kTempSpace].get_space_typed( + Shape1(workspace_size), s); + Tensor block_reductions(reinterpret_cast(&workspace[0]), + Shape1(nInputs * param.max_chunks_per_tensor), s); + CUDA_CALL(cudaMemsetAsync(block_reductions.dptr_, 0, nInputs * param.max_chunks_per_tensor* sizeof(float), stream)); + int loc_block_info = 0; // position in param.block_to_tensor and param.block_to_chunck int loc_tensor_info = 0; // position in param.sizes and param.addresses - int output_offset = 0; // array index of the first block pointed on by param.addresses + int start_tensor_id = 0; for (int t = 0; t < nInputs; t++, loc_tensor_info++) { // array index in inputs param.sizes[loc_tensor_info] = inputs[t].shape_.Size(); param.addresses[loc_tensor_info] = inputs[t].FlatTo2D(s).dptr_; @@ -142,27 +166,30 @@ void MultiSumSqRun(const std::vector &inputs, int nInputs, loc_block_info++; const bool last_curr_chunk = chunk == chunks_this_tensor; - const bool tensors_full = last_curr_chunk && loc_tensor_info == 109; - const bool blocks_full = (loc_block_info == 320); + const bool tensors_full = last_curr_chunk && loc_tensor_info == (ARRAY_LIMIT-1); + const bool blocks_full = (loc_block_info == BLOCK_LIMIT); const bool last_chunk = last_curr_chunk && t == nInputs - 1; if (!(tensors_full || blocks_full || last_chunk)) continue; - MultiSumSqKernel<<>> - (chunk_size, param, out_ptr + output_offset); + (chunk_size, param, block_reductions.dptr_, start_tensor_id); MSHADOW_CUDA_POST_KERNEL_CHECK(MultiSumSqKernel); + loc_block_info = 0; if (last_curr_chunk) { // if you start from a new tensor loc_tensor_info = -1; - output_offset = t + 1; + start_tensor_id = t + 1; } else { // if you start from the same tensor param.sizes[0] = param.sizes[loc_tensor_info]; param.addresses[0] = param.addresses[loc_tensor_info]; loc_tensor_info = 0; - output_offset = t; + start_tensor_id = t; } } } + // Global reduction + GlobalReductionKernel<<>> + (param, block_reductions.dptr_, out_ptr); }); } From 8d60931679dedc21efb839171dbcf7fe14e7b387 Mon Sep 17 00:00:00 2001 From: moisesh Date: Fri, 6 Dec 2019 16:40:44 -0800 Subject: [PATCH 2/6] Add specific test for multi_sum_sq --- tests/python/gpu/test_operator_gpu.py | 29 +++++++++++++++++++++++++++ 1 file changed, 29 insertions(+) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index efa55d2c1cde..50dc34d348ae 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -310,6 +310,35 @@ def check_fast_lars(w_dtype, g_dtype, shapes, ctx, tol1, tol2): ref_new_lrs[i] = lrs[i] assert_almost_equal(ref_new_lrs.asnumpy(), mx_new_lrs.asnumpy(), atol=tol2, rtol=tol2) +def check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2): + values_arr = [np.random.rand(*shape).astype(dtype) * 10. for shape in shapes] + + mx_vals = _make_ndarrays(values_arr, ctx=ctx) + sum_sq = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) + + ref_sum_sq = mx.nd.array([(v.astype('float32') ** 2).sum() for v in values_arr], + dtype='float32', ctx=ctx) + + assert_almost_equal(ref_sum_sq.asnumpy(), sum_sq.asnumpy(), atol=tol1, rtol=tol1) + +@with_seed() +def test_multi_sum_sq(): + min_nparam = 390 + max_nparam = 400 + mindim = 50000 + maxdim = 3200000 + maxndim = 1 + + dtypes = ['float16','float32', 'float64'] + for ctx in [mx.gpu(0)]: + for dtype in dtypes: + nparam = np.random.randint(min_nparam + 1, max_nparam + 1) + shapes = [np.random.randint(mindim, maxdim + 1, size=maxndim) for i in range(nparam)] + lowTol = ctx == mx.cpu(0) and ('float16'in [dtype]) + tol1 = 1e-3 if lowTol else 1e-5 + tol2 = 1e-6 if lowTol else 1e-7 + check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2) + @with_seed() def test_fast_lars(): min_nparam = 50 From 06ad5d6ad80bbe296e82f9f46f761d6f44afa7e9 Mon Sep 17 00:00:00 2001 From: moisesh Date: Fri, 6 Dec 2019 17:28:34 -0800 Subject: [PATCH 3/6] Add a determism test and lint issues --- src/operator/contrib/multi_sum_sq.cu | 12 +++++++----- tests/python/gpu/test_operator_gpu.py | 5 ++++- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/operator/contrib/multi_sum_sq.cu b/src/operator/contrib/multi_sum_sq.cu index ad81b3dd9736..fbbec4561c90 100644 --- a/src/operator/contrib/multi_sum_sq.cu +++ b/src/operator/contrib/multi_sum_sq.cu @@ -105,7 +105,7 @@ __global__ void MultiSumSqKernel(int chunk_size, } const float final = reduce_block_into_lanes(vals, val); - if (threadIdx.x == 0){ + if (threadIdx.x == 0) { block_reductions[(start_tensor_id + tensor_loc) * param.max_chunks_per_tensor + param.block_to_chunk[blockIdx.x]] = final; } @@ -118,12 +118,12 @@ __global__ void GlobalReductionKernel(MultiSumSqKernelParam param, __shared__ float vals[512]; float* reductions_this_tensor = block_reductions + blockIdx.x * param.max_chunks_per_tensor; float val = 0; - for(int i = threadIdx.x; i < param.max_chunks_per_tensor; i += blockDim.x) + for (int i = threadIdx.x; i < param.max_chunks_per_tensor; i += blockDim.x) val += reductions_this_tensor[i]; float final = reduce_block_into_lanes(vals, val); - if(threadIdx.x == 0) + if (threadIdx.x == 0) output[blockIdx.x] = final; } @@ -141,7 +141,7 @@ void MultiSumSqRun(const std::vector &inputs, int nInputs, // find max num of chunks in tensors for (int t = 0; t < nInputs; t++) { int chunks_this_tensor = (inputs[t].shape_.Size() + chunk_size - 1) / chunk_size; - if(chunks_this_tensor > param.max_chunks_per_tensor) + if (chunks_this_tensor > param.max_chunks_per_tensor) param.max_chunks_per_tensor = chunks_this_tensor; } // temporary storage for the reduction of each block @@ -151,7 +151,9 @@ void MultiSumSqRun(const std::vector &inputs, int nInputs, Shape1(workspace_size), s); Tensor block_reductions(reinterpret_cast(&workspace[0]), Shape1(nInputs * param.max_chunks_per_tensor), s); - CUDA_CALL(cudaMemsetAsync(block_reductions.dptr_, 0, nInputs * param.max_chunks_per_tensor* sizeof(float), stream)); + CUDA_CALL(cudaMemsetAsync(block_reductions.dptr_, 0, + nInputs * param.max_chunks_per_tensor* sizeof(float), + stream)); int loc_block_info = 0; // position in param.block_to_tensor and param.block_to_chunck int loc_tensor_info = 0; // position in param.sizes and param.addresses diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 50dc34d348ae..b22fe14f1df1 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -315,9 +315,12 @@ def check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2): mx_vals = _make_ndarrays(values_arr, ctx=ctx) sum_sq = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) + sum_sq2 = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) + # checks that operator is deterministic + assert_almost_equal(sum_sq.asnumpy(), sum_sq2.asnumpy(), atol=1e-9, rtol=1e-9) ref_sum_sq = mx.nd.array([(v.astype('float32') ** 2).sum() for v in values_arr], - dtype='float32', ctx=ctx) + dtype='float32', ctx=ctx) assert_almost_equal(ref_sum_sq.asnumpy(), sum_sq.asnumpy(), atol=tol1, rtol=tol1) From c1780f915936ff80337ebda1647b86f4e0add0b1 Mon Sep 17 00:00:00 2001 From: moisesh Date: Sat, 7 Dec 2019 11:53:50 -0800 Subject: [PATCH 4/6] better test for cheching op is deterministic --- tests/python/gpu/test_operator_gpu.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index b22fe14f1df1..839727fbbfd7 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -312,16 +312,14 @@ def check_fast_lars(w_dtype, g_dtype, shapes, ctx, tol1, tol2): def check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2): values_arr = [np.random.rand(*shape).astype(dtype) * 10. for shape in shapes] - mx_vals = _make_ndarrays(values_arr, ctx=ctx) sum_sq = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) sum_sq2 = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) # checks that operator is deterministic - assert_almost_equal(sum_sq.asnumpy(), sum_sq2.asnumpy(), atol=1e-9, rtol=1e-9) + assert np.array_equal(sum_sq.asnumpy(), sum_sq2.asnumpy()) ref_sum_sq = mx.nd.array([(v.astype('float32') ** 2).sum() for v in values_arr], dtype='float32', ctx=ctx) - assert_almost_equal(ref_sum_sq.asnumpy(), sum_sq.asnumpy(), atol=tol1, rtol=tol1) @with_seed() From eb85d8c905e4536c4c3b6aa53d4cd1668bd9e2b5 Mon Sep 17 00:00:00 2001 From: moisesh Date: Mon, 9 Dec 2019 11:41:36 -0800 Subject: [PATCH 5/6] Follow MXNet letters case format --- src/operator/contrib/multi_sum_sq.cc | 12 +++--- src/operator/contrib/multi_sum_sq.cu | 43 ++++++++++--------- tests/python/gpu/test_operator_gpu.py | 60 +++++++++++++-------------- 3 files changed, 57 insertions(+), 58 deletions(-) diff --git a/src/operator/contrib/multi_sum_sq.cc b/src/operator/contrib/multi_sum_sq.cc index 912bc1a59197..16c99d1c9699 100644 --- a/src/operator/contrib/multi_sum_sq.cc +++ b/src/operator/contrib/multi_sum_sq.cc @@ -60,16 +60,16 @@ NNVM_REGISTER_OP(multi_sum_sq) .add_arguments(MultiSumSqParam::__FIELDS__()); template -inline void CalcSumSq(const std::vector &inputs, int nInputs, +inline void CalcSumSq(const std::vector &inputs, int n_inputs, float *out_ptr, mshadow::Stream *s) { int i; size_t j; #pragma omp parallel for private(i, j) - for (i = 0; i < nInputs; ++i) { // array index in inputs + for (i = 0; i < n_inputs; ++i) { // array index in inputs float sum = 0; const auto address = inputs[i].FlatTo2D(s).dptr_; - const auto jMax = inputs[i].shape_.Size(); - for (j = 0; j < jMax; ++j) + const auto j_max = inputs[i].shape_.Size(); + for (j = 0; j < j_max; ++j) sum += address[j] * address[j]; out_ptr[i] = sum; @@ -77,10 +77,10 @@ inline void CalcSumSq(const std::vector &inputs, int nInputs, } template<> -void MultiSumSqRun(const std::vector &inputs, int nInputs, +void MultiSumSqRun(const std::vector &inputs, int n_inputs, float *out_ptr, const OpContext &ctx) { MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, - CalcSumSq(inputs, nInputs, out_ptr, ctx.get_stream()); + CalcSumSq(inputs, n_inputs, out_ptr, ctx.get_stream()); ) } diff --git a/src/operator/contrib/multi_sum_sq.cu b/src/operator/contrib/multi_sum_sq.cu index fbbec4561c90..620c9ca8a073 100644 --- a/src/operator/contrib/multi_sum_sq.cu +++ b/src/operator/contrib/multi_sum_sq.cu @@ -47,27 +47,26 @@ struct MultiSumSqKernelParam { }; template -__device__ __forceinline__ DType reduce_block_into_lanes(DType* x, - DType val) { +__device__ __forceinline__ DType ReduceBlockIntoLanes(DType* x, + DType val) { int tid = threadIdx.x; - int blockSize = blockDim.x; + int block_size = blockDim.x; - if (blockSize >= 64) { + if (block_size >= 64) { x[tid] = val; __syncthreads(); } #pragma unroll - for (int i = (blockSize >> 1); i >= 64; i >>= 1) { + for (int i = (block_size >> 1); i >= 64; i >>= 1) { if (tid < i) x[tid] = x[tid] + x[tid+i]; __syncthreads(); } DType final; - if (tid < 32) { - if (blockSize >= 64) + if (block_size >= 64) final = x[tid] + x[tid+32]; else final = val; @@ -88,22 +87,22 @@ __global__ void MultiSumSqKernel(int chunk_size, const int chunk_len = param.block_to_chunk[blockIdx.x] * chunk_size; const int n = param.sizes[tensor_loc] - chunk_len; const DType* x = param.addresses[tensor_loc] + chunk_len; - const auto iMax = n <= chunk_size? n : chunk_size; + const auto i_max = n <= chunk_size ? n : chunk_size; __shared__ float vals[512]; // Non-divergent exit condition for __syncthreads, not necessary here float val = 0; for (int i_start = 0; - i_start < iMax; + i_start < i_max; i_start += blockDim.x * ILP) { int i = i_start + threadIdx.x; #pragma unroll - for (int ii = 0; ii < ILP && i < iMax; ++ii, i += blockDim.x) { + for (int ii = 0; ii < ILP && i < i_max; ++ii, i += blockDim.x) { const auto incoming_val = static_cast(x[i]); val += incoming_val * incoming_val; } } - const float final = reduce_block_into_lanes(vals, val); + const float final = ReduceBlockIntoLanes(vals, val); if (threadIdx.x == 0) { block_reductions[(start_tensor_id + tensor_loc) * param.max_chunks_per_tensor + @@ -121,14 +120,14 @@ __global__ void GlobalReductionKernel(MultiSumSqKernelParam param, for (int i = threadIdx.x; i < param.max_chunks_per_tensor; i += blockDim.x) val += reductions_this_tensor[i]; - float final = reduce_block_into_lanes(vals, val); + float final = ReduceBlockIntoLanes(vals, val); if (threadIdx.x == 0) output[blockIdx.x] = final; } template<> -void MultiSumSqRun(const std::vector &inputs, int nInputs, +void MultiSumSqRun(const std::vector &inputs, int n_inputs, float *out_ptr, const OpContext &ctx) { const int chunk_size = 32768; const int block_size = 512; @@ -139,26 +138,26 @@ void MultiSumSqRun(const std::vector &inputs, int nInputs, MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { MultiSumSqKernelParam param; // find max num of chunks in tensors - for (int t = 0; t < nInputs; t++) { + for (int t = 0; t < n_inputs; t++) { int chunks_this_tensor = (inputs[t].shape_.Size() + chunk_size - 1) / chunk_size; if (chunks_this_tensor > param.max_chunks_per_tensor) param.max_chunks_per_tensor = chunks_this_tensor; } // temporary storage for the reduction of each block - size_t workspace_size = nInputs * param.max_chunks_per_tensor * sizeof(float); + size_t workspace_size = n_inputs * param.max_chunks_per_tensor * sizeof(float); Tensor workspace = - ctx.requested[multi_sum_sq::kTempSpace].get_space_typed( - Shape1(workspace_size), s); + ctx.requested[multi_sum_sq::kTempSpace].get_space_typed( + Shape1(workspace_size), s); Tensor block_reductions(reinterpret_cast(&workspace[0]), - Shape1(nInputs * param.max_chunks_per_tensor), s); + Shape1(n_inputs * param.max_chunks_per_tensor), s); CUDA_CALL(cudaMemsetAsync(block_reductions.dptr_, 0, - nInputs * param.max_chunks_per_tensor* sizeof(float), + n_inputs * param.max_chunks_per_tensor* sizeof(float), stream)); int loc_block_info = 0; // position in param.block_to_tensor and param.block_to_chunck int loc_tensor_info = 0; // position in param.sizes and param.addresses int start_tensor_id = 0; - for (int t = 0; t < nInputs; t++, loc_tensor_info++) { // array index in inputs + for (int t = 0; t < n_inputs; t++, loc_tensor_info++) { // array index in inputs param.sizes[loc_tensor_info] = inputs[t].shape_.Size(); param.addresses[loc_tensor_info] = inputs[t].FlatTo2D(s).dptr_; const int chunks_this_tensor = (inputs[t].shape_.Size() - 1) / chunk_size; @@ -170,7 +169,7 @@ void MultiSumSqRun(const std::vector &inputs, int nInputs, const bool last_curr_chunk = chunk == chunks_this_tensor; const bool tensors_full = last_curr_chunk && loc_tensor_info == (ARRAY_LIMIT-1); const bool blocks_full = (loc_block_info == BLOCK_LIMIT); - const bool last_chunk = last_curr_chunk && t == nInputs - 1; + const bool last_chunk = last_curr_chunk && t == n_inputs - 1; if (!(tensors_full || blocks_full || last_chunk)) continue; MultiSumSqKernel<<>> @@ -190,7 +189,7 @@ void MultiSumSqRun(const std::vector &inputs, int nInputs, } } // Global reduction - GlobalReductionKernel<<>> + GlobalReductionKernel<<>> (param, block_reductions.dptr_, out_ptr); }); } diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 839727fbbfd7..38ef34ccd6da 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -271,6 +271,36 @@ def test_fft(): def _make_ndarrays(input_list, ctx=mx.gpu(0)): return [mx.nd.array(arr, dtype=arr.dtype, ctx=ctx) for arr in input_list] +def check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2): + values_arr = [np.random.rand(*shape).astype(dtype) * 10. for shape in shapes] + mx_vals = _make_ndarrays(values_arr, ctx=ctx) + sum_sq = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) + sum_sq2 = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) + # checks that operator is deterministic + assert np.array_equal(sum_sq.asnumpy(), sum_sq2.asnumpy()) + + ref_sum_sq = mx.nd.array([(v.astype('float32') ** 2).sum() for v in values_arr], + dtype='float32', ctx=ctx) + assert_almost_equal(ref_sum_sq.asnumpy(), sum_sq.asnumpy(), atol=tol1, rtol=tol1) + +@with_seed() +def test_multi_sum_sq(): + min_nparam = 390 + max_nparam = 400 + min_dim = 50000 + max_dim = 3200000 + max_ndim = 1 + + dtypes = ['float16','float32', 'float64'] + for ctx in [mx.gpu(0)]: + for dtype in dtypes: + nparam = np.random.randint(min_nparam + 1, max_nparam + 1) + shapes = [np.random.randint(min_dim, max_dim + 1, size=max_ndim) for i in range(nparam)] + low_tol = ctx == mx.cpu(0) and ('float16'in [dtype]) + tol1 = 1e-3 if low_tol else 1e-5 + tol2 = 1e-6 if low_tol else 1e-7 + check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2) + def check_fast_lars(w_dtype, g_dtype, shapes, ctx, tol1, tol2): weights_arr = [np.random.rand(*shape).astype(w_dtype) * 10. for shape in shapes] grads_arr = [np.random.rand(*shape).astype(g_dtype) for shape in shapes] @@ -310,36 +340,6 @@ def check_fast_lars(w_dtype, g_dtype, shapes, ctx, tol1, tol2): ref_new_lrs[i] = lrs[i] assert_almost_equal(ref_new_lrs.asnumpy(), mx_new_lrs.asnumpy(), atol=tol2, rtol=tol2) -def check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2): - values_arr = [np.random.rand(*shape).astype(dtype) * 10. for shape in shapes] - mx_vals = _make_ndarrays(values_arr, ctx=ctx) - sum_sq = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) - sum_sq2 = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes)) - # checks that operator is deterministic - assert np.array_equal(sum_sq.asnumpy(), sum_sq2.asnumpy()) - - ref_sum_sq = mx.nd.array([(v.astype('float32') ** 2).sum() for v in values_arr], - dtype='float32', ctx=ctx) - assert_almost_equal(ref_sum_sq.asnumpy(), sum_sq.asnumpy(), atol=tol1, rtol=tol1) - -@with_seed() -def test_multi_sum_sq(): - min_nparam = 390 - max_nparam = 400 - mindim = 50000 - maxdim = 3200000 - maxndim = 1 - - dtypes = ['float16','float32', 'float64'] - for ctx in [mx.gpu(0)]: - for dtype in dtypes: - nparam = np.random.randint(min_nparam + 1, max_nparam + 1) - shapes = [np.random.randint(mindim, maxdim + 1, size=maxndim) for i in range(nparam)] - lowTol = ctx == mx.cpu(0) and ('float16'in [dtype]) - tol1 = 1e-3 if lowTol else 1e-5 - tol2 = 1e-6 if lowTol else 1e-7 - check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2) - @with_seed() def test_fast_lars(): min_nparam = 50 From 4c917d36038f483b7cc7804dd6744582de304ed5 Mon Sep 17 00:00:00 2001 From: moisesh Date: Thu, 12 Dec 2019 13:07:39 -0800 Subject: [PATCH 6/6] Reduce dimensions of tensors in the test --- tests/python/gpu/test_operator_gpu.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python/gpu/test_operator_gpu.py b/tests/python/gpu/test_operator_gpu.py index 591ba88b8592..7d23c2ca0aaf 100644 --- a/tests/python/gpu/test_operator_gpu.py +++ b/tests/python/gpu/test_operator_gpu.py @@ -285,10 +285,10 @@ def check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2): @with_seed() def test_multi_sum_sq(): - min_nparam = 390 - max_nparam = 400 + min_nparam = 100 + max_nparam = 120 min_dim = 50000 - max_dim = 3200000 + max_dim = 100000 max_ndim = 1 dtypes = ['float16','float32', 'float64']