Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Multi_sum_sq review, AtomicAdd removal #17002

Merged
merged 7 commits into from
Dec 14, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 7 additions & 3 deletions src/operator/contrib/multi_sum_sq-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2019 by Contributors
* \file multi_l2_norm-inl.h
* \brief vectorized L2 norm over multiple arrays operators
* \author Clement Fuji Tsang, Andrei Ivanov
* \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez
*/


Expand All @@ -32,6 +32,10 @@
#include <vector>
#include "../operator_common.h"

namespace multi_sum_sq {
enum MultiSumSqUpdateResource {kTempSpace};
} // namespace multi_sum_sq

namespace mxnet {
namespace op {

Expand Down Expand Up @@ -80,7 +84,7 @@ inline bool MultiSumSqType(const NodeAttrs& attrs,

template<typename xpu>
void MultiSumSqRun(const std::vector<TBlob> &inputs, int nInputs,
float *out_ptr, mshadow::Stream<xpu> *s);
float *out_ptr, const OpContext &ctx);

template<typename xpu>
void MultiSumSq(const nnvm::NodeAttrs& attrs,
Expand All @@ -91,7 +95,7 @@ void MultiSumSq(const nnvm::NodeAttrs& attrs,
auto s = ctx.get_stream<xpu>();
const auto& p = dmlc::get<MultiSumSqParam>(attrs.parsed);
float* out_ptr = outputs[0].FlatTo2D<xpu, float>(s).dptr_;
MultiSumSqRun<xpu>(inputs, p.num_arrays, out_ptr, s);
MultiSumSqRun<xpu>(inputs, p.num_arrays, out_ptr, ctx);
}

} // namespace op
Expand Down
20 changes: 12 additions & 8 deletions src/operator/contrib/multi_sum_sq.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2019 by Contributors
* \file multi_sum_sq.cc
* \brief vectorized sum or squared over multiple arrays operators
* \author Clement Fuji Tsang, Andrei Ivanov
* \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez
*/

#include "./multi_sum_sq-inl.h"
Expand Down Expand Up @@ -52,31 +52,35 @@ NNVM_REGISTER_OP(multi_sum_sq)
return ret;
})
.set_attr<FCompute>("FCompute<cpu>", MultiSumSq<cpu>)
.set_attr<FResourceRequest>("FResourceRequest",
[](const NodeAttrs& attrs) {
return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
.add_argument("data", "NDArray-or-Symbol[]", "Arrays")
.add_arguments(MultiSumSqParam::__FIELDS__());

template<typename DType>
inline void CalcSumSq(const std::vector<TBlob> &inputs, int nInputs,
inline void CalcSumSq(const std::vector<TBlob> &inputs, int n_inputs,
float *out_ptr, mshadow::Stream<cpu> *s) {
int i;
size_t j;
#pragma omp parallel for private(i, j)
for (i = 0; i < nInputs; ++i) { // array index in inputs
for (i = 0; i < n_inputs; ++i) { // array index in inputs
float sum = 0;
const auto address = inputs[i].FlatTo2D<cpu, DType>(s).dptr_;
const auto jMax = inputs[i].shape_.Size();
for (j = 0; j < jMax; ++j)
const auto j_max = inputs[i].shape_.Size();
for (j = 0; j < j_max; ++j)
sum += address[j] * address[j];

out_ptr[i] = sum;
}
}

template<>
void MultiSumSqRun<cpu>(const std::vector<TBlob> &inputs, int nInputs,
float *out_ptr, mshadow::Stream<cpu> *s) {
void MultiSumSqRun<cpu>(const std::vector<TBlob> &inputs, int n_inputs,
float *out_ptr, const OpContext &ctx) {
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType,
CalcSumSq<DType>(inputs, nInputs, out_ptr, s);
CalcSumSq<DType>(inputs, n_inputs, out_ptr, ctx.get_stream<cpu>());
)
}

Expand Down
110 changes: 69 additions & 41 deletions src/operator/contrib/multi_sum_sq.cu
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
* Copyright (c) 2019 by Contributors
* \file multi_sum_sq.cu
* \brief vectorized sums of squares norm over multiple arrays operators
* \author Clement Fuji Tsang, Andrei Ivanov
* \author Clement Fuji Tsang, Andrei Ivanov, Moises Hernandez
*/
#include "./multi_sum_sq-inl.h"
#include <cub/cub.cuh>
Expand All @@ -43,96 +43,121 @@ struct MultiSumSqKernelParam {
int sizes[ARRAY_LIMIT];
unsigned char block_to_tensor[BLOCK_LIMIT];
int block_to_chunk[BLOCK_LIMIT];
int max_chunks_per_tensor = -1;
};

template<typename DType>
__device__ __forceinline__ DType reduce_block_into_lanes(DType* x,
DType val,
int lanes = 1,
bool share_result = false) {
int tid = threadIdx.x + threadIdx.y * blockDim.x;
int blockSize = blockDim.x * blockDim.y; // blockSize is intended to be a multiple of 32.

if (blockSize >= 64) {
__device__ __forceinline__ DType ReduceBlockIntoLanes(DType* x,
DType val) {
int tid = threadIdx.x;
int block_size = blockDim.x;

if (block_size >= 64) {
x[tid] = val;
__syncthreads();
}

#pragma unroll
for (int i = (blockSize >> 1); i >= 64; i >>= 1) {
for (int i = (block_size >> 1); i >= 64; i >>= 1) {
if (tid < i)
x[tid] = x[tid] + x[tid+i];
__syncthreads();
}

DType final;

if (tid < 32) {
if (blockSize >= 64)
if (block_size >= 64)
final = x[tid] + x[tid+32];
else
final = val;
// __SYNCWARP();

#pragma unroll
for (int i = 16; i >= lanes; i >>= 1)
for (int i = 16; i >= 1; i >>= 1)
final = final + __shfl_down_sync(0xffffffff, final, i);
}

if (share_result) {
if (tid < lanes)
x[tid] = final; // EpilogueOp
// Make sure the smem result is visible to all warps.
__syncthreads();
}

return final;
}

template<typename DType>
__global__ void MultiSumSqKernel(int chunk_size,
MultiSumSqKernelParam<DType> param,
float* output) {
float* block_reductions,
int start_tensor_id) {
const int tensor_loc = param.block_to_tensor[blockIdx.x];
const int chunk_len = param.block_to_chunk[blockIdx.x] * chunk_size;
const int n = param.sizes[tensor_loc] - chunk_len;
const DType* x = param.addresses[tensor_loc] + chunk_len;
const auto iMax = n <= chunk_size? n : chunk_size;
const auto i_max = n <= chunk_size ? n : chunk_size;
__shared__ float vals[512];

// Non-divergent exit condition for __syncthreads, not necessary here
float val = 0;
for (int i_start = 0;
i_start < iMax;
i_start < i_max;
i_start += blockDim.x * ILP) {
int i = i_start + threadIdx.x;
// #pragma unroll
for (int ii = 0; ii < ILP && i < iMax; ++ii, i += blockDim.x) {
#pragma unroll
for (int ii = 0; ii < ILP && i < i_max; ++ii, i += blockDim.x) {
const auto incoming_val = static_cast<float>(x[i]);
val += incoming_val * incoming_val;
}
}
const float final = ReduceBlockIntoLanes(vals, val);

if (threadIdx.x == 0) {
block_reductions[(start_tensor_id + tensor_loc) * param.max_chunks_per_tensor +
param.block_to_chunk[blockIdx.x]] = final;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should change the variable name here? = final specifies that a virtual function cannot be overridden in a derived class.

}
}

template<typename DType>
__global__ void GlobalReductionKernel(MultiSumSqKernelParam<DType> param,
float* block_reductions,
float* output) {
__shared__ float vals[512];
float* reductions_this_tensor = block_reductions + blockIdx.x * param.max_chunks_per_tensor;
float val = 0;
for (int i = threadIdx.x; i < param.max_chunks_per_tensor; i += blockDim.x)
val += reductions_this_tensor[i];

float final = ReduceBlockIntoLanes(vals, val);

const float final = reduce_block_into_lanes(vals, val);
if (threadIdx.x == 0)
atomicAdd(output + tensor_loc, final);
output[blockIdx.x] = final;
}

template<>
void MultiSumSqRun<gpu>(const std::vector<TBlob> &inputs, int nInputs,
float *out_ptr, mshadow::Stream<gpu> *s) {
void MultiSumSqRun<gpu>(const std::vector<TBlob> &inputs, int n_inputs,
float *out_ptr, const OpContext &ctx) {
const int chunk_size = 32768;
const int block_size = 512;
using namespace mxnet_op;
auto s = ctx.get_stream<gpu>();
auto stream = mshadow::Stream<gpu>::GetStream(s);
CUDA_CALL(cudaMemsetAsync(out_ptr, 0, nInputs * sizeof(float), stream));

MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
MultiSumSqKernelParam<DType> param;
// find max num of chunks in tensors
for (int t = 0; t < n_inputs; t++) {
int chunks_this_tensor = (inputs[t].shape_.Size() + chunk_size - 1) / chunk_size;
if (chunks_this_tensor > param.max_chunks_per_tensor)
param.max_chunks_per_tensor = chunks_this_tensor;
}
// temporary storage for the reduction of each block
size_t workspace_size = n_inputs * param.max_chunks_per_tensor * sizeof(float);
Tensor<gpu, 1, char> workspace =
ctx.requested[multi_sum_sq::kTempSpace].get_space_typed<gpu, 1, char>(
Shape1(workspace_size), s);
Tensor<gpu, 1, float> block_reductions(reinterpret_cast<float*>(&workspace[0]),
Shape1(n_inputs * param.max_chunks_per_tensor), s);
CUDA_CALL(cudaMemsetAsync(block_reductions.dptr_, 0,
n_inputs * param.max_chunks_per_tensor* sizeof(float),
stream));

int loc_block_info = 0; // position in param.block_to_tensor and param.block_to_chunck
int loc_tensor_info = 0; // position in param.sizes and param.addresses
int output_offset = 0; // array index of the first block pointed on by param.addresses
for (int t = 0; t < nInputs; t++, loc_tensor_info++) { // array index in inputs
int start_tensor_id = 0;
for (int t = 0; t < n_inputs; t++, loc_tensor_info++) { // array index in inputs
param.sizes[loc_tensor_info] = inputs[t].shape_.Size();
param.addresses[loc_tensor_info] = inputs[t].FlatTo2D<gpu, DType>(s).dptr_;
const int chunks_this_tensor = (inputs[t].shape_.Size() - 1) / chunk_size;
Expand All @@ -142,27 +167,30 @@ void MultiSumSqRun<gpu>(const std::vector<TBlob> &inputs, int nInputs,
loc_block_info++;

const bool last_curr_chunk = chunk == chunks_this_tensor;
const bool tensors_full = last_curr_chunk && loc_tensor_info == 109;
const bool blocks_full = (loc_block_info == 320);
const bool last_chunk = last_curr_chunk && t == nInputs - 1;
const bool tensors_full = last_curr_chunk && loc_tensor_info == (ARRAY_LIMIT-1);
const bool blocks_full = (loc_block_info == BLOCK_LIMIT);
const bool last_chunk = last_curr_chunk && t == n_inputs - 1;
if (!(tensors_full || blocks_full || last_chunk))
continue;

MultiSumSqKernel<<<loc_block_info, block_size, 0, stream>>>
(chunk_size, param, out_ptr + output_offset);
(chunk_size, param, block_reductions.dptr_, start_tensor_id);
MSHADOW_CUDA_POST_KERNEL_CHECK(MultiSumSqKernel);

loc_block_info = 0;
if (last_curr_chunk) { // if you start from a new tensor
loc_tensor_info = -1;
output_offset = t + 1;
start_tensor_id = t + 1;
} else { // if you start from the same tensor
param.sizes[0] = param.sizes[loc_tensor_info];
param.addresses[0] = param.addresses[loc_tensor_info];
loc_tensor_info = 0;
output_offset = t;
start_tensor_id = t;
}
}
}
// Global reduction
GlobalReductionKernel<<<n_inputs, block_size, 0, stream>>>
(param, block_reductions.dptr_, out_ptr);
});
}

Expand Down
30 changes: 30 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,36 @@ def test_fft():
def _make_ndarrays(input_list, ctx=mx.gpu(0)):
return [mx.nd.array(arr, dtype=arr.dtype, ctx=ctx) for arr in input_list]

def check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2):
values_arr = [np.random.rand(*shape).astype(dtype) * 10. for shape in shapes]
mx_vals = _make_ndarrays(values_arr, ctx=ctx)
sum_sq = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes))
sum_sq2 = mx.nd.multi_sum_sq(*mx_vals, num_arrays=len(shapes))
# checks that operator is deterministic
assert np.array_equal(sum_sq.asnumpy(), sum_sq2.asnumpy())

ref_sum_sq = mx.nd.array([(v.astype('float32') ** 2).sum() for v in values_arr],
dtype='float32', ctx=ctx)
assert_almost_equal(ref_sum_sq.asnumpy(), sum_sq.asnumpy(), atol=tol1, rtol=tol1)

@with_seed()
def test_multi_sum_sq():
min_nparam = 100
max_nparam = 120
min_dim = 50000
max_dim = 100000
max_ndim = 1

dtypes = ['float16','float32', 'float64']
for ctx in [mx.gpu(0)]:
for dtype in dtypes:
nparam = np.random.randint(min_nparam + 1, max_nparam + 1)
shapes = [np.random.randint(min_dim, max_dim + 1, size=max_ndim) for i in range(nparam)]
low_tol = ctx == mx.cpu(0) and ('float16'in [dtype])
tol1 = 1e-3 if low_tol else 1e-5
tol2 = 1e-6 if low_tol else 1e-7
check_multi_sum_sq(dtype, shapes, ctx, tol1, tol2)

def check_fast_lars(w_dtype, g_dtype, shapes, ctx, tol1, tol2):
weights_arr = [np.random.rand(*shape).astype(w_dtype) * 10. for shape in shapes]
grads_arr = [np.random.rand(*shape).astype(g_dtype) for shape in shapes]
Expand Down