From 03453c4bd303b49d760dde6cf290615e0427d6c2 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sat, 12 Jan 2019 00:06:05 -0800 Subject: [PATCH] cudnn dropout --- src/operator/cudnn_rnn-inl.h | 4 +- src/operator/nn/dropout-inl.h | 204 ++++++++++++++++++++++++++++------ src/operator/nn/dropout.cc | 8 +- src/operator/nn/dropout.cu | 4 +- 4 files changed, 181 insertions(+), 39 deletions(-) diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h index 7c450b77c7ec..cc8e4db404da 100644 --- a/src/operator/cudnn_rnn-inl.h +++ b/src/operator/cudnn_rnn-inl.h @@ -699,7 +699,7 @@ class CuDNNRNNOp : public Operator { if (param_.p > 0) { CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_)); dropout_size_ = dropout_byte_ / sizeof(DType); - dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU()); + dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU(s->dev_id)); } else { dropout_states_ = {}; dropout_byte_ = 0; @@ -764,7 +764,7 @@ class CuDNNRNNOp : public Operator { &reserve_space_byte_)); workspace_size_ = workspace_byte_ / sizeof(DType); // Allocate the reserve space - reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU()); + reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id)); // Check that number of params are correct size_t cudnn_param_size; diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 9668c227b309..c5625bfc44d6 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -78,7 +78,7 @@ struct DropoutParam : public dmlc::Parameter { template class DropoutOp { -#if defined(USE_MKL) && defined(_OPENMP) +#if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) static void BernoulliGenerate(common::random::RandGenerator gen, int n, double p, int* r) { typename RandGenerator::Impl genImpl(&gen, 1); @@ -150,23 +150,7 @@ class DropoutOp { return false; } -#ifdef __CUDACC__ - // GPU never uses MKL - static bool MSHADOW_CINLINE MKLForward(mshadow::Stream *s, RandGenerator *pgen, - const double pkeep, - const std::vector &in_data, - const std::vector &out_data) { - return false; - } - static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream *s, const double pkeep, - const std::vector &in_grad, - const std::vector &out_data, - const std::vector &out_grad) { - return false; - } -#endif // __CUDACC__ - -#else // #if defined(USE_MKL) && defined(_OPENMP) +#else // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) static bool MSHADOW_CINLINE MKLForward(mshadow::Stream *s, RandGenerator *pgen, const double pkeep, const std::vector &in_data, @@ -179,7 +163,7 @@ class DropoutOp { const std::vector &out_grad) { return false; } -#endif // #if defined(USE_MKL) && defined(_OPENMP) +#endif // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) public: /*! @@ -227,12 +211,136 @@ class DropoutOp { } }; - void Init(const DropoutParam ¶m) { + explicit DropoutOp(const DropoutParam ¶m, Context ctx) { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); this->axes_ = param.axes; +#if MXNET_USE_CUDNN == 1 + this->ctx_ = ctx; + if (ctx.dev_type == kGPU) { + init_cudnn_ = false; + dtype_ = mshadow::DataType::kCudnnFlag; + CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_desc_)); + CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); + } +#endif // MXNET_USE_CUDNN == 1 + } + + ~DropoutOp() { +#if MXNET_USE_CUDNN == 1 + if (this->ctx_.dev_type == kGPU) { + CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_)); + CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_)); + if (init_cudnn_) { + Storage::Get()->Free(dropout_states_); + Storage::Get()->Free(reserve_space_); + } + } +#endif // MXNET_USE_CUDNN == 1 } +#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + inline void CuDNNForward(const OpContext &ctx, + const TBlob &in, + const TBlob &out) { + Stream *s = ctx.get_stream(); + + // set dropout state. + // TODO(szha): expensive call, should be cached and reused across operators. + if (!init_cudnn_) { + CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_state_byte_)); + dropout_states_ = Storage::Get()->Alloc(dropout_state_byte_, Context::GPU(s->dev_id)); + CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_, + this->pkeep_, + dropout_states_.dptr, dropout_state_byte_, + seed_)); + } + + // describe input/output tensor + int dim[4], stride[4]; + dim[0] = 1; + dim[1] = 1; + dim[2] = 1; + dim[3] = out.Size(); + stride[0] = out.Size(); + stride[1] = out.Size(); + stride[2] = out.Size(); + stride[3] = 1; + CUDNN_CALL(cudnnSetTensorNdDescriptor(x_desc_, + dtype_, + 4, + dim, + stride)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(y_desc_, + dtype_, + 4, + dim, + stride)); + + // perform dropout with cudnn + CUDNN_CALL(cudnnDropoutGetReserveSpaceSize(x_desc_, &dropout_reserve_byte_)); + if (init_cudnn_ && dropout_reserve_byte_ > reserve_space_.size) { + Storage::Get()->Free(reserve_space_); + init_cudnn_ = false; + } + if (!init_cudnn_) { + reserve_space_ = Storage::Get()->Alloc(dropout_reserve_byte_, Context::GPU(s->dev_id)); + init_cudnn_ = true; + } + CUDNN_CALL(cudnnDropoutForward(s->dnn_handle_, + dropout_desc_, + x_desc_, + in.dptr(), + y_desc_, + out.dptr(), + reserve_space_.dptr, + dropout_reserve_byte_)); + } + + inline void CuDNNBackward(const OpContext &ctx, + const TBlob &out_grad, + const TBlob &in_grad) { + Stream *s = ctx.get_stream(); + + // describe input/output tensor + int dim[4], stride[4]; + dim[0] = 1; + dim[1] = 1; + dim[2] = 1; + dim[3] = in_grad.Size(); + stride[0] = in_grad.Size(); + stride[1] = in_grad.Size(); + stride[2] = in_grad.Size(); + stride[3] = 1; + CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_desc_, + dtype_, + 4, + dim, + stride)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_desc_, + dtype_, + 4, + dim, + stride)); + + // perform dropout with cudnn + CUDNN_CALL(cudnnDropoutBackward(s->dnn_handle_, + dropout_desc_, + dy_desc_, + out_grad.dptr(), + dx_desc_, + in_grad.dptr(), + reserve_space_.dptr, + dropout_reserve_byte_)); + } +#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, @@ -252,11 +360,15 @@ class DropoutOp { CHECK(req[dropout::kOut] != kAddTo); if (this->axes_.ndim() == 0) { // standard case for dropout +#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + CuDNNForward(ctx, in_data[dropout::kData], out); +#else LaunchRNG(s, pgen, out.Size(), - out.dptr(), - mask.dptr(), - in_data[dropout::kData].dptr(), - this->pkeep_); + out.dptr(), + mask.dptr(), + in_data[dropout::kData].dptr(), + this->pkeep_); +#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) return; } @@ -319,10 +431,14 @@ class DropoutOp { if (this->axes_.ndim() == 0) { // standard case for dropout CHECK_EQ(grad.Size(), mask.Size()); +#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + CuDNNBackward(ctx, grad, gdata); +#else MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { mxnet_op::Kernel, xpu>::Launch( s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); }); +#endif // MXNET_USE_CUDNN == 1 & defined(__CUDACC__) return; } // broardcast mul @@ -367,29 +483,54 @@ class DropoutOp { /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; TShape axes_; +#if MXNET_USE_CUDNN == 1 + Context ctx_; + cudnnDataType_t dtype_; + cudnnDropoutDescriptor_t dropout_desc_; + bool init_cudnn_; + uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) + size_t dropout_state_byte_, dropout_reserve_byte_; + Storage::Handle dropout_states_, reserve_space_; + cudnnTensorDescriptor_t x_desc_, y_desc_, dx_desc_, dy_desc_; +#endif // MXNET_USE_CUDNN == 1 }; // class DropoutOp +static OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, + const Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + const DropoutParam& param = nnvm::get(attrs.parsed); + OpStatePtr state; + MSHADOW_REAL_TYPE_SWITCH(in_types[dropout::kData], DType, { + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(param, ctx); + } else { + state = OpStatePtr::Create>(param, ctx); + } + return state; + }); + LOG(FATAL) << "should never reach here"; + return OpStatePtr(); // should never reach here +} + template -void DropoutCompute(const nnvm::NodeAttrs& attrs, +void DropoutCompute(const OpStatePtr& state, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - const DropoutParam& param = nnvm::get(attrs.parsed); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - DropoutOp op; - op.Init(param); + DropoutOp& op = state.get_state>(); op.Forward(ctx, inputs, req, outputs); }); } template -void DropoutGradCompute(const nnvm::NodeAttrs& attrs, +void DropoutGradCompute(const OpStatePtr& state, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - const DropoutParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1); CHECK_EQ(req.size(), 1); @@ -399,8 +540,7 @@ void DropoutGradCompute(const nnvm::NodeAttrs& attrs, out_data[dropout::kMask] = inputs[1]; MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - DropoutOp op; - op.Init(param); + DropoutOp& op = state.get_state>(); op.Backward(ctx, out_grads, out_data, req, outputs); }); } diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 3021e0105b4f..27cecee906a0 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -119,25 +119,27 @@ Example:: for (size_t i = 0; i < nout; ++i) out_type->push_back(dtype); return true; }) -.set_attr("FCompute", DropoutCompute) +.set_attr("FCreateOpState", CreateDropoutState) +.set_attr("FStatefulCompute", DropoutCompute) .set_attr("FGradient", DropoutGrad{"_backward_Dropout"}) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) .set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ ResourceRequest::kParallelRandom }; + return std::vector{ResourceRequest::kParallelRandom}; }) .add_argument("data", "NDArray-or-Symbol", "Input array to which dropout will be applied.") .add_arguments(DropoutParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_Dropout) .set_num_outputs(1) +.set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) .set_attr_parser(ParamParser) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) -.set_attr("FCompute", DropoutGradCompute); +.set_attr("FStatefulCompute", DropoutGradCompute); } // namespace op } // namespace mxnet diff --git a/src/operator/nn/dropout.cu b/src/operator/nn/dropout.cu index 832490b08f1f..20c5714dd904 100644 --- a/src/operator/nn/dropout.cu +++ b/src/operator/nn/dropout.cu @@ -30,10 +30,10 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(Dropout) -.set_attr("FCompute", DropoutCompute); +.set_attr("FStatefulCompute", DropoutCompute); NNVM_REGISTER_OP(_backward_Dropout) -.set_attr("FCompute", DropoutGradCompute); +.set_attr("FStatefulCompute", DropoutGradCompute); } // namespace op } // namespace mxnet