Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[BACKPORT][FEATURE]Support for fp16 in SpM x DnsM on GPU (#18930) #19074

Merged
merged 1 commit into from
Sep 3, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 13 additions & 13 deletions src/operator/tensor/dot-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ namespace op {
* \brief GPU scalar kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements: 1 thread/element
*/
template<int req>
template<int req, typename AType>
struct DotCsrDnsDnsScalarKernel {
/*!
* \brief This function represents performing an inner product between a row of lhs
Expand All @@ -63,20 +63,20 @@ struct DotCsrDnsDnsScalarKernel {
const nnvm::dim_t num_cols_r) {
const nnvm::dim_t irow = tid / num_cols_r; // row id of the lhs
const nnvm::dim_t icol = tid % num_cols_r; // col id of the rhs
DType sum = 0;
AType sum = 0;
for (IType j = indptr_l[irow]; j < indptr_l[irow+1]; ++j) {
const CType cur_col = col_idx_l[j]; // corresponding row id of the rhs
sum += data_l[j] * data_r[cur_col*num_cols_r+icol];
}
KERNEL_ASSIGN(out[tid], req, sum);
KERNEL_ASSIGN(out[tid], req, static_cast<DType>(sum));
}
};

/*!
* \brief GPU vector kernel of dot(csr, dns1) = dns2
* Parallelization by output matrix elements: 1 warp/element
*/
template<int req>
template<int req, typename AType>
struct DotCsrDnsDnsVectorKernel {
/*!
* \brief see DotCsrDnsDnsScalarKernel Map for documentation.
Expand All @@ -90,7 +90,7 @@ struct DotCsrDnsDnsVectorKernel {
const DType* data_r,
const nnvm::dim_t num_cols_r) {
using nnvm::dim_t;
__shared__ volatile DType vals[mshadow::cuda::kBaseThreadNum];
__shared__ volatile AType vals[mshadow::cuda::kBaseThreadNum];
const dim_t warp_id = tid / 32; // global warp id
const dim_t lane = tid & (32-1); // local thread id within warp
const dim_t irow = warp_id / num_cols_r; // lhs row that this warp computes
Expand All @@ -101,9 +101,9 @@ struct DotCsrDnsDnsVectorKernel {
const dim_t high = static_cast<dim_t>(indptr_l[irow+1]);

// Compute running sum per thread
DType sum = 0;
AType sum = 0;
for (dim_t j = low+lane; j < high; j+=32) {
sum += data_l[j] * data_r[col_idx_l[j]*num_cols_r + kcol];
sum += static_cast<AType>(data_l[j]) * static_cast<AType>(data_r[col_idx_l[j]*num_cols_r + kcol]);
}
vals[threadIdx.x] = sum; __syncwarp();

Expand All @@ -115,7 +115,7 @@ struct DotCsrDnsDnsVectorKernel {
if (lane < 1) {vals[threadIdx.x] += vals[threadIdx.x+ 1];} __syncwarp();

if (lane == 0) {
KERNEL_ASSIGN(out[irow*num_cols_r+kcol], req, vals[threadIdx.x]);
KERNEL_ASSIGN(out[irow*num_cols_r+kcol], req, static_cast<DType>(vals[threadIdx.x]));
}
}
};
Expand Down Expand Up @@ -418,7 +418,7 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
const TBlob& data_r = rhs;
const TBlob data_out = *ret;

MSHADOW_SGL_DBL_TYPE_SWITCH(data_l.type_flag_, DType, { // data type
MXNET_REAL_ACC_TYPE_SWITCH(data_l.type_flag_, DType, AType, { // data type and accelerator type
MSHADOW_IDX_TYPE_SWITCH(indptr_l.type_flag_, IType, { // indptr type
MSHADOW_IDX_TYPE_SWITCH(col_idx_l.type_flag_, CType, { // col idx type
if (kWriteTo == req) {
Expand Down Expand Up @@ -513,14 +513,14 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
if (num_cols_r > 4) {
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
Kernel<DotCsrDnsDnsScalarKernel<ReqType, AType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), csc_data_ptr, csc_indptr_ptr,
csc_indices_ptr, data_r.dptr<DType>(), num_cols_r);
});
} else {
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
Kernel<DotCsrDnsDnsVectorKernel<ReqType, AType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), csc_data_ptr, csc_indptr_ptr,
csc_indices_ptr, data_r.dptr<DType>(), num_cols_r);
});
Expand All @@ -529,14 +529,14 @@ inline void DotCsrDnsDnsImpl(const OpContext& ctx,
if (num_cols_r > 4) {
num_threads = data_out.Size();
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsScalarKernel<ReqType>, gpu>::Launch(s, num_threads,
Kernel<DotCsrDnsDnsScalarKernel<ReqType, AType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
} else {
num_threads = threads_per_warp * num_rows_l * num_cols_r;
MXNET_ASSIGN_REQ_SWITCH(req, ReqType, {
Kernel<DotCsrDnsDnsVectorKernel<ReqType>, gpu>::Launch(s, num_threads,
Kernel<DotCsrDnsDnsVectorKernel<ReqType, AType>, gpu>::Launch(s, num_threads,
data_out.dptr<DType>(), data_l.dptr<DType>(), indptr_l.dptr<IType>(),
col_idx_l.dptr<CType>(), data_r.dptr<DType>(), num_cols_r);
});
Expand Down
14 changes: 14 additions & 0 deletions tests/python/gpu/test_operator_gpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
import numpy as np
import unittest
from nose.tools import assert_raises
import scipy.sparse as sps
import mxnet.ndarray.sparse as mxsps
from mxnet.test_utils import check_consistency, set_default_context, assert_almost_equal, assert_allclose
from mxnet.base import MXNetError
from mxnet import autograd
Expand Down Expand Up @@ -2548,6 +2550,18 @@ def test_arange_like_dtype():
for v in out:
assert v.dtype == t


def test_fp16_spmm():
inp = mxsps.csr_matrix(sps.coo_matrix(([2.0], ([150], [100000]))).tocsr())
inp = inp.astype('float16', copy=False)
weight = mx.nd.random.randn(100001, 151)
weight = weight.astype('float16', copy=False)
out = mxsps.dot(inp, weight)
out_np = mx.nd.dot(inp, weight)
assert_almost_equal(out.asnumpy(), out_np, rtol=1e-3, atol=1e-5)


if __name__ == '__main__':
import nose
nose.runmodule()