Skip to content

Commit

Permalink
More generalization
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Mar 5, 2020
1 parent 2fe0eaf commit 6b89506
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 15 deletions.
12 changes: 6 additions & 6 deletions src/operator/tensor/elemwise_binary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -219,19 +219,19 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<gpu> *s,
}

NNVM_REGISTER_OP(elemwise_add)
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<op::mshadow_op::plus>)
.set_attr<FCompute>("FCompute<gpu>", VectorizedCompute<op::mshadow_op::plus>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::plus>);

NNVM_REGISTER_OP(_grad_add)
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<op::mshadow_op::plus>);
.set_attr<FCompute>("FCompute<gpu>", VectorizedCompute<op::mshadow_op::plus>);

NNVM_REGISTER_OP(_backward_add)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryOp::BackwardUseNoneWithHalf2<gpu, mshadow_op::identity,
mshadow_op::identity>);

NNVM_REGISTER_OP(elemwise_sub)
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<op::mshadow_op::minus>)
.set_attr<FCompute>("FCompute<gpu>", VectorizedCompute<op::mshadow_op::minus>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::minus>);

NNVM_REGISTER_OP(_backward_sub)
Expand All @@ -240,7 +240,7 @@ NNVM_REGISTER_OP(_backward_sub)
mshadow_op::negation>);

NNVM_REGISTER_OP(elemwise_mul)
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<op::mshadow_op::mul>)
.set_attr<FCompute>("FCompute<gpu>", VectorizedCompute<op::mshadow_op::mul>)
.set_attr<FComputeEx>("FComputeEx<gpu>",
ElemwiseBinaryOp::ComputeDnsLRValueEx<gpu, op::mshadow_op::mul, true, true>);

Expand All @@ -251,15 +251,15 @@ NNVM_REGISTER_OP(_backward_mul)

NNVM_REGISTER_OP(elemwise_div)
.set_attr<FCompute>("FCompute<gpu>",
ComputeWithHalf2<op::mshadow_op::div>);
VectorizedCompute<op::mshadow_op::div>);

NNVM_REGISTER_OP(_backward_div)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryOp::BackwardUseInWithHalf2<gpu, mshadow_op::div_grad,
mshadow_op::div_rgrad>);

NNVM_REGISTER_OP(_mod)
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<mshadow_op::mod>);
.set_attr<FCompute>("FCompute<gpu>", VectorizedCompute<mshadow_op::mod>);

NNVM_REGISTER_OP(_backward_mod)
.set_attr<FCompute>("FCompute<gpu>",
Expand Down
48 changes: 39 additions & 9 deletions src/operator/tensor/elemwise_op.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <cuda_runtime.h>
#include "../operator_common.h"
#include "../../common/cuda_utils.h"

#include <vector>

Expand All @@ -49,6 +50,29 @@ class VectorizedStorage {
} scratch_;
};

template <typename LType>
MSHADOW_XINLINE void ldg(LType* dst, const LType* src) {
*dst = *src;
}

template <>
MSHADOW_XINLINE void ldg(double* dst, const double* src) {
double temp;
asm volatile ("ld.global.f64 %0, [%1];" :
"=d"(temp) :
"l"(src));
*dst = temp;
}

/*template <>*/
/*MSHADOW_XINLINE void ldg(uint64_t* dst, const uint64_t* src) {*/
/*uint64_t temp;*/
/*asm volatile ("ld.global.u64 %0, [%1];" :*/
/*"=l"(temp) :*/
/*"l"(src));*/
/**dst = temp;*/
/*}*/

template <typename DType, typename LType, bool aligned = false>
class VectorizedAccessor {
public:
Expand Down Expand Up @@ -80,10 +104,12 @@ class VectorizedAccessor {

MSHADOW_XINLINE void load(const index_t id, const index_t N) {
if (aligned) {
storage_.scratch_.aligned = aligned_ptr_[id];
ldg<typename std::remove_const<LType>::type>(&(storage_.scratch_.aligned),
aligned_ptr_ + id);
} else {
if (id > 0 && id < n_elems_ - 1) {
storage_.scratch_.aligned = aligned_ptr_[id];
ldg<typename std::remove_const<LType>::type>(&(storage_.scratch_.aligned),
aligned_ptr_ + id);
} else {
#pragma unroll
for (int j = 0; j < storage_.nvec; ++j) {
Expand Down Expand Up @@ -203,11 +229,11 @@ size_t minthree(const size_t a, const size_t b, const size_t c) {
} // namespace

template<typename OP>
void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
void VectorizedCompute(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
Stream<gpu> *s = ctx.get_stream<gpu>();
Expand All @@ -226,15 +252,16 @@ void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
});
} else {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
using LType = double;
using LType = uint4;
static_assert(sizeof(LType) >= sizeof(DType), "Load type is smaller than operand type");
if (outputs[0].Size() != 0) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
constexpr int nvec = sizeof(LType) / sizeof(DType);
VectorizedLoader<DType, LType> l(outputs[0].dptr<DType>(), outputs[0].Size());
size_t num_elements = l.num_aligned_elements();
constexpr int threads = 512;
index_t blocks = (num_elements + threads - 1) / threads;
index_t blocks = std::min(static_cast<int>((num_elements + threads - 1) / threads),
65535);
auto align = CheckAlignment<LType, DType>({outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>()});
Expand All @@ -252,6 +279,9 @@ void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
inputs[1].dptr<DType>(),
outputs[0].Size());
} else {
index_t blocks = std::min(static_cast<int>((outputs[0].Size() + threads - 1) /
threads),
65535);
// If the pointers are aligned differently we cannot vectorize
VectorizedElementwiseKernel<true, DType, DType, OP, Req>
<<<blocks, threads, 0, stream>>>(outputs[0].dptr<DType>(),
Expand Down

0 comments on commit 6b89506

Please sign in to comment.