diff --git a/src/operator/tensor/elemwise_binary_op_basic.cu b/src/operator/tensor/elemwise_binary_op_basic.cu index 3f7a0a1d574f..e00319358a8c 100644 --- a/src/operator/tensor/elemwise_binary_op_basic.cu +++ b/src/operator/tensor/elemwise_binary_op_basic.cu @@ -219,11 +219,11 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream *s, } NNVM_REGISTER_OP(elemwise_add) -.set_attr("FCompute", ComputeWithHalf2) +.set_attr("FCompute", VectorizedCompute) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeEx); NNVM_REGISTER_OP(_grad_add) -.set_attr("FCompute", ComputeWithHalf2); +.set_attr("FCompute", VectorizedCompute); NNVM_REGISTER_OP(_backward_add) .set_attr("FCompute", @@ -231,7 +231,7 @@ NNVM_REGISTER_OP(_backward_add) mshadow_op::identity>); NNVM_REGISTER_OP(elemwise_sub) -.set_attr("FCompute", ComputeWithHalf2) +.set_attr("FCompute", VectorizedCompute) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeEx); NNVM_REGISTER_OP(_backward_sub) @@ -240,7 +240,7 @@ NNVM_REGISTER_OP(_backward_sub) mshadow_op::negation>); NNVM_REGISTER_OP(elemwise_mul) -.set_attr("FCompute", ComputeWithHalf2) +.set_attr("FCompute", VectorizedCompute) .set_attr("FComputeEx", ElemwiseBinaryOp::ComputeDnsLRValueEx); @@ -251,7 +251,7 @@ NNVM_REGISTER_OP(_backward_mul) NNVM_REGISTER_OP(elemwise_div) .set_attr("FCompute", - ComputeWithHalf2); + VectorizedCompute); NNVM_REGISTER_OP(_backward_div) .set_attr("FCompute", @@ -259,7 +259,7 @@ NNVM_REGISTER_OP(_backward_div) mshadow_op::div_rgrad>); NNVM_REGISTER_OP(_mod) -.set_attr("FCompute", ComputeWithHalf2); +.set_attr("FCompute", VectorizedCompute); NNVM_REGISTER_OP(_backward_mod) .set_attr("FCompute", diff --git a/src/operator/tensor/elemwise_op.cuh b/src/operator/tensor/elemwise_op.cuh index 091d2c5ab091..5d83c50312ee 100644 --- a/src/operator/tensor/elemwise_op.cuh +++ b/src/operator/tensor/elemwise_op.cuh @@ -28,6 +28,7 @@ #include #include "../operator_common.h" +#include "../../common/cuda_utils.h" #include @@ -49,6 +50,29 @@ class VectorizedStorage { } scratch_; }; +template +MSHADOW_XINLINE void ldg(LType* dst, const LType* src) { + *dst = *src; +} + +template <> +MSHADOW_XINLINE void ldg(double* dst, const double* src) { + double temp; + asm volatile ("ld.global.f64 %0, [%1];" : + "=d"(temp) : + "l"(src)); + *dst = temp; +} + +/*template <>*/ +/*MSHADOW_XINLINE void ldg(uint64_t* dst, const uint64_t* src) {*/ + /*uint64_t temp;*/ + /*asm volatile ("ld.global.u64 %0, [%1];" :*/ + /*"=l"(temp) :*/ + /*"l"(src));*/ + /**dst = temp;*/ +/*}*/ + template class VectorizedAccessor { public: @@ -80,10 +104,12 @@ class VectorizedAccessor { MSHADOW_XINLINE void load(const index_t id, const index_t N) { if (aligned) { - storage_.scratch_.aligned = aligned_ptr_[id]; + ldg::type>(&(storage_.scratch_.aligned), + aligned_ptr_ + id); } else { if (id > 0 && id < n_elems_ - 1) { - storage_.scratch_.aligned = aligned_ptr_[id]; + ldg::type>(&(storage_.scratch_.aligned), + aligned_ptr_ + id); } else { #pragma unroll for (int j = 0; j < storage_.nvec; ++j) { @@ -203,11 +229,11 @@ size_t minthree(const size_t a, const size_t b, const size_t c) { } // namespace template -void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, - const OpContext &ctx, - const std::vector &inputs, - const std::vector &req, - const std::vector &outputs) { +void VectorizedCompute(const nnvm::NodeAttrs &attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { using namespace mxnet_op; if (req[0] == kNullOp) return; Stream *s = ctx.get_stream(); @@ -226,7 +252,7 @@ void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, }); } else { MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, { - using LType = double; + using LType = uint4; static_assert(sizeof(LType) >= sizeof(DType), "Load type is smaller than operand type"); if (outputs[0].Size() != 0) { cudaStream_t stream = mshadow::Stream::GetStream(s); @@ -234,7 +260,8 @@ void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, VectorizedLoader l(outputs[0].dptr(), outputs[0].Size()); size_t num_elements = l.num_aligned_elements(); constexpr int threads = 512; - index_t blocks = (num_elements + threads - 1) / threads; + index_t blocks = std::min(static_cast((num_elements + threads - 1) / threads), + 65535); auto align = CheckAlignment({outputs[0].dptr(), inputs[0].dptr(), inputs[1].dptr()}); @@ -252,6 +279,9 @@ void ComputeWithHalf2(const nnvm::NodeAttrs &attrs, inputs[1].dptr(), outputs[0].Size()); } else { + index_t blocks = std::min(static_cast((outputs[0].Size() + threads - 1) / + threads), + 65535); // If the pointers are aligned differently we cannot vectorize VectorizedElementwiseKernel <<>>(outputs[0].dptr(),