Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Support projection feature for LSTM on CPU (Only Inference) #17702

Merged
merged 2 commits into from
Mar 2, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 0 additions & 8 deletions src/operator/nn/mkldnn/mkldnn_base-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,6 @@ static inline bool SupportMKLDNN(int dtype, const mxnet::TShape &shape) {
(ndim == 1 || ndim == 2 || ndim == 4);
}

static inline bool SupportMKLDNNRnn(const NDArray &input) {
if (input.dtype() == mshadow::kFloat32 && input.shape().ndim() == 3
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
return true;
}
return false;
}

static inline bool SupportMKLDNNQuantize(int dtype) {
return dtype == mshadow::kFloat32 || dtype == mshadow::kInt8 ||
dtype == mshadow::kUint8 || dtype == mshadow::kBfloat16;
Expand Down
12 changes: 12 additions & 0 deletions src/operator/nn/mkldnn/mkldnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -441,6 +441,18 @@ class MKLDNNRnnOp {
const std::vector<NDArray> &outputs);
};

inline bool SupportMKLDNNRnn(const int input_dtype) {
if (input_dtype == mshadow::kFloat32 && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
return true;
}
return false;
}

inline bool SupportMKLDNNRnn(const RNNParam &param, const int input_dtype) {
if (param.projection_size.has_value()) return false;
return SupportMKLDNNRnn(input_dtype);
}

} // namespace op
} // namespace mxnet

