From adad7201e1cfe1c7975e20b310f1a4f344ff559f Mon Sep 17 00:00:00 2001 From: stu1130 Date: Fri, 14 Jun 2019 12:45:28 -0700 Subject: [PATCH] refactor & pass the tensor instead of tuple to kernel --- src/operator/numpy/random/np_multinomial_op.h | 42 ++++++------------- 1 file changed, 12 insertions(+), 30 deletions(-) diff --git a/src/operator/numpy/random/np_multinomial_op.h b/src/operator/numpy/random/np_multinomial_op.h index 800cf8971062..2e60b7abaf46 100644 --- a/src/operator/numpy/random/np_multinomial_op.h +++ b/src/operator/numpy/random/np_multinomial_op.h @@ -99,32 +99,7 @@ inline bool NumpyMultinomialOpType(const nnvm::NodeAttrs& attrs, return true; } -struct multinomial_kernel_from_tuple { - MSHADOW_XINLINE static void Map(int i, - const int num_exp, - const mxnet::Tuple& pvals, - float* uniform, - int64_t* out) { - for (int j = 0; j < num_exp; ++j) { - double loc = static_cast(uniform[i * num_exp + j]); - double acc = 0.0; - bool found = false; - for (int k = 0; k < pvals.ndim(); ++k) { - acc += pvals[k]; - if (acc > loc) { - found = true; - out[i * pvals.ndim() + k] += 1; - break; - } - } - if (!found) { - out[i * pvals.ndim() + (pvals.ndim() - 1)] += 1; - } - } - } -}; - -struct multinomial_kernel_from_input { +struct multinomial_kernel { template MSHADOW_XINLINE static void Map(int i, const int num_exp, @@ -179,15 +154,22 @@ void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs, Kernel::Launch(s, outputs[0].Size(), outputs[0].dptr()); if (param.pvals.has_value()) { - // check if sum of input(pvals) > 1.0 + // create a tensor to copy the param.pvals tuple to avoid + // error: calling a __host__ function from a __host__ __device__ function is not allowed + Tensor pvals = + ctx.requested[1].get_space_typed(Shape1(prob_length), s); + double* pvals_ = pvals.dptr_; + // check if sum of input(pvals) > 1.0 double sum = 0.0; for (int i = 0; i < prob_length; ++i) { sum += param.pvals.value()[i]; + // copy the tuple to data for later kernel usage + pvals_[i] = param.pvals.value()[i]; CHECK_LE(sum, 1.0) << "sum(pvals[:-1]) > 1.0"; } - Kernel::Launch( - s, num_output, num_exp, param.pvals.value(), uniform.dptr_, outputs[0].dptr()); + Kernel::Launch( + s, num_output, num_exp, prob_length, pvals_, uniform.dptr_, outputs[0].dptr()); } else { MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, { // check if sum of input(pvals) > 1.0 @@ -198,7 +180,7 @@ void NumpyMultinomialForward(const nnvm::NodeAttrs& attrs, CHECK_LE(sum, 1.0) << "sum(pvals[:-1]) > 1.0"; } - Kernel::Launch( + Kernel::Launch( s, num_output, num_exp, prob_length, inputs[0].dptr(), uniform.dptr_, outputs[0].dptr()); });