Skip to content

Commit

Permalink
Accelerate batched_gemm in GPU using CuBLAS version
Browse files Browse the repository at this point in the history
Fix lint

Adding workspace

Fix build error

Fix build error on GPU

Fix comment

Update comment
  • Loading branch information
sxjscience committed Jul 23, 2016
1 parent 76c12ef commit 335b9a0
Show file tree
Hide file tree
Showing 4 changed files with 156 additions and 80 deletions.
21 changes: 21 additions & 0 deletions mshadow/cuda/tensor_gpu-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,27 @@ inline void MapReduceKeepDim1(expr::Plan<DstExp, DType> dst,
<<<dimGrid, dimBlock, 0, stream>>>(dst, plan, scale, pshape);
}

template<int x_bits, typename DType>
__global__ void GetBatchedViewKernel(DType **dst, DType *src, int num, int stride) {
const int x_size = 1 << x_bits;
const int start = threadIdx.x;
// Copy the addresses of src to dst every stride steps
for (int i = start; i < num; i += x_size) {
dst[i] = src + i * stride;
}
}

template<typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<gpu> *stream) {
cudaStream_t stream_ = Stream<gpu>::GetStream(stream);
dim3 dimBlock(kBaseThreadNum);
dim3 dimGrid(1);
CheckLaunchParam(dimGrid, dimBlock, "GetBatchedView");
GetBatchedViewKernel<kBaseThreadBits, DType>
<<<dimGrid, dimBlock, 0, stream_>>> (dst, src, num, stride);
}

template<int x_bits, typename DType, typename DstPlan, typename SrcPlan1, typename SrcPlan2>
__global__ void SoftmaxGradKernel(DstPlan dst, SrcPlan1 src, SrcPlan2 label, index_t xmax) {
const unsigned x_size = 1 << x_bits;
Expand Down
152 changes: 74 additions & 78 deletions mshadow/dot_engine-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,37 @@
#include "./base.h"
#include "./extension/implicit_gemm.h"

#ifdef __CUDACC__
#include "./cuda/tensor_gpu-inl.cuh"
#endif // #ifdef __CUDACC__

namespace mshadow {
/*!
* \brief CPU/GPU: Get a batched view of the src array. dst[i] = src + i * stride
* \param dst 2D pointer
* \param src 1D pointer
* \param num number of batches
* \param stride size of each batch
* \param stream
*/
template<typename Device, typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<Device> *stream);
template<typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<cpu> *stream) {
for (int i = 0; i < num; i++) {
dst[i] = src + i * stride;
}
}
#ifdef __CUDACC__
template<typename DType>
inline void GetBatchedView(DType **dst, DType *src, int num, int stride,
Stream<gpu> *stream) {
cuda::GetBatchedView(dst, src, num, stride, stream);
}
#endif // #ifdef __CUDACC__