Expand Down
2 changes: 0 additions & 2 deletions src/operator/nn/mkldnn/mkldnn_rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -339,7 +339,6 @@ FUNC(MKLDNN_ARG_DIFF_##NAME, ARGS.at(MKLDNN_ARG_##NAME).get_desc(), HANDLE)
void MKLDNNRnnForward::SetNewDataMem(void* x, void* hx, void* cx,
void* y, void* hy, void* cy,
const int dtype) {
using dims = mkldnn::memory::dims;
using desc = mkldnn::memory::desc;
using format_tag = mkldnn::memory::format_tag;
auto& cpu_engine = CpuEngine::Get()->get_engine();
Expand Down Expand Up @@ -632,7 +631,6 @@ void MKLDNNRnnOp::Init(const OpContext &ctx,
const std::vector<NDArray> &inputs,
const std::vector<OpReqType> &req,
const std::vector<NDArray> &outputs) {
using memory = mkldnn::memory;
using format_tag = mkldnn::memory::format_tag;

// In the `autograd.record()` context, RNNOp is required to run into
Expand Down
33 changes: 24 additions & 9 deletions src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,6 +185,7 @@ inline int GetRnnBiasSize(int num_layer,
inline size_t GetRNNWorkspaceSize(int seq_length,
int batch_size,
int hidden_size,
int projection_size,
int direction,
int mode) {
size_t size = 0;
Expand Down Expand Up @@ -324,6 +325,7 @@ void RNNForwardInference(DType* ws,
const int batch_size,
const int input_size,
const int state_size,
const int projection_size,
DType* x_ptr,
DType* hx_ptr,
DType* cx_ptr,
Expand All @@ -336,8 +338,8 @@ void RNNForwardInference(DType* ws,
switch (mode) {
case rnn_enum::kLstm:
LstmForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
batch_size, input_size, state_size, x_ptr, hx_ptr, cx_ptr,
w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
batch_size, input_size, state_size, projection_size,
x_ptr, hx_ptr, cx_ptr, w_ptr, b_ptr, y_ptr, hy_ptr, cy_ptr);
break;
case rnn_enum::kGru:
GruForwardInference<DType>(ws, state_outputs, num_layers, direction, seq_length,
Expand Down Expand Up @@ -511,10 +513,7 @@ class RNNOp {
this->temp_init_space_ = false;
this->reserve_cpu_space_size_ = 0;
this->temp_cpu_space_size_ = 0;
if (param_.projection_size.has_value()) {
LOG(FATAL) <<
"hidden layer projection is only supported for GPU with CuDNN later than 7.1.1";
}

if (param_.lstm_state_clip_min.has_value()
|| param_.lstm_state_clip_max.has_value()) {
LOG(FATAL) << "LSTM state clipping is only supported for GPU with CuDNN later than 7.2.1";
Expand Down Expand Up @@ -843,9 +842,14 @@ class RNNOp {
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
projection_size = param_.projection_size.value();
}

// allocate temp space
const size_t work_cpu_space_size = GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, direction, param_.mode);
param_.state_size, projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ < work_cpu_space_size) {
temp_cpu_space_size_ = work_cpu_space_size;
temp_cpu_space_ = NDArray(TShape({static_cast<dim_t>(temp_cpu_space_size_)}), ctx_,
Expand All @@ -856,6 +860,9 @@ class RNNOp {

if (ctx.is_train || ctx.need_grad) {
// allocate reserve space
if (param_.projection_size.has_value()) {
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

const size_t r_size = GetRNNReserveSpaceSize(param_.num_layers, direction,
param_.seq_length_, param_.batch_size_,
Expand Down Expand Up @@ -896,6 +903,7 @@ class RNNOp {
param_.batch_size_,
param_.input_size_,
param_.state_size,
projection_size,
x.dptr_,
hx.dptr_,
cx_ptr,
Expand Down Expand Up @@ -1096,10 +1104,17 @@ class RNNOp {
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

if (ctx_.dev_type == kCPU) {
int projection_size = 0;
if (param_.projection_size.has_value()) {
// TODO(zixuanweeei): Add training support for LSTM with projection on CPU.
// projection_size = param_.projection_size.value();
LOG(FATAL) << "No training support for LSTM with projection on CPU currently.";
}

// allocate temp space
const size_t work_cpu_space_size =
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_,
param_.state_size, direction, param_.mode);
GetRNNWorkspaceSize(param_.seq_length_, param_.batch_size_, param_.state_size,
projection_size, direction, param_.mode);
if (!temp_init_space_ || temp_cpu_space_size_ != work_cpu_space_size) {
LOG(FATAL) << "Check temp init error";
}
Expand Down
44 changes: 29 additions & 15 deletions src/operator/rnn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -190,20 +190,19 @@ static std::vector<ResourceRequest> RNNResourceEx(const NodeAttrs& attrs, const
return request;
}

#if MXNET_USE_MKLDNN == 1
inline static bool RNNStorageType(const nnvm::NodeAttrs& attrs,
const int dev_mask,
DispatchMode* dispatch_mode,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
DispatchMode wanted_mode = DispatchMode::kFCompute;

#if MXNET_USE_MKLDNN == 1
wanted_mode = DispatchMode::kFComputeEx;
#endif // MXNET_USE_MKLDNN == 1

return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
dispatch_mode, wanted_mode);
const RNNParam& param = nnvm::get<RNNParam>(attrs.parsed);
const bool support_mkldnn_rnn =
!param.projection_size.has_value() && dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1);
return MKLDNNStorageType(attrs, dev_mask, support_mkldnn_rnn,
dispatch_mode, in_attrs, out_attrs);
}
#endif // MXNET_USE_MKLDNN == 1

struct RNNGrad {
const char *op_name;
Expand Down Expand Up @@ -246,9 +245,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs &attrs,
}

#if MXNET_USE_MKLDNN == 1
if ((in_types[0] == mshadow::kFloat32 || in_types[0] == mshadow::kFloat16)
&& in_shapes[0].ndim() == 3 && ctx.dev_type == kCPU
&& dmlc::GetEnv("MXNET_USE_MKLDNN_RNN", 1)) {
if (ctx.dev_type == kCPU && SupportMKLDNNRnn(param, in_types[rnn_enum::kData])) {
const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData];
state = OpStatePtr::Create<MKLDNNRnnOp>(param, data_shape[0],
data_shape[1], data_shape[2]);
Expand All @@ -274,7 +271,7 @@ static void RNNStatefulComputeExCPU(const OpStatePtr& state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRnn(inputs[0])) {
if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Forward(ctx, inputs, req, outputs);
} else {
Expand All @@ -287,7 +284,7 @@ static void RNNStatefulGradComputeExCPU(const OpStatePtr& state_ptr,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (SupportMKLDNNRnn(inputs[0])) {
if (SupportMKLDNNRnn(inputs[rnn_enum::kData].dtype())) {
MKLDNNRnnOp& op = state_ptr.get_state<MKLDNNRnnOp>();
op.Backward(ctx, inputs, req, outputs);
} else {
Expand Down Expand Up @@ -338,6 +335,23 @@ Long Short-Term Memory - Hochreiter, 1997. http://www.bioinf.jku.at/publications
h_t = o_t * \tanh(c_t)
\end{array}

With the projection size being set, LSTM could use the projection feature to reduce the parameters
size and give some speedups without significant damage to the accuracy.

Long Short-Term Memory Based Recurrent Neural Network Architectures for Large Vocabulary Speech
Recognition - Sak et al. 2014. https://arxiv.org/abs/1402.1128

.. math::
\begin{array}{ll}
i_t = \mathrm{sigmoid}(W_{ii} x_t + b_{ii} + W_{ri} r_{(t-1)} + b_{ri}) \\
f_t = \mathrm{sigmoid}(W_{if} x_t + b_{if} + W_{rf} r_{(t-1)} + b_{rf}) \\
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{rc} r_{(t-1)} + b_{rg}) \\
o_t = \mathrm{sigmoid}(W_{io} x_t + b_{o} + W_{ro} r_{(t-1)} + b_{ro}) \\
c_t = f_t * c_{(t-1)} + i_t * g_t \\
h_t = o_t * \tanh(c_t)
r_t = W_{hr} h_t
\end{array}

**GRU**

Gated Recurrent Unit - Cho et al. 2014. http://arxiv.org/abs/1406.1078
Expand Down Expand Up @@ -385,10 +399,10 @@ The definition of GRU here is slightly different from paper but compatible with
})
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think the projection support is clear in the documentation. Could you update the documentation with LSTMP support when projection_size is set? You can refer to https://github.com/apache/incubator-mxnet/blob/62a85f365b819829fedb60116f803e0c0a3c554c/python/mxnet/gluon/contrib/rnn/rnn_cell.py#L197

Thanks!

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Thanks for pointing out that.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done. Please take a review again. Thanks.

.set_attr<mxnet::FInferShape>("FInferShape", RNNShape)
.set_attr<nnvm::FInferType>("FInferType", RNNType)
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FCreateOpState>("FCreateOpState", CreateRNNState)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulComputeExCPU)
#endif
Expand Down Expand Up @@ -427,9 +441,9 @@ NNVM_REGISTER_OP(_backward_RNN)
.set_attr_parser(ParamParser<RNNParam>)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", RNNStatefulGradCompute<cpu>)
#if MXNET_USE_MKLDNN == 1
.set_attr<FInferStorageType>("FInferStorageType", RNNStorageType)
.set_attr<bool>("TIsMKLDNN", true)
.set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", RNNStatefulGradComputeExCPU)
#endif
Expand Down
48 changes: 35 additions & 13 deletions src/operator/rnn_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ void LstmForwardInferenceSingleLayer(DType* ws,
const int N,
const int I,
const int H,
const int P,
const Tensor<cpu, 2, DType> &x,
const Tensor<cpu, 2, DType> &hx,
const Tensor<cpu, 2, DType> &cx,
Expand All @@ -219,7 +220,9 @@ void LstmForwardInferenceSingleLayer(DType* ws,
DType* cy_ptr) {
using namespace mshadow;
const Tensor<cpu, 2, DType> wx(w_ptr, Shape2(H * 4, I));
const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, H));
const Tensor<cpu, 2, DType> wh(w_ptr + I * H * 4, Shape2(H * 4, (P ? P : H)));
Tensor<cpu, 2, DType> whr(w_ptr, Shape2(1, 1));
if (P > 0) whr = Tensor<cpu, 2, DType>(wh.dptr_ + P * 4 * H, Shape2(P, H));
const Tensor<cpu, 2, DType> bx(b_ptr, Shape2(4, H));
const Tensor<cpu, 2, DType> bh(b_ptr + H * 4, Shape2(4, H));
Tensor<cpu, 2, DType> yx_flat(ws, Shape2(T * N, H * 4));
Expand All @@ -228,7 +231,10 @@ void LstmForwardInferenceSingleLayer(DType* ws,
const Tensor<cpu, 3, DType> yh(yh_flat.dptr_, Shape3(N, 4, H));
Tensor<cpu, 2, DType> h(yh_flat.dptr_ + N * H * 4, Shape2(N, H));
Tensor<cpu, 2, DType> c(h.dptr_ + N * H, Shape2(N, H));
Tensor<cpu, 2, DType> r(hy_ptr, Shape2(1, 1));
if (P > 0) r = Tensor<cpu, 2, DType>(hy_ptr, Shape2(N, P));
const int offset = bid ? H : 0;
const int proj_offset = bid ? P : 0;
const DType alpha = 1.0;
const DType beta = 0.0;
const int cell_size = N * H;
Expand All @@ -237,7 +243,11 @@ void LstmForwardInferenceSingleLayer(DType* ws,
const int omp_threads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount();
for (int i = 0; i < T; ++i) {
int t = bid ? T - 1 - i : i;
linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
if (P > 0) {
linalg_gemm(i ? r : hx, wh, yh_flat, alpha, beta, false, true);
} else {
linalg_gemm(i ? h : hx, wh, yh_flat, alpha, beta, false, true);
}
#pragma omp parallel for num_threads(omp_threads)
for (int jk = 0; jk < cell_size; ++jk) {
int j = jk / H;
Expand All @@ -248,14 +258,21 @@ void LstmForwardInferenceSingleLayer(DType* ws,
DType ot = sigmoid<DType>(yx[t][j][3][k] + yh[j][3][k] + bx[3][k] + bh[3][k]);
DType ct = (i ? c[j][k] : cx[j][k]) * ft + it * gt;
DType ht = ot * tanh(ct);
y[t][j][k + offset] = ht;
if (P == 0) y[t][j][k + offset] = ht;
if (i == T - 1 && state_outputs) {
hy_ptr[jk] = ht;
if (P == 0) hy_ptr[jk] = ht;
cy_ptr[jk] = ct;
} else {
h[j][k] = ht;
c[j][k] = ct;
}
h[j][k] = ht;
}
if (P > 0) {
linalg_gemm(h, whr, r, alpha, beta, false, true);
#pragma omp parallel for num_threads(omp_threads)
for (int j = 0; j < N; ++j) {
std::memcpy(y[t][j].dptr_ + proj_offset, r[j].dptr_, P * sizeof(DType));
}
}
}
}
Expand All @@ -269,6 +286,7 @@ void LstmForwardInference(DType* ws,
const int N,
const int I,
const int H,
const int P,
DType* x_ptr,
DType* hx_ptr,
DType* cx_ptr,
Expand All @@ -278,36 +296,40 @@ void LstmForwardInference(DType* ws,
DType* hy_ptr,
DType* cy_ptr) {
const int total_layers = D * L;
Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, H));
Tensor<cpu, 3, DType> hx(hx_ptr, Shape3(total_layers, N, P ? P : H));
Tensor<cpu, 3, DType> cx(cx_ptr, Shape3(total_layers, N, H));
const int b_size = 2 * H * 4;
const int cell_size = N * H;
const int projection_size = (P ? P : H) * N;
DType* y_tmp_ptr = ws + (T + 1) * cell_size * 4 + cell_size * 2;
DType* y_cur_ptr = y_ptr;
int idx = 0; // state & cell state's idx;
bool flag = L % 2 ? false : true;
for (int i = 0; i < L; ++i) {
const int input_size = i ? H * D : I;
const int w_size = (input_size + H) * H * 4;
const int input_size = i ? (P ? P : H) * D : I;
int w_size = (input_size + (P ? P : H)) * H * 4;
if (P > 0) {
w_size += P * H;
}
// If bidirectional, need space to save current layer output y.
if (D == 2) {
y_cur_ptr = flag ? y_tmp_ptr : y_ptr;
flag = !flag;
}
Tensor<cpu, 2, DType> x(x_ptr, Shape2(T * N, input_size));
Tensor<cpu, 3, DType> y(y_cur_ptr, Shape3(T, N, H * D));
LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, false, T, N, input_size, H,
Tensor<cpu, 3, DType> y(y_cur_ptr, Shape3(T, N, (P ? P : H) * D));
LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, false, T, N, input_size, H, P,
x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr);
// If bidirectional, then calculate the reverse direction's forward result.
if (D == 2) {
w_ptr += w_size;
b_ptr += b_size;
++idx;
if (state_outputs) {
hy_ptr += cell_size;
hy_ptr += projection_size;
cy_ptr += cell_size;
}
LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, true, T, N, input_size, H,
LstmForwardInferenceSingleLayer<DType>(ws, state_outputs, true, T, N, input_size, H, P,
x, hx[idx], cx[idx], y, w_ptr, b_ptr, hy_ptr, cy_ptr);
}
// Don't need to move pointer in the last layer.
Expand All @@ -317,7 +339,7 @@ void LstmForwardInference(DType* ws,
x_ptr = y_cur_ptr;
++idx;
if (state_outputs) {
hy_ptr += cell_size;
hy_ptr += projection_size;
cy_ptr += cell_size;
}
}
Expand Down
Loading