Skip to content

Commit

Permalink
Vectorized loads for binary elemwise kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
ptrendx committed Mar 5, 2020
1 parent 64d8e0f commit 2fe0eaf
Show file tree
Hide file tree
Showing 3 changed files with 280 additions and 31 deletions.
24 changes: 0 additions & 24 deletions src/operator/tensor/elemwise_binary_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -574,30 +574,6 @@ class ElemwiseBinaryOp : public OpBase {
});
}

template<typename xpu, typename OP>
static void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
Stream<xpu> *s = ctx.get_stream<xpu>();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
if (size != 0) {
Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(s, size,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
}
});
});
}

template<typename xpu, typename OP>
static void ComputeEx(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
Expand Down
14 changes: 7 additions & 7 deletions src/operator/tensor/elemwise_binary_op_basic.cu
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "./elemwise_binary_op.h"
#include "./elemwise_binary_op-inl.h"
#include "./indexing_op.h"
#include "./elemwise_op.cuh"

namespace mxnet {
namespace op {
Expand Down Expand Up @@ -218,20 +219,19 @@ void ElemwiseBinaryOp::DnsCsrDnsOp(mshadow::Stream<gpu> *s,
}

NNVM_REGISTER_OP(elemwise_add)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>)
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<op::mshadow_op::plus>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::plus>);

NNVM_REGISTER_OP(_grad_add)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::plus>);
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<op::mshadow_op::plus>);

NNVM_REGISTER_OP(_backward_add)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryOp::BackwardUseNoneWithHalf2<gpu, mshadow_op::identity,
mshadow_op::identity>);

NNVM_REGISTER_OP(elemwise_sub)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<
gpu, op::mshadow_op::minus>)
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<op::mshadow_op::minus>)
.set_attr<FComputeEx>("FComputeEx<gpu>", ElemwiseBinaryOp::ComputeEx<gpu, op::mshadow_op::minus>);

NNVM_REGISTER_OP(_backward_sub)
Expand All @@ -240,7 +240,7 @@ NNVM_REGISTER_OP(_backward_sub)
mshadow_op::negation>);

NNVM_REGISTER_OP(elemwise_mul)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::mul>)
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<op::mshadow_op::mul>)
.set_attr<FComputeEx>("FComputeEx<gpu>",
ElemwiseBinaryOp::ComputeDnsLRValueEx<gpu, op::mshadow_op::mul, true, true>);

Expand All @@ -251,15 +251,15 @@ NNVM_REGISTER_OP(_backward_mul)

NNVM_REGISTER_OP(elemwise_div)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryOp::ElemwiseBinaryOp::ComputeWithHalf2<gpu, op::mshadow_op::div>);
ComputeWithHalf2<op::mshadow_op::div>);

NNVM_REGISTER_OP(_backward_div)
.set_attr<FCompute>("FCompute<gpu>",
ElemwiseBinaryOp::BackwardUseInWithHalf2<gpu, mshadow_op::div_grad,
mshadow_op::div_rgrad>);

NNVM_REGISTER_OP(_mod)
.set_attr<FCompute>("FCompute<gpu>", ElemwiseBinaryOp::ComputeWithHalf2<gpu, mshadow_op::mod>);
.set_attr<FCompute>("FCompute<gpu>", ComputeWithHalf2<mshadow_op::mod>);