namespace expr {
//---------------------------------------------------------------------
// Matrix Multiplications, depends on BLAS Engine
Expand Down Expand Up @@ -42,7 +72,8 @@ struct BLASEngine {
bool transa, bool transb,
int m, int n, int k, DType alpha,
const DType *A, int lda, const DType *B, int ldb,
DType beta, DType *C, int ldc, int batch_count) {
DType beta, DType *C, int ldc, int batch_count,
DType **workspace) {
LOG(FATAL) << "Not implmented!";
}
inline static void gemv(Stream<Device> *stream,
Expand Down Expand Up @@ -116,7 +147,8 @@ struct BLASEngine<cpu, float> {
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count) {
float beta, float *C, int ldc, int batch_count,
float **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -193,7 +225,8 @@ struct BLASEngine<cpu, double> {
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count) {
double beta, double *C, int ldc, int batch_count,
double **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -255,7 +288,8 @@ struct BLASEngine<cpu, float> {
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count) {
float beta, float *C, int ldc, int batch_count,
float **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -324,7 +358,8 @@ struct BLASEngine<cpu, double> {
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count) {
double beta, double *C, int ldc, int batch_count,
double **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -424,7 +459,8 @@ struct BLASEngine<gpu, half::half_t> {
bool transa, bool transb,
int m, int n, int k, half::half_t alpha,
const half::half_t *A, int lda, const half::half_t *B, int ldb,
half::half_t beta, half::half_t *C, int ldc, int batch_count) {
half::half_t beta, half::half_t *C, int ldc, int batch_count,
half::half_t **workspace) {
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
Expand Down Expand Up @@ -491,12 +527,27 @@ struct BLASEngine<gpu, float> {
bool transa, bool transb,
int m, int n, int k, float alpha,
const float *A, int lda, const float *B, int ldb,
float beta, float *C, int ldc, int batch_count) {
float beta, float *C, int ldc, int batch_count,
float **workspace) {
#if defined(__CUDACC__) && CUDA_VERSION >= 4010
// Cast DType* to DType** using workspace as a buffer
GetBatchedView(workspace, const_cast<float*>(A), batch_count, m * k, stream);
GetBatchedView(workspace + batch_count,
const_cast<float*>(B), batch_count, k * n, stream);
GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
cublasStatus_t err = cublasSgemmBatched(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
(const float**)workspace, lda,
(const float**)(workspace + batch_count), ldb,
&beta, workspace + 2 * batch_count, ldc, batch_count);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: SgemmBatched fail";
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
}
inline static void gemv(Stream<gpu> *stream,
bool trans, int m, int n, float alpha,
Expand Down Expand Up @@ -575,12 +626,27 @@ struct BLASEngine<gpu, double> {
bool transa, bool transb,
int m, int n, int k, double alpha,
const double *A, int lda, const double *B, int ldb,
double beta, double *C, int ldc, int batch_count) {
double beta, double *C, int ldc, int batch_count,
double **workspace) {
#if defined(__CUDACC__) && CUDA_VERSION >= 4010
// Cast DType* to DType** using workspace as a buffer
GetBatchedView(workspace, const_cast<double*>(A), batch_count, m * k, stream);
GetBatchedView(workspace + batch_count,
const_cast<double*>(B), batch_count, k * n, stream);
GetBatchedView(workspace + 2 * batch_count, C, batch_count, m * n, stream);
cublasStatus_t err = cublasDgemmBatched(Stream<gpu>::GetBlasHandle(stream),
GetT(transa), GetT(transb), m, n, k, &alpha,
(const double**)workspace, lda,
(const double**)(workspace + batch_count), ldb,
&beta, workspace + 2 * batch_count, ldc, batch_count);
CHECK_EQ(err, CUBLAS_STATUS_SUCCESS) << "Cublas: DgemmBatched fail";
#else
for (int i = 0; i < batch_count; ++i) {
gemm(stream, transa, transb, m, n, k, alpha,
A + i * m * k, lda, B + i * k * n, ldb,
beta, C + i * m * n, ldc);
}
#endif // defined(__CUDACC__) && CUDA_VERSION >= 4010
}
inline static void gemv(Stream<gpu> *stream,
bool trans, int m, int n, double alpha,
Expand Down Expand Up @@ -638,9 +704,6 @@ struct BLASEngine<gpu, double> {
inline static Shape<2> GetShape(const Shape<2> &shape, bool transpose) {
return transpose ? Shape2(shape[1], shape[0]) : shape;
}
inline static Shape<3> GetBatchedShape(const Shape<3> &shape, bool transpose) {
return transpose ? Shape3(shape[0], shape[2], shape[1]) : shape;
}
// dst = dot(lhs[.T], rhs[.T])
template<typename SV, typename xpu,
bool transpose_left, bool transpose_right, typename DType>
Expand Down Expand Up @@ -732,73 +795,6 @@ struct DotEngine<SV, xpu, 2, 1, 1, true, false, DType> {
}
}
};
// dst = batched_dot(lhs[.T], rhs[.T])
template<typename SV, typename xpu,
bool transpose_left, bool transpose_right, typename DType>
struct DotEngine<SV, xpu, 3, 3, 3, transpose_left, transpose_right, DType> {
inline static void Eval(Tensor<xpu, 3, DType> *p_dst,
const Tensor<xpu, 3, DType> &lhs,
const Tensor<xpu, 3, DType> &rhs,
DType scale) {
Tensor<xpu, 3, DType> &dst = *p_dst;
// set kernel stream
// if there is no stream, crush
BLASEngine<xpu, DType>::SetStream(dst.stream_);
Shape<3> sleft = GetBatchedShape(lhs.shape_, transpose_left);
Shape<3> sright = GetBatchedShape(rhs.shape_, transpose_right);
CHECK(dst.size(0) == sleft[0] && dst.size(0) == sright[0])
<< "batch_dot-gemm: batchsize must be equal."
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1])
<< "batch_dot-gemm: matrix shape mismatch"
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
// use column major argument to compatible with most BLAS
if (sleft[1] == 1) {
// For (batch, 1, K) gemm (batch, K, N), we can use (batch, N, K) gemv (batch, K)
BLASEngine<xpu, DType>::batched_gemv
(dst.stream_,
transpose_right,
rhs.size(2), rhs.size(1), scale * SV::AlphaBLAS(),
rhs.dptr_, rhs.stride_,
lhs.dptr_, 1, SV::BetaBLAS(),
dst.dptr_, 1, dst.size(0));
} else if (sleft[2] == 1 && (SV::BetaBLAS() == 0.0f || SV::BetaBLAS() == 1.0f)) {
// For (batch, M, 1) gemm (batch, 1, N) + Beta = 0, we can use (batch, M) ger (batch, N)
if (SV::BetaBLAS() == 0.0f) {
dst = DType(0);
}
BLASEngine<xpu, DType>::batched_ger
(dst.stream_, sright[2], sleft[1], scale * SV::AlphaBLAS(),
rhs.dptr_, 1, lhs.dptr_, 1, dst.dptr_, dst.stride_, dst.size(0));
} else if (sright[2] == 1) {
// For (batch, M, K) gemm (batch, K, 1), we can use (batch, M, K) gemv (batch, K)
BLASEngine<xpu, DType>::batched_gemv
(dst.stream_,
!transpose_left,
lhs.size(2), lhs.size(1), scale * SV::AlphaBLAS(),
lhs.dptr_, lhs.stride_,
rhs.dptr_, 1, SV::BetaBLAS(),
dst.dptr_, 1, dst.size(0));
} else {
// For general case, use gemm
BLASEngine<xpu, DType>::batched_gemm
(dst.stream_,
transpose_right, transpose_left,
transpose_right ? rhs.size(1) : rhs.size(2),
transpose_left ? lhs.size(2) : lhs.size(1),
transpose_right ? rhs.size(2) : rhs.size(1),
DType(scale * SV::AlphaBLAS()),
rhs.dptr_, rhs.stride_,
lhs.dptr_, lhs.stride_,
DType(SV::BetaBLAS()),
dst.dptr_, dst.stride_, dst.size(0));
}
}
};
} // namespace expr
} // namespace mshadow
#endif // MSHADOW_DOT_ENGINE_INL_H_
19 changes: 17 additions & 2 deletions mshadow/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -855,17 +855,32 @@ template<typename Saver, typename Reducer, int dimkeep,
inline void MapReduceKeepHighDim(TRValue<R, gpu, 1, DType> *dst,
const expr::Exp<E, DType, etype> &exp,
DType scale = 1);

/*!
* \brief CPU/GPU: 1 dimension vector dot
* \param dst Length 1 vector, used to hold the result.
* \param lhs Left operand vector
* \param rhs right operand vector
* \param rhs Right operand vector
*/
template<typename Device, typename DType>
inline void VectorDot(Tensor<Device, 1, DType> dst,
const Tensor<Device, 1, DType> &lhs,
const Tensor<Device, 1, DType> &rhs);
/*!
* \brief CPU/GPU: dst = alpha * op(lhs) op(rhs) + beta * dst
* \param dst Length 3 tensor, used to hold the result
* \param lhs Left operand vector
* \param rhs Right operand vector
* \param alpha multiplier of op(lhs)op(rhs)
* \param beta multiplier of dst
* \param workspace Workspace for casting DType* to DType** (batched-view), must have size >= 3 * batch_size
*/
template<bool transpose_left, bool transpose_right, typename Device, typename DType>
inline void BatchGEMM(Tensor<Device, 3, DType> dst,
const Tensor<Device, 3, DType> &lhs,
const Tensor<Device, 3, DType> &rhs,
DType alpha,
DType beta,
Tensor<Device, 1, DType*> workspace);
} // namespace mshadow
// include headers
#include "./stream_gpu-inl.h"
Expand Down
44 changes: 44 additions & 0 deletions mshadow/tensor_cpu-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -449,5 +449,49 @@ inline void VectorDot(Tensor<Device, 1, DType> dst,
mshadow::expr::BLASEngine<Device, DType>::dot(
lhs.stream_, lhs.size(0), lhs.dptr_, 1, rhs.dptr_, 1, dst.dptr_);
}

template<bool transpose_left, bool transpose_right, typename Device, typename DType>
inline void BatchGEMM(Tensor<Device, 3, DType> dst,
const Tensor<Device, 3, DType> &lhs,
const Tensor<Device, 3, DType> &rhs,
DType alpha,
DType beta,
Tensor<Device, 1, DType*> workspace) {
int batch_size = dst.shape_[0];
expr::BLASEngine<Device, DType>::SetStream(dst.stream_);
Shape<3> sleft = transpose_left ? Shape3(lhs.shape_[0], lhs.shape_[2], lhs.shape_[1])
: lhs.shape_;
Shape<3> sright = transpose_right ? Shape3(rhs.shape_[0], rhs.shape_[2], rhs.shape_[1])
: rhs.shape_;
CHECK_EQ(dst.CheckContiguous(), true);
CHECK_EQ(lhs.CheckContiguous(), true);
CHECK_EQ(rhs.CheckContiguous(), true);
CHECK(sleft[0] == batch_size && sright[0] == batch_size)
<< "BatchGEMM: batchsize must be equal."
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
CHECK(dst.size(1) == sleft[1] && dst.size(2) == sright[2] && sleft[2] == sright[1])
<< "BatchGEMM: matrix shape mismatch"
<< "dst: " << dst.shape_ << "\n"
<< "lhs: " << sleft << "\n"
<< "rhs: " << sright << "\n";
CHECK(workspace.size(0) >= 3 * batch_size)
<< "Workspace Size must be bigger than " << 3 * batch_size;
CHECK_EQ(workspace.CheckContiguous(), true);
// use column major argument to compatible with most BLAS
expr::BLASEngine<Device, DType>::batched_gemm
(dst.stream_,
transpose_right, transpose_left,
transpose_right ? rhs.size(1) : rhs.size(2),
transpose_left ? lhs.size(2) : lhs.size(1),
transpose_right ? rhs.size(2) : rhs.size(1),
alpha,
rhs.dptr_, rhs.stride_,
lhs.dptr_, lhs.stride_,
beta,
dst.dptr_, dst.stride_, batch_size,
workspace.dptr_);
}
} // namespace mshadow
#endif // MSHADOW_TENSOR_CPU_INL_H_

0 comments on commit 335b9a0

Please sign in to comment.