NNVM_REGISTER_OP(_backward_mod)
.set_attr<FCompute>("FCompute<gpu>",
Expand Down
273 changes: 273 additions & 0 deletions src/operator/tensor/elemwise_op.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,273 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2020 by Contributors
* \file elemwise_op.cuh
* \brief GPU helpers for elementwise operators
*/

#ifndef MXNET_OPERATOR_TENSOR_ELEMWISE_OP_CUH_
#define MXNET_OPERATOR_TENSOR_ELEMWISE_OP_CUH_

#include <cuda_runtime.h>
#include "../operator_common.h"

#include <vector>

#if MXNET_USE_CUDA

namespace mxnet {
namespace op {

template <typename DType, typename LType>
class VectorizedStorage {
public:
constexpr static int nvec = sizeof(LType) / sizeof(DType);
union vectorized_storage {
LType aligned;
DType separate[nvec]; // NOLINT(*)

MSHADOW_XINLINE vectorized_storage() {}
MSHADOW_XINLINE ~vectorized_storage() {}
} scratch_;
};

template <typename DType, typename LType, bool aligned = false>
class VectorizedAccessor {
public:
VectorizedStorage<typename std::remove_const<DType>::type,
typename std::remove_const<LType>::type> storage_;

LType* aligned_ptr_;
DType* unaligned_ptr_;
int alignment_;
index_t n_elems_;

MSHADOW_XINLINE VectorizedAccessor(DType* ptr, const index_t N) {
unaligned_ptr_ = ptr;
if (aligned) {
alignment_ = 0;
aligned_ptr_ = reinterpret_cast<LType*>(ptr);
n_elems_ = (N + storage_.nvec - 1) / storage_.nvec;
} else {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
alignment_ = (ptr_as_number % sizeof(LType)) / sizeof(DType);
aligned_ptr_ = reinterpret_cast<LType*>(ptr - alignment_);
n_elems_ = (N + alignment_ + storage_.nvec - 1) / storage_.nvec;
}
}

MSHADOW_XINLINE index_t num_aligned_elements() const {
return n_elems_;
}

MSHADOW_XINLINE void load(const index_t id, const index_t N) {
if (aligned) {
storage_.scratch_.aligned = aligned_ptr_[id];
} else {
if (id > 0 && id < n_elems_ - 1) {
storage_.scratch_.aligned = aligned_ptr_[id];
} else {
#pragma unroll
for (int j = 0; j < storage_.nvec; ++j) {
DType* ptr = reinterpret_cast<DType*>(&(aligned_ptr_[id])) + j;
if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(unaligned_ptr_) &&
reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(unaligned_ptr_ + N)) {
storage_.scratch_.separate[j] = *ptr;
}
}
}
}
}
};

template <typename DType, typename LType, bool aligned = false>
class VectorizedLoader : public VectorizedAccessor<const DType, const LType, aligned> {
public:
MSHADOW_XINLINE VectorizedLoader(const DType* ptr, const index_t N) :
VectorizedAccessor<const DType, const LType, aligned>(ptr, N) {
}
};

template <typename DType, typename LType, bool aligned = false>
class VectorizedStorer : public VectorizedAccessor<DType, LType, aligned> {
public:
MSHADOW_XINLINE VectorizedStorer(DType* ptr, const index_t N) :
VectorizedAccessor<DType, LType, aligned>(ptr, N) {
}

MSHADOW_XINLINE void store(const index_t id, const index_t N) {
if (aligned) {
this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
} else {
if (id > 0 && id < this->n_elems_ - 1) {
this->aligned_ptr_[id] = this->storage_.scratch_.aligned;
} else {
#pragma unroll
for (int j = 0; j < this->storage_.nvec; ++j) {
DType* ptr = reinterpret_cast<DType*>(&(this->aligned_ptr_[id])) + j;
if (reinterpret_cast<size_t>(ptr) >= reinterpret_cast<size_t>(this->unaligned_ptr_) &&
reinterpret_cast<size_t>(ptr) < reinterpret_cast<size_t>(this->unaligned_ptr_ + N)) {
*ptr = this->storage_.scratch_.separate[j];
}
}
}
}
}
};

namespace {

template <bool aligned, typename DType, typename LType, typename OP, int req>
__global__ void VectorizedElementwiseKernel(DType* output, const DType* input0, const DType* input1, index_t N) {
VectorizedLoader<DType, LType, aligned> loader0(input0, N);
VectorizedLoader<DType, LType, aligned> loader1(input1, N);
VectorizedStorer<DType, LType, aligned> storer(output, N);

const index_t M = loader0.num_aligned_elements();

for (index_t tid = blockIdx.x * blockDim.x + threadIdx.x;
tid < M;
tid += gridDim.x * blockDim.x) {
loader0.load(tid, N);
loader1.load(tid, N);
if (req == kAddTo) {
storer.load(tid, N);
}
#pragma unroll
for (int i = 0; i < loader0.storage_.nvec; ++i) {
DType temp = OP::Map(loader0.storage_.scratch_.separate[i],
loader1.storage_.scratch_.separate[i]);

if (req == kAddTo) {
storer.storage_.scratch_.separate[i] += temp;
} else {
storer.storage_.scratch_.separate[i] = temp;
}
}
storer.store(tid, N);
}
}

enum class Alignment {
SAME_ALIGNED,
SAME_UNALIGNED,
DIFFERENT
};

template <typename LType, typename DType>
int CalcAlignment(const DType* ptr) {
size_t ptr_as_number = reinterpret_cast<size_t>(ptr);
return ptr_as_number % sizeof(LType);
}

template <typename LType, typename DType>
Alignment CheckAlignment(const std::vector<DType*>& pointers) {
int align = -1;
for (const DType* ptr : pointers) {
int new_align = CalcAlignment<LType>(ptr);
if (align == -1) {
align = new_align;
} else {
if (align != new_align) {
return Alignment::DIFFERENT;
}
}
}

return align == 0 ? Alignment::SAME_ALIGNED
: Alignment::SAME_UNALIGNED;
}

size_t minthree(const size_t a, const size_t b, const size_t c) {
return a < b ? (a < c ? a : c) : (b < c ? b : c);
}

} // namespace

template<typename OP>
void ComputeWithHalf2(const nnvm::NodeAttrs &attrs,
const OpContext &ctx,
const std::vector<TBlob> &inputs,
const std::vector<OpReqType> &req,
const std::vector<TBlob> &outputs) {
using namespace mxnet_op;
if (req[0] == kNullOp) return;
Stream<gpu> *s = ctx.get_stream<gpu>();
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (dmlc::GetEnv("DEBUG_VECTOR", false)) {
MSHADOW_TYPE_SWITCH_WITH_HALF2(outputs[0].type_flag_, DType, {
const size_t size = (minthree(outputs[0].Size(), inputs[0].Size(), inputs[1].Size())
+ DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
if (size != 0) {
Kernel<mxnet_op::op_with_req<OP, Req>, gpu>::Launch(s, size,
outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
}
});
} else {
MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
using LType = double;
static_assert(sizeof(LType) >= sizeof(DType), "Load type is smaller than operand type");
if (outputs[0].Size() != 0) {
cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
constexpr int nvec = sizeof(LType) / sizeof(DType);
VectorizedLoader<DType, LType> l(outputs[0].dptr<DType>(), outputs[0].Size());
size_t num_elements = l.num_aligned_elements();
constexpr int threads = 512;
index_t blocks = (num_elements + threads - 1) / threads;
auto align = CheckAlignment<LType, DType>({outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>()});
if (align == Alignment::SAME_ALIGNED && (outputs[0].Size() % nvec == 0)) {
VectorizedElementwiseKernel<true, DType, LType, OP, Req>
<<<blocks, threads, 0, stream>>>(outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
outputs[0].Size());
} else {
if (align != Alignment::DIFFERENT) {
VectorizedElementwiseKernel<false, DType, LType, OP, Req>
<<<blocks, threads, 0, stream>>>(outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
outputs[0].Size());
} else {
// If the pointers are aligned differently we cannot vectorize
VectorizedElementwiseKernel<true, DType, DType, OP, Req>
<<<blocks, threads, 0, stream>>>(outputs[0].dptr<DType>(),
inputs[0].dptr<DType>(),
inputs[1].dptr<DType>(),
outputs[0].Size());
}
}
}
});
}
});
}

} // namespace op
} // namespace mxnet

#endif // MXNET_USE_CUDA
#endif // MXNET_OPERATOR_TENSOR_ELEMWISE_OP_CUH_

0 comments on commit 2fe0eaf

Please sign in to comment.