From da56d291030a1f853f1e26c45b3a48b4fa0c5a1c Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sat, 12 Jan 2019 00:06:05 -0800 Subject: [PATCH 1/9] cudnn dropout --- src/operator/cudnn_rnn-inl.h | 4 +- src/operator/nn/dropout-inl.h | 219 +++++++++++++++++++++++++++++----- src/operator/nn/dropout.cc | 8 +- src/operator/nn/dropout.cu | 4 +- 4 files changed, 196 insertions(+), 39 deletions(-) diff --git a/src/operator/cudnn_rnn-inl.h b/src/operator/cudnn_rnn-inl.h index 7c450b77c7ec..cc8e4db404da 100644 --- a/src/operator/cudnn_rnn-inl.h +++ b/src/operator/cudnn_rnn-inl.h @@ -699,7 +699,7 @@ class CuDNNRNNOp : public Operator { if (param_.p > 0) { CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_)); dropout_size_ = dropout_byte_ / sizeof(DType); - dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU()); + dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU(s->dev_id)); } else { dropout_states_ = {}; dropout_byte_ = 0; @@ -764,7 +764,7 @@ class CuDNNRNNOp : public Operator { &reserve_space_byte_)); workspace_size_ = workspace_byte_ / sizeof(DType); // Allocate the reserve space - reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU()); + reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id)); // Check that number of params are correct size_t cudnn_param_size; diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 9668c227b309..4c4616ecff3b 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -78,7 +78,7 @@ struct DropoutParam : public dmlc::Parameter { template class DropoutOp { -#if defined(USE_MKL) && defined(_OPENMP) +#if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) static void BernoulliGenerate(common::random::RandGenerator gen, int n, double p, int* r) { typename RandGenerator::Impl genImpl(&gen, 1); @@ -150,23 +150,7 @@ class DropoutOp { return false; } -#ifdef __CUDACC__ - // GPU never uses MKL - static bool MSHADOW_CINLINE MKLForward(mshadow::Stream *s, RandGenerator *pgen, - const double pkeep, - const std::vector &in_data, - const std::vector &out_data) { - return false; - } - static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream *s, const double pkeep, - const std::vector &in_grad, - const std::vector &out_data, - const std::vector &out_grad) { - return false; - } -#endif // __CUDACC__ - -#else // #if defined(USE_MKL) && defined(_OPENMP) +#else // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) static bool MSHADOW_CINLINE MKLForward(mshadow::Stream *s, RandGenerator *pgen, const double pkeep, const std::vector &in_data, @@ -179,7 +163,7 @@ class DropoutOp { const std::vector &out_grad) { return false; } -#endif // #if defined(USE_MKL) && defined(_OPENMP) +#endif // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) public: /*! @@ -227,12 +211,136 @@ class DropoutOp { } }; - void Init(const DropoutParam ¶m) { + explicit DropoutOp(const DropoutParam ¶m, Context ctx) { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); this->axes_ = param.axes; +#if MXNET_USE_CUDNN == 1 + this->ctx_ = ctx; + if (ctx.dev_type == kGPU && this->pkeep_ > 0) { + init_cudnn_ = false; + dtype_ = mshadow::DataType::kCudnnFlag; + CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_desc_)); + CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_desc_)); + CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); + } +#endif // MXNET_USE_CUDNN == 1 + } + + ~DropoutOp() { +#if MXNET_USE_CUDNN == 1 + if (this->ctx_.dev_type == kGPU && this->pkeep_ > 0) { + CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_)); + CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_)); + CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_)); + if (init_cudnn_) { + Storage::Get()->Free(dropout_states_); + Storage::Get()->Free(reserve_space_); + } + } +#endif // MXNET_USE_CUDNN == 1 } +#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + inline void CuDNNForward(const OpContext &ctx, + const TBlob &in, + const TBlob &out) { + Stream *s = ctx.get_stream(); + + // set dropout state. + // TODO(szha): expensive call, should be cached and reused across operators. + if (!init_cudnn_) { + CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_state_byte_)); + dropout_states_ = Storage::Get()->Alloc(dropout_state_byte_, Context::GPU(s->dev_id)); + CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_, + 1.0f - this->pkeep_, + dropout_states_.dptr, dropout_state_byte_, + seed_)); + } + + // describe input/output tensor + int dim[4], stride[4]; + dim[0] = 1; + dim[1] = 1; + dim[2] = 1; + dim[3] = out.Size(); + stride[0] = out.Size(); + stride[1] = out.Size(); + stride[2] = out.Size(); + stride[3] = 1; + CUDNN_CALL(cudnnSetTensorNdDescriptor(x_desc_, + dtype_, + 4, + dim, + stride)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(y_desc_, + dtype_, + 4, + dim, + stride)); + + // perform dropout with cudnn + CUDNN_CALL(cudnnDropoutGetReserveSpaceSize(x_desc_, &dropout_reserve_byte_)); + if (init_cudnn_ && dropout_reserve_byte_ > reserve_space_.size) { + Storage::Get()->Free(reserve_space_); + init_cudnn_ = false; + } + if (!init_cudnn_) { + reserve_space_ = Storage::Get()->Alloc(dropout_reserve_byte_, Context::GPU(s->dev_id)); + init_cudnn_ = true; + } + CUDNN_CALL(cudnnDropoutForward(s->dnn_handle_, + dropout_desc_, + x_desc_, + in.dptr(), + y_desc_, + out.dptr(), + reserve_space_.dptr, + dropout_reserve_byte_)); + } + + inline void CuDNNBackward(const OpContext &ctx, + const TBlob &out_grad, + const TBlob &in_grad) { + Stream *s = ctx.get_stream(); + + // describe input/output tensor + int dim[4], stride[4]; + dim[0] = 1; + dim[1] = 1; + dim[2] = 1; + dim[3] = in_grad.Size(); + stride[0] = in_grad.Size(); + stride[1] = in_grad.Size(); + stride[2] = in_grad.Size(); + stride[3] = 1; + CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_desc_, + dtype_, + 4, + dim, + stride)); + CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_desc_, + dtype_, + 4, + dim, + stride)); + + // perform dropout with cudnn + CUDNN_CALL(cudnnDropoutBackward(s->dnn_handle_, + dropout_desc_, + dy_desc_, + out_grad.dptr(), + dx_desc_, + in_grad.dptr(), + reserve_space_.dptr, + dropout_reserve_byte_)); + } +#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + void Forward(const OpContext &ctx, const std::vector &in_data, const std::vector &req, @@ -252,11 +360,23 @@ class DropoutOp { CHECK(req[dropout::kOut] != kAddTo); if (this->axes_.ndim() == 0) { // standard case for dropout +#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + if (this->pkeep_ > 0) { // existing dropout produces inf with pkeep=0 + CuDNNForward(ctx, in_data[dropout::kData], out); + } else { + LaunchRNG(s, pgen, out.Size(), + out.dptr(), + mask.dptr(), + in_data[dropout::kData].dptr(), + this->pkeep_); + } +#else LaunchRNG(s, pgen, out.Size(), - out.dptr(), - mask.dptr(), - in_data[dropout::kData].dptr(), - this->pkeep_); + out.dptr(), + mask.dptr(), + in_data[dropout::kData].dptr(), + this->pkeep_); +#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) return; } @@ -319,10 +439,21 @@ class DropoutOp { if (this->axes_.ndim() == 0) { // standard case for dropout CHECK_EQ(grad.Size(), mask.Size()); +#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + if (this->pkeep_ > 0) { + CuDNNBackward(ctx, grad, gdata); + } else { + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); + }); + } +#else MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { mxnet_op::Kernel, xpu>::Launch( s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); }); +#endif // MXNET_USE_CUDNN == 1 & defined(__CUDACC__) return; } // broardcast mul @@ -367,29 +498,54 @@ class DropoutOp { /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; TShape axes_; +#if MXNET_USE_CUDNN == 1 + Context ctx_; + cudnnDataType_t dtype_; + cudnnDropoutDescriptor_t dropout_desc_; + bool init_cudnn_; + uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) + size_t dropout_state_byte_, dropout_reserve_byte_; + Storage::Handle dropout_states_, reserve_space_; + cudnnTensorDescriptor_t x_desc_, y_desc_, dx_desc_, dy_desc_; +#endif // MXNET_USE_CUDNN == 1 }; // class DropoutOp +static OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, + const Context ctx, + const std::vector &in_shapes, + const std::vector &in_types) { + const DropoutParam& param = nnvm::get(attrs.parsed); + OpStatePtr state; + MSHADOW_REAL_TYPE_SWITCH(in_types[dropout::kData], DType, { + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(param, ctx); + } else { + state = OpStatePtr::Create>(param, ctx); + } + return state; + }); + LOG(FATAL) << "should never reach here"; + return OpStatePtr(); // should never reach here +} + template -void DropoutCompute(const nnvm::NodeAttrs& attrs, +void DropoutCompute(const OpStatePtr& state, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - const DropoutParam& param = nnvm::get(attrs.parsed); MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - DropoutOp op; - op.Init(param); + DropoutOp& op = state.get_state>(); op.Forward(ctx, inputs, req, outputs); }); } template -void DropoutGradCompute(const nnvm::NodeAttrs& attrs, +void DropoutGradCompute(const OpStatePtr& state, const OpContext& ctx, const std::vector& inputs, const std::vector& req, const std::vector& outputs) { - const DropoutParam& param = nnvm::get(attrs.parsed); CHECK_EQ(inputs.size(), 2U); CHECK_EQ(outputs.size(), 1); CHECK_EQ(req.size(), 1); @@ -399,8 +555,7 @@ void DropoutGradCompute(const nnvm::NodeAttrs& attrs, out_data[dropout::kMask] = inputs[1]; MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, { - DropoutOp op; - op.Init(param); + DropoutOp& op = state.get_state>(); op.Backward(ctx, out_grads, out_data, req, outputs); }); } diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 3021e0105b4f..27cecee906a0 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -119,25 +119,27 @@ Example:: for (size_t i = 0; i < nout; ++i) out_type->push_back(dtype); return true; }) -.set_attr("FCompute", DropoutCompute) +.set_attr("FCreateOpState", CreateDropoutState) +.set_attr("FStatefulCompute", DropoutCompute) .set_attr("FGradient", DropoutGrad{"_backward_Dropout"}) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) .set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ ResourceRequest::kParallelRandom }; + return std::vector{ResourceRequest::kParallelRandom}; }) .add_argument("data", "NDArray-or-Symbol", "Input array to which dropout will be applied.") .add_arguments(DropoutParam::__FIELDS__()); NNVM_REGISTER_OP(_backward_Dropout) .set_num_outputs(1) +.set_attr("TIsLayerOpBackward", true) .set_attr("TIsBackward", true) .set_attr_parser(ParamParser) .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) -.set_attr("FCompute", DropoutGradCompute); +.set_attr("FStatefulCompute", DropoutGradCompute); } // namespace op } // namespace mxnet diff --git a/src/operator/nn/dropout.cu b/src/operator/nn/dropout.cu index 832490b08f1f..20c5714dd904 100644 --- a/src/operator/nn/dropout.cu +++ b/src/operator/nn/dropout.cu @@ -30,10 +30,10 @@ namespace mxnet { namespace op { NNVM_REGISTER_OP(Dropout) -.set_attr("FCompute", DropoutCompute); +.set_attr("FStatefulCompute", DropoutCompute); NNVM_REGISTER_OP(_backward_Dropout) -.set_attr("FCompute", DropoutGradCompute); +.set_attr("FStatefulCompute", DropoutGradCompute); } // namespace op } // namespace mxnet From abed0b130334cd4f85199bc9e1cca851c02c8f40 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 18 Jan 2019 18:42:49 -0800 Subject: [PATCH 2/9] test dropout as stateful op --- tests/cpp/include/test_core_op.h | 75 ++++++++++++++++++++++++++++++-- 1 file changed, 71 insertions(+), 4 deletions(-) diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index c39373b1b798..7fde28947af8 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -511,9 +511,25 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer function_ = common::GetFCompute(op_, "FCompute", ctx_.run_ctx.ctx); functionex_ = common::GetFCompute(op_, "FComputeEx", ctx_.run_ctx.ctx); + stateful_function_ = common::GetFCompute(op_, "FStatefulCompute", + ctx_.run_ctx.ctx); AttachResources(&ctx_, attrs_, op_); + auto& is_layer_backward = Op::GetAttr("TIsLayerOpBackward"); + auto& createop = nnvm::Op::GetAttr("FCreateOpState"); + if (createop.count(op_) || is_layer_backward.get(op_, false)) { + if (backward_for_op) { + state_ = backward_for_op->state_; + } + if (!state_) { + if (!create_state_) { + create_state_ = createop[op_]; + } + state_ = create_state_(attrs_, ctx_.run_ctx.ctx, input_shapes_, input_types); + } + } + if (!backward_for_op) { bool no_backward = false; // Set up backward @@ -561,8 +577,14 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer inline void forward(const size_t count) { perf::TimingItem timeF(&OperatorExecutorTiming::GetTiming(), kForward, "Forward", count); mxnet::profiler::vtune::VTuneResume profile; - for (size_t i = 0; i < count; ++i) { - Execute(); + if (stateful_function_) { + for (size_t i = 0; i < count; ++i) { + ExecuteStateful(); + } + } else { + for (size_t i = 0; i < count; ++i) { + Execute(); + } } } @@ -570,8 +592,14 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer CHECK(HasBackward()); perf::TimingItem timeF(&OperatorExecutorTiming::GetTiming(), kBackward, "Backward", count); mxnet::profiler::vtune::VTuneResume profile; - for (size_t i = 0; i < count; ++i) { - ExecuteBackward(); + if (stateful_function_) { + for (size_t i = 0; i < count; ++i) { + ExecuteBackwardStateful(); + } + } else { + for (size_t i = 0; i < count; ++i) { + ExecuteBackward(); + } } } @@ -595,6 +623,17 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer functionex_(attrs_, ctx_, inputs_, req_, outputs_); } + /*! + * \brief Execute the stateful operator + */ + void ExecuteStateful() { + CHECK_EQ(initialized_, true); + CHECK(state_); + CollectBlobs(inputs_, &blob_inputs_); + CollectBlobs(outputs_, &blob_outputs_); + stateful_function_(state_, ctx_, blob_inputs_, req_, blob_outputs_); + } + bool HasBackward() const { return !backward_.empty(); } @@ -631,6 +670,22 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer return false; } + /*! + * \brief Execute backward pass on stateful operator + */ + bool ExecuteBackwardStateful() { + CHECK_EQ(initialized_, true); + CHECK(HasBackward()); + if (!backward_.empty()) { + // Avoid locked ref count here + for (std::shared_ptr &p : backward_) { + p->ExecuteStateful(); + } + return true; + } + return false; + } + /*! * \brief Access input NDArray vector * \return reference to NDArray vector of forward inputs @@ -738,6 +793,18 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer * \brief Operator's FCompute function (for sparse tensors) */ FComputeEx functionex_; + /*! + * \brief Operator's FStatefulCompute function + */ + FStatefulCompute stateful_function_; + /*! + * \brief Operator's FCreateOpState function + */ + FCreateOpState create_state_; + /*! + * \brief Operator state + */ + OpStatePtr state_; /*! * \brief Backward executors (if any) From bc0927ed1b932f113bbd6723273f10e72b2b4878 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Sun, 20 Jan 2019 13:22:07 -0800 Subject: [PATCH 3/9] add cudnn_off --- src/operator/nn/dropout-inl.h | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 4c4616ecff3b..1cb324c4957c 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -62,6 +62,7 @@ struct DropoutParam : public dmlc::Parameter { float p; int mode; TShape axes; + dmlc::optional cudnn_off; DMLC_DECLARE_PARAMETER(DropoutParam) { DMLC_DECLARE_FIELD(p).set_default(0.5) .set_range(0, 1) @@ -73,6 +74,8 @@ struct DropoutParam : public dmlc::Parameter { .describe("Whether to only turn on dropout during training or to also turn on for inference."); DMLC_DECLARE_FIELD(axes).set_default(TShape()) .describe("Axes for variational dropout kernel."); + DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional(false)) + .describe("Whether to turn off cudnn in dropout operator."); } }; // struct DropoutParam @@ -216,6 +219,7 @@ class DropoutOp { this->mode_ = static_cast(param.mode); this->axes_ = param.axes; #if MXNET_USE_CUDNN == 1 + this->cudnn_off_ = param.cudnn_off && param.cudnn_off.value(); this->ctx_ = ctx; if (ctx.dev_type == kGPU && this->pkeep_ > 0) { init_cudnn_ = false; @@ -361,9 +365,11 @@ class DropoutOp { if (this->axes_.ndim() == 0) { // standard case for dropout #if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) - if (this->pkeep_ > 0) { // existing dropout produces inf with pkeep=0 + if (this->pkeep_ > 0 && !this->cudnn_off_) { CuDNNForward(ctx, in_data[dropout::kData], out); } else { + // existing dropout produces inf with pkeep=0, + // thus revert to existing GPU kernel for consistency. LaunchRNG(s, pgen, out.Size(), out.dptr(), mask.dptr(), @@ -440,7 +446,7 @@ class DropoutOp { // standard case for dropout CHECK_EQ(grad.Size(), mask.Size()); #if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) - if (this->pkeep_ > 0) { + if (this->pkeep_ > 0 && !this->cudnn_off_) { CuDNNBackward(ctx, grad, gdata); } else { MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { @@ -499,6 +505,7 @@ class DropoutOp { dropout::DropoutOpMode mode_; TShape axes_; #if MXNET_USE_CUDNN == 1 + bool cudnn_off_; Context ctx_; cudnnDataType_t dtype_; cudnnDropoutDescriptor_t dropout_desc_; From ca834a0934f91a8569e2897207cd3ed487e81c1f Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 25 Jan 2019 16:00:12 -0800 Subject: [PATCH 4/9] refactor --- src/operator/nn/dropout-inl.h | 275 ++++++++++++------------- src/operator/nn/dropout.cc | 20 +- tests/cpp/include/test_core_op.h | 16 +- tests/python/unittest/test_operator.py | 28 ++- 4 files changed, 181 insertions(+), 158 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 1cb324c4957c..a1e554f554af 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -39,12 +39,15 @@ #include "../random/sampler.h" #include "../tensor/elemwise_binary_broadcast_op.h" -#if defined(USE_MKL) && defined(_OPENMP) +#define MXNET_USE_MKL_DROPOUT defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) +#if MXNET_USE_MKL_DROPOUT #include #include #include -#endif // USE_MKL && _OPENMP +#endif // MXNET_USE_MKL_DROPOUT + +#define MXNET_USE_CUDNN_DROPOUT MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 namespace dropout { enum DropoutOpInputs {kData}; @@ -74,14 +77,14 @@ struct DropoutParam : public dmlc::Parameter { .describe("Whether to only turn on dropout during training or to also turn on for inference."); DMLC_DECLARE_FIELD(axes).set_default(TShape()) .describe("Axes for variational dropout kernel."); - DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional(false)) + DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional(true)) .describe("Whether to turn off cudnn in dropout operator."); } }; // struct DropoutParam template class DropoutOp { -#if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) +#if MXNET_USE_MKL_DROPOUT static void BernoulliGenerate(common::random::RandGenerator gen, int n, double p, int* r) { typename RandGenerator::Impl genImpl(&gen, 1); @@ -103,70 +106,55 @@ class DropoutOp { } } } - - // MKL forward pass - static bool MSHADOW_CINLINE MKLForward(mshadow::Stream *s, RandGenerator *pgen, - const double pkeep, - const std::vector &in_data, - const std::vector &out_data) { + static inline bool MKLAvailable() { // BernoulliGenerate expects an array int, so for types smaller than int, the mask buffer // will be too small, so we can;t use MKL in those cases - if (sizeof(DType) >= sizeof(int)) { - Tensor mask = out_data[dropout::kMask].FlatTo2D(s); - Tensor data = in_data[dropout::kData].FlatTo2D(s); - Tensor out = out_data[dropout::kOut].FlatTo2D(s); - DType *outptr = out.dptr_; - DType *dataptr = data.dptr_; - auto maskptr = reinterpret_cast(mask.dptr_); - int count = mask.shape_[0] * mask.shape_[1]; - BernoulliGenerate(*pgen, count, pkeep, maskptr); - const float pk_1 = 1.0f / pkeep; + return sizeof(DType) >= sizeof(int); + } + + // MKL forward pass + inline void MKLForward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &out_data) { + Stream *s = ctx.get_stream(); + RandGenerator *pgen = ctx.requested[0].get_parallel_random(); + CHECK_NOTNULL(pgen); + Tensor mask = out_data[dropout::kMask].FlatTo2D(s); + Tensor data = in_data[dropout::kData].FlatTo2D(s); + Tensor out = out_data[dropout::kOut].FlatTo2D(s); + DType *outptr = out.dptr_; + DType *dataptr = data.dptr_; + auto maskptr = reinterpret_cast(mask.dptr_); + int count = mask.shape_[0] * mask.shape_[1]; + BernoulliGenerate(*pgen, count, this->pkeep_, maskptr); + const float pk_1 = 1.0f / this->pkeep_; #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int i = 0; i < count; ++i) { - outptr[i] = dataptr[i] * maskptr[i] * pk_1; - } - return true; + for (int i = 0; i < count; ++i) { + outptr[i] = dataptr[i] * maskptr[i] * pk_1; } - return false; } // MKL backward pass - static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream *s, const double pkeep, - const std::vector &in_grad, - const std::vector &out_data, - const std::vector &out_grad) { - if (sizeof(DType) >= sizeof(int)) { - Tensor grad = out_grad[dropout::kOut].FlatTo2D(s); - Tensor mask = out_data[dropout::kMask].FlatTo2D(s); - Tensor gdata = in_grad[dropout::kData].FlatTo2D(s); - DType *ingradptr = gdata.dptr_; - const DType *outgradptr = grad.dptr_; - auto maskptr = reinterpret_cast(mask.dptr_); - int count = mask.shape_[0] * mask.shape_[1]; - const float pk_1 = 1.0f / pkeep; + inline void MKLBackward(const OpContext &ctx, + const std::vector &in_grad, + const std::vector &out_data, + const std::vector &out_grad) { + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[dropout::kOut].FlatTo2D(s); + Tensor mask = out_data[dropout::kMask].FlatTo2D(s); + Tensor gdata = in_grad[dropout::kData].FlatTo2D(s); + DType *ingradptr = gdata.dptr_; + const DType *outgradptr = grad.dptr_; + auto maskptr = reinterpret_cast(mask.dptr_); + int count = mask.shape_[0] * mask.shape_[1]; + const float pk_1 = 1.0f / this->pkeep_; #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int i = 0; i < count; ++i) { - ingradptr[i] = outgradptr[i] * maskptr[i] * pk_1; - } - return true; + for (int i = 0; i < count; ++i) { + ingradptr[i] = outgradptr[i] * maskptr[i] * pk_1; } - return false; } -#else // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) - static bool MSHADOW_CINLINE MKLForward(mshadow::Stream *s, RandGenerator *pgen, - const double pkeep, - const std::vector &in_data, - const std::vector &out_data) { - return false; - } - static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream *s, const double pkeep, - const std::vector &in_grad, - const std::vector &out_data, - const std::vector &out_grad) { - return false; - } -#endif // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) +#endif // #if MXNET_USE_MKL_DROPOUT public: /*! @@ -218,10 +206,10 @@ class DropoutOp { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); this->axes_ = param.axes; -#if MXNET_USE_CUDNN == 1 +#if MXNET_USE_CUDNN_DROPOUT this->cudnn_off_ = param.cudnn_off && param.cudnn_off.value(); this->ctx_ = ctx; - if (ctx.dev_type == kGPU && this->pkeep_ > 0) { + if (ctx.dev_type == kGPU && this->pkeep_ > 0 && !this->cudnn_off_) { init_cudnn_ = false; dtype_ = mshadow::DataType::kCudnnFlag; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc_)); @@ -230,12 +218,12 @@ class DropoutOp { CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_desc_)); CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); } -#endif // MXNET_USE_CUDNN == 1 +#endif // MXNET_USE_CUDNN_DROPOUT } ~DropoutOp() { -#if MXNET_USE_CUDNN == 1 - if (this->ctx_.dev_type == kGPU && this->pkeep_ > 0) { +#if MXNET_USE_CUDNN_DROPOUT + if (this->ctx_.dev_type == kGPU && this->pkeep_ > 0 && !this->cudnn_off_) { CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_)); @@ -243,15 +231,19 @@ class DropoutOp { CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_)); if (init_cudnn_) { Storage::Get()->Free(dropout_states_); - Storage::Get()->Free(reserve_space_); } } -#endif // MXNET_USE_CUDNN == 1 +#endif // MXNET_USE_CUDNN_DROPOUT + } + +#if MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + inline bool CuDNNAvailable() { + return this->pkeep_ > 0 && !this->cudnn_off_; } -#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) inline void CuDNNForward(const OpContext &ctx, const TBlob &in, + const TBlob &mask, const TBlob &out) { Stream *s = ctx.get_stream(); @@ -264,6 +256,7 @@ class DropoutOp { 1.0f - this->pkeep_, dropout_states_.dptr, dropout_state_byte_, seed_)); + init_cudnn_ = true; } // describe input/output tensor @@ -289,26 +282,23 @@ class DropoutOp { // perform dropout with cudnn CUDNN_CALL(cudnnDropoutGetReserveSpaceSize(x_desc_, &dropout_reserve_byte_)); - if (init_cudnn_ && dropout_reserve_byte_ > reserve_space_.size) { - Storage::Get()->Free(reserve_space_); - init_cudnn_ = false; - } - if (!init_cudnn_) { - reserve_space_ = Storage::Get()->Alloc(dropout_reserve_byte_, Context::GPU(s->dev_id)); - init_cudnn_ = true; - } + // cudnn uses bits to record the positions that are dropped, so reserve bytes is always + // 1/8 of input size. + CHECK_GE(mask.Size() * sizeof(DType), dropout_reserve_byte_) << + "The size of the mask space is smaller than the required cudnn reserved space."; CUDNN_CALL(cudnnDropoutForward(s->dnn_handle_, dropout_desc_, x_desc_, in.dptr(), y_desc_, out.dptr(), - reserve_space_.dptr, + mask.dptr(), dropout_reserve_byte_)); } inline void CuDNNBackward(const OpContext &ctx, const TBlob &out_grad, + const TBlob &mask, const TBlob &in_grad) { Stream *s = ctx.get_stream(); @@ -340,10 +330,10 @@ class DropoutOp { out_grad.dptr(), dx_desc_, in_grad.dptr(), - reserve_space_.dptr, + mask.dptr(), dropout_reserve_byte_)); } -#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) +#endif // MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) void Forward(const OpContext &ctx, const std::vector &in_data, @@ -355,50 +345,48 @@ class DropoutOp { CHECK_EQ(out_data.size(), 2U); } Stream *s = ctx.get_stream(); + const TBlob &in = in_data[dropout::kData]; const TBlob &out = out_data[dropout::kOut]; + const TBlob &mask = out_data[dropout::kMask]; if (ctx.is_train || this->mode_ == dropout::kAlways) { - RandGenerator *pgen = ctx.requested[0].get_parallel_random(); - CHECK_NOTNULL(pgen); - if (this->axes_.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { - const TBlob &mask = out_data[dropout::kMask]; - CHECK(req[dropout::kOut] != kAddTo); - if (this->axes_.ndim() == 0) { - // standard case for dropout -#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) - if (this->pkeep_ > 0 && !this->cudnn_off_) { - CuDNNForward(ctx, in_data[dropout::kData], out); - } else { - // existing dropout produces inf with pkeep=0, - // thus revert to existing GPU kernel for consistency. - LaunchRNG(s, pgen, out.Size(), - out.dptr(), - mask.dptr(), - in_data[dropout::kData].dptr(), - this->pkeep_); - } -#else - LaunchRNG(s, pgen, out.Size(), - out.dptr(), - mask.dptr(), - in_data[dropout::kData].dptr(), - this->pkeep_); -#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + if (this->axes_.ndim() == 0) { +#if MXNET_USE_MKL_DROPOUT + if (MKLAvailable()) { + MKLForward(ctx, in_data, out_data); return; } - +#endif // MXNET_USE_MKL_DROPOUT +#if MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + if (CuDNNAvailable()) { + CuDNNForward(ctx, in, mask, out); + return; + } +#endif // MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + RandGenerator *pgen = ctx.requested[0].get_parallel_random(); + CHECK_NOTNULL(pgen); + CHECK(req[dropout::kOut] != kAddTo); + LaunchRNG(s, pgen, out.Size(), + out.dptr(), + mask.dptr(), + in.dptr(), + this->pkeep_); + return; + } else { + RandGenerator *pgen = ctx.requested[0].get_parallel_random(); + CHECK_NOTNULL(pgen); // initialize the mask LaunchRNG(s, pgen, mask.Size(), mask.dptr(), this->pkeep_); // broadcast mul TShape new_lshape, new_rshape, new_oshape; - int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, + int ndim = BinaryBroadcastShapeCompact(in.shape_, mask.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); if (!ndim) { MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { mxnet_op::Kernel, xpu>::Launch( - s, out.Size(), out.dptr(), in_data[dropout::kData].dptr(), + s, out.Size(), out.dptr(), in.dptr(), mask.dptr()); }); } else { @@ -410,21 +398,16 @@ class DropoutOp { mshadow_op::mul>, xpu>:: template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], lstride, rstride, oshape, - in_data[dropout::kData].dptr(), + in.dptr(), mask.dptr(), out.dptr()); }); } } } else { - const TBlob& data = in_data[dropout::kData]; - if (req[dropout::kOut] == kWriteTo) { - mxnet_op::copy(s, out, data); - } else { - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, out.Size(), out.dptr(), data.dptr()); - }); - } + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), in.dptr()); + }); } } } @@ -438,30 +421,30 @@ class DropoutOp { using namespace mshadow::expr; Stream *s = ctx.get_stream(); if (ctx.is_train || mode_ == dropout::kAlways) { - if (this->axes_.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { - const TBlob &gdata = in_grad[dropout::kData]; - const TBlob &grad = out_grad[dropout::kOut]; - const TBlob &mask = out_data[dropout::kMask]; - if (this->axes_.ndim() == 0) { - // standard case for dropout - CHECK_EQ(grad.Size(), mask.Size()); -#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) - if (this->pkeep_ > 0 && !this->cudnn_off_) { - CuDNNBackward(ctx, grad, gdata); - } else { - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); - }); - } -#else - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); - }); -#endif // MXNET_USE_CUDNN == 1 & defined(__CUDACC__) + const TBlob &gdata = in_grad[dropout::kData]; + const TBlob &grad = out_grad[dropout::kOut]; + const TBlob &mask = out_data[dropout::kMask]; + if (this->axes_.ndim() == 0) { +#if MXNET_USE_MKL_DROPOUT + if (MKLAvailable()) { + MKLBackward(ctx, in_grad, out_data, out_grad); + return; + } +#endif // MXNET_USE_MKL_DROPOUT +#if MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + if (CuDNNAvailable()) { + CuDNNBackward(ctx, grad, mask, gdata); return; } +#endif // MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + // standard case for dropout + CHECK_EQ(grad.Size(), mask.Size()); + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); + }); + return; + } else { // broardcast mul TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(grad.shape_, @@ -487,14 +470,10 @@ class DropoutOp { } else { const TBlob& gdata = in_grad[dropout::kData]; const TBlob& grad = out_grad[dropout::kOut]; - if (req[dropout::kData] == kWriteTo) { - mxnet_op::copy(s, gdata, grad); - } else { - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, gdata.Size(), gdata.dptr(), grad.dptr()); - }); - } + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, gdata.Size(), gdata.dptr(), grad.dptr()); + }); } } @@ -504,7 +483,7 @@ class DropoutOp { /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; TShape axes_; -#if MXNET_USE_CUDNN == 1 +#if MXNET_USE_CUDNN_DROPOUT bool cudnn_off_; Context ctx_; cudnnDataType_t dtype_; @@ -512,9 +491,9 @@ class DropoutOp { bool init_cudnn_; uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) size_t dropout_state_byte_, dropout_reserve_byte_; - Storage::Handle dropout_states_, reserve_space_; + Storage::Handle dropout_states_; cudnnTensorDescriptor_t x_desc_, y_desc_, dx_desc_, dy_desc_; -#endif // MXNET_USE_CUDNN == 1 +#endif // MXNET_USE_CUDNN_DROPOUT }; // class DropoutOp static OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 27cecee906a0..1862a2e96a2f 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -125,9 +125,23 @@ Example:: .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) -.set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::kParallelRandom}; -}) +.set_attr("FResourceRequestEx", + [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { + std::vector request; + if (dev_mask == kGPU) { +#if MXNET_USE_CUDNN_DROPOUT + const DropoutParam& param = nnvm::get(attrs.parsed); + // if cudnn is used, parallel random is not needed. + if (1.0f - param.p > 0 + && !(param.cudnn_off && param.cudnn_off.value()) + && param.axes.ndim() == 0) { + return request; + } +#endif + } + request.emplace_back(ResourceRequest::kParallelRandom); + return request; + }) .add_argument("data", "NDArray-or-Symbol", "Input array to which dropout will be applied.") .add_arguments(DropoutParam::__FIELDS__()); diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index 7fde28947af8..224bcff4b190 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -168,10 +168,22 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer * \param op Pointer to nnvm Operator object */ void AttachResources(OpContext *ctx, const nnvm::NodeAttrs& attrs, const nnvm::Op *op) { + std::vector reqs; + std::vector& requested = ctx->requested; static auto& fresource = nnvm::Op::GetAttr("FResourceRequest"); if (fresource.count(op) != 0) { - std::vector& requested = ctx->requested; - auto reqs = fresource[op](attrs); + reqs = fresource[op](attrs); + } else { + static auto& fresourceex = nnvm::Op::GetAttr("FResourceRequestEx"); + if (fresourceex.count(op) != 0) { + if (this->function_ || this->stateful_function_) { + reqs = fresourceex[op](attrs, ctx->run_ctx.ctx.dev_mask(), DispatchMode::kFCompute); + } else { + reqs = fresourceex[op](attrs, ctx->run_ctx.ctx.dev_mask(), DispatchMode::kFComputeEx); + } + } + } + if (!reqs.empty()) { // Get the resource of temporal space. for (const ResourceRequest& req : reqs) { if (req.type == ResourceRequest::kTempSpace) { diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 670cc7eb15e0..4cddcfd2741c 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5909,10 +5909,10 @@ def check_correctness(executor, input, ratio): elif ratio == 0: assert output_zeroes == 0 - def check_dropout_ratio(ratio, shape): + def check_dropout_ratio(ratio, shape, cudnn_off=True): # test dropout x = mx.sym.var('data') - y = mx.sym.Dropout(x, p=ratio) + y = mx.sym.Dropout(x, p=ratio, cudnn_off=cudnn_off) exe = y.simple_bind(ctx=default_context(), data=shape) if ratio == 1: @@ -5950,7 +5950,7 @@ def check_dropout_ratio(ratio, shape): # test permanent dropout x = mx.sym.var('data') - y = mx.sym.Dropout(x, p=ratio, mode='always') + y = mx.sym.Dropout(x, p=ratio, mode='always', cudnn_off=cudnn_off) exe = y.simple_bind(ctx=default_context(), data=shape) exe.arg_arrays[0][:] = 1 @@ -5975,13 +5975,13 @@ def get_slice(x, axis, idx): ix += (slice(None, None, None),) return x[ix] - def check_dropout_axes(ratio, shape, axes): + def check_dropout_axes(ratio, shape, axes, cudnn_off=True): compactshape = list(shape) for axis in axes: compactshape[axis] = 1 compactx = mx.random.uniform(shape=tuple(compactshape)) broadcastx = compactx.broadcast_to(shape) - dropouty = mx.nd.Dropout(broadcastx, p=ratio, axes=axes) + dropouty = mx.nd.Dropout(broadcastx, p=ratio, axes=axes, cudnn_off=cudnn_off) for axis in axes: target = get_slice(dropouty, axis, 0).asnumpy() for i in range(1, shape[axis]): @@ -5993,6 +5993,11 @@ def check_dropout_axes(ratio, shape, axes): check_dropout_ratio(1.0, shape) check_dropout_ratio(0.75, shape) check_dropout_ratio(0.25, shape) + check_dropout_ratio(0.5, shape, cudnn_off=False) + check_dropout_ratio(0.0, shape, cudnn_off=False) + check_dropout_ratio(1.0, shape, cudnn_off=False) + check_dropout_ratio(0.75, shape, cudnn_off=False) + check_dropout_ratio(0.25, shape, cudnn_off=False) nshape = (10, 10, 10, 10) with mx.autograd.train_mode(): @@ -6009,6 +6014,19 @@ def check_dropout_axes(ratio, shape, axes): check_dropout_axes(0.25, nshape, axes = (0, 1, 2)) check_dropout_axes(0.25, nshape, axes = (0, 2, 3)) check_dropout_axes(0.25, nshape, axes = (1, 2, 3)) + check_dropout_axes(0.25, nshape, axes = (0,), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (1,), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (2,), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (3,), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 1), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 2), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 3), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (1, 2), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (1, 3), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (2, 3), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 1, 2), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 2, 3), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (1, 2, 3), cudnn_off=False) @unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11290") From cb2a2b45aa9a4a83a91b5330292aea5247f7421c Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Mon, 28 Jan 2019 12:54:12 -0800 Subject: [PATCH 5/9] fix bug when using inf forward --- src/operator/nn/dropout-inl.h | 7 ++++++- tests/python/unittest/test_operator.py | 17 +++++++++++++++++ 2 files changed, 23 insertions(+), 1 deletion(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index a1e554f554af..22e135846235 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -206,6 +206,7 @@ class DropoutOp { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); this->axes_ = param.axes; + this->dropout_passthrough_ = true; #if MXNET_USE_CUDNN_DROPOUT this->cudnn_off_ = param.cudnn_off && param.cudnn_off.value(); this->ctx_ = ctx; @@ -339,6 +340,7 @@ class DropoutOp { const std::vector &in_data, const std::vector &req, const std::vector &out_data) { + this->dropout_passthrough_ = true; if (req[dropout::kOut] != kNullOp) { CHECK_EQ(in_data.size(), 1U); if (ctx.is_train) { @@ -349,6 +351,7 @@ class DropoutOp { const TBlob &out = out_data[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; if (ctx.is_train || this->mode_ == dropout::kAlways) { + this->dropout_passthrough_ = false; if (this->axes_.ndim() == 0) { #if MXNET_USE_MKL_DROPOUT if (MKLAvailable()) { @@ -420,7 +423,8 @@ class DropoutOp { using namespace mshadow; using namespace mshadow::expr; Stream *s = ctx.get_stream(); - if (ctx.is_train || mode_ == dropout::kAlways) { + if (!this->dropout_passthrough_) { + this->dropout_passthrough_ = true; const TBlob &gdata = in_grad[dropout::kData]; const TBlob &grad = out_grad[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; @@ -483,6 +487,7 @@ class DropoutOp { /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; TShape axes_; + bool dropout_passthrough_; #if MXNET_USE_CUDNN_DROPOUT bool cudnn_off_; Context ctx_; diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 4cddcfd2741c..d75f3e938454 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5987,6 +5987,15 @@ def check_dropout_axes(ratio, shape, axes, cudnn_off=True): for i in range(1, shape[axis]): assert(get_slice(dropouty, axis, i).asnumpy() == target).all() + def check_passthrough(ratio, shape, cudnn_off=True): + # test inference_mode forward and then backward + a = mx.random.uniform(shape=shape) + a.attach_grad() + with mx.autograd.record(train_mode=False): + b = mx.nd.Dropout(a, ratio, cudnn_off=cudnn_off) # dropout acts as identity + b.backward() + assert_almost_equal(a.grad.asnumpy(), mx.nd.ones_like(b).asnumpy()) + shape = (100, 100) check_dropout_ratio(0.5, shape) check_dropout_ratio(0.0, shape) @@ -5999,6 +6008,13 @@ def check_dropout_axes(ratio, shape, axes, cudnn_off=True): check_dropout_ratio(0.75, shape, cudnn_off=False) check_dropout_ratio(0.25, shape, cudnn_off=False) + check_passthrough(0.5, shape) + check_passthrough(0.0, shape) + check_passthrough(1.0, shape) + check_passthrough(0.5, shape, cudnn_off=False) + check_passthrough(0.0, shape, cudnn_off=False) + check_passthrough(1.0, shape, cudnn_off=False) + nshape = (10, 10, 10, 10) with mx.autograd.train_mode(): check_dropout_axes(0.25, nshape, axes = (0,)) @@ -6029,6 +6045,7 @@ def check_dropout_axes(ratio, shape, axes, cudnn_off=True): check_dropout_axes(0.25, nshape, axes = (1, 2, 3), cudnn_off=False) + @unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11290") @with_seed() def test_scatter_gather_nd(): From 7f765821fc3a8960c36789bb932ab5f4fc54150e Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Mon, 28 Jan 2019 15:14:53 -0800 Subject: [PATCH 6/9] turn on cudnn in gluon --- python/mxnet/gluon/nn/basic_layers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index 4d514c28317a..ace814275d61 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -262,7 +262,7 @@ def __init__(self, rate, axes=(), **kwargs): self._axes = axes def hybrid_forward(self, F, x): - return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd') + return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) def __repr__(self): s = '{name}(p = {_rate}, axes={_axes})' From 1ad98ff8c950e188d91d7ce5b387457213c9d294 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Tue, 29 Jan 2019 22:06:48 -0800 Subject: [PATCH 7/9] reuse dropout state space --- include/mxnet/resource.h | 20 ++++++++ src/executor/attach_op_resource_pass.cc | 39 ++++++++++----- src/imperative/imperative_utils.h | 6 +++ src/operator/nn/dropout-inl.h | 22 ++------- src/operator/nn/dropout.cc | 1 + src/resource.cc | 65 ++++++++++++++++++++++--- tests/cpp/include/test_core_op.h | 38 ++++++++++----- tests/cpp/include/test_legacy_op.h | 47 +++++++++++------- 8 files changed, 171 insertions(+), 67 deletions(-) diff --git a/include/mxnet/resource.h b/include/mxnet/resource.h index 67c14b66abdd..34c8f88d1ca9 100644 --- a/include/mxnet/resource.h +++ b/include/mxnet/resource.h @@ -44,6 +44,11 @@ struct ResourceRequest { kTempSpace, /*! \brief common::RandGenerator object, which can be used in GPU kernel functions */ kParallelRandom +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + , + /*! \brief cudnnDropoutDescriptor_t object for GPU dropout kernel functions */ + kCuDNNDropoutDesc +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 }; /*! \brief type of resources */ Type type; @@ -157,6 +162,21 @@ struct Resource { reinterpret_cast(get_space_internal(shape.Size() * sizeof(DType))), shape, shape[ndim - 1], stream); } +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + /*! + * \brief Get cudnn dropout descriptor from shared state space. + * + * \param dropout_desc reference to previously created cudnn dropout descriptor. + * \param stream the stream of retruning tensor. + * \return the mshadow tensor requested. + */ + void get_cudnn_dropout_desc( + cudnnDropoutDescriptor_t* dropout_desc, + mshadow::Stream *stream, + const float dropout, + uint64_t seed) const; +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + /*! * \brief Get CPU space as mshadow Tensor in specified type. * The caller can request arbitrary size. diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc index 56122cda6ff0..76ac4ec56fd6 100644 --- a/src/executor/attach_op_resource_pass.cc +++ b/src/executor/attach_op_resource_pass.cc @@ -64,20 +64,33 @@ void AttachOpResources( : fresource[op](inode.source->attrs); // Get the resource of temporal space. for (const ResourceRequest& req : reqs) { - if (req.type == ResourceRequest::kTempSpace) { - if (cached_temp.count(ctx) != 0) { - requested.push_back(cached_temp.at(ctx)); - } else { - Resource r = ResourceManager::Get()->Request(ctx, req); - requested.push_back(r); - cached_temp[ctx] = r; + switch (req.type) { + case ResourceRequest::kTempSpace: { + if (cached_temp.count(ctx) != 0) { + requested.push_back(cached_temp.at(ctx)); + } else { + Resource r = ResourceManager::Get()->Request(ctx, req); + requested.push_back(r); + cached_temp[ctx] = r; + } + break; } - } else if (req.type == ResourceRequest::kRandom) { - requested.push_back(ResourceManager::Get()->Request(ctx, req)); - } else if (req.type == ResourceRequest::kParallelRandom) { - requested.push_back(ResourceManager::Get()->Request(ctx, req)); - } else { - LOG(FATAL) << "resource type not yet supported"; + case ResourceRequest::kRandom: { + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + break; + } + case ResourceRequest::kParallelRandom: { + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + break; + } +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + case ResourceRequest::kCuDNNDropoutDesc: { + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + break; + } +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + default: + LOG(FATAL) << "resource type " << req.type << " is not yet supported"; } } CHECK(vdispatch[nid] != DispatchMode::kUndefined); diff --git a/src/imperative/imperative_utils.h b/src/imperative/imperative_utils.h index 4b0d13167356..6d4956228970 100644 --- a/src/imperative/imperative_utils.h +++ b/src/imperative/imperative_utils.h @@ -241,6 +241,12 @@ inline void SetDependency(const nnvm::NodeAttrs& attrs, requested.push_back(ResourceManager::Get()->Request(ctx, req)); write_vars.push_back(requested.back().var); break; +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + case ResourceRequest::kCuDNNDropoutDesc: + requested.push_back(ResourceManager::Get()->Request(ctx, req)); + write_vars.push_back(requested.back().var); + break; +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 default: LOG(FATAL) << "resource type not yet supported"; } diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 22e135846235..e8a808350eb5 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -78,7 +78,8 @@ struct DropoutParam : public dmlc::Parameter { DMLC_DECLARE_FIELD(axes).set_default(TShape()) .describe("Axes for variational dropout kernel."); DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional(true)) - .describe("Whether to turn off cudnn in dropout operator."); + .describe("Whether to turn off cudnn in dropout operator. " + "This option is ignored if axes is specified."); } }; // struct DropoutParam @@ -211,7 +212,6 @@ class DropoutOp { this->cudnn_off_ = param.cudnn_off && param.cudnn_off.value(); this->ctx_ = ctx; if (ctx.dev_type == kGPU && this->pkeep_ > 0 && !this->cudnn_off_) { - init_cudnn_ = false; dtype_ = mshadow::DataType::kCudnnFlag; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc_)); CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc_)); @@ -230,9 +230,6 @@ class DropoutOp { CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_)); CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_)); - if (init_cudnn_) { - Storage::Get()->Free(dropout_states_); - } } #endif // MXNET_USE_CUDNN_DROPOUT } @@ -249,16 +246,7 @@ class DropoutOp { Stream *s = ctx.get_stream(); // set dropout state. - // TODO(szha): expensive call, should be cached and reused across operators. - if (!init_cudnn_) { - CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_state_byte_)); - dropout_states_ = Storage::Get()->Alloc(dropout_state_byte_, Context::GPU(s->dev_id)); - CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_, - 1.0f - this->pkeep_, - dropout_states_.dptr, dropout_state_byte_, - seed_)); - init_cudnn_ = true; - } + ctx.requested[0].get_cudnn_dropout_desc(&dropout_desc_, s, 1.0f - this->pkeep_, seed_); // describe input/output tensor int dim[4], stride[4]; @@ -493,10 +481,8 @@ class DropoutOp { Context ctx_; cudnnDataType_t dtype_; cudnnDropoutDescriptor_t dropout_desc_; - bool init_cudnn_; uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) - size_t dropout_state_byte_, dropout_reserve_byte_; - Storage::Handle dropout_states_; + size_t dropout_reserve_byte_; cudnnTensorDescriptor_t x_desc_, y_desc_, dx_desc_, dy_desc_; #endif // MXNET_USE_CUDNN_DROPOUT }; // class DropoutOp diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 1862a2e96a2f..3205fe9fb320 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -135,6 +135,7 @@ Example:: if (1.0f - param.p > 0 && !(param.cudnn_off && param.cudnn_off.value()) && param.axes.ndim() == 0) { + request.emplace_back(ResourceRequest::kCuDNNDropoutDesc); return request; } #endif diff --git a/src/resource.cc b/src/resource.cc index ba4ab7270bdb..80a5c0e444e1 100644 --- a/src/resource.cc +++ b/src/resource.cc @@ -34,6 +34,7 @@ #include #include "./common/lazy_alloc_array.h" #include "./common/utils.h" +#include "./common/cuda_utils.h" namespace mxnet { namespace resource { @@ -92,11 +93,14 @@ class ResourceManagerImpl : public ResourceManager { gpu_temp_space_copy_ = dmlc::GetEnv("MXNET_GPU_TEMP_COPY", 1); cpu_native_rand_copy_ = dmlc::GetEnv("MXNET_CPU_PARALLEL_RAND_COPY", 1); gpu_native_rand_copy_ = dmlc::GetEnv("MXNET_GPU_PARALLEL_RAND_COPY", 4); +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + gpu_cudnn_dropout_state_copy_ = dmlc::GetEnv("MXNET_GPU_CUDNN_DROPOUT_STATE_COPY", 4); +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 engine_ref_ = Engine::_GetSharedRef(); storage_ref_ = Storage::_GetSharedRef(); cpu_rand_.reset(new ResourceRandom( Context::CPU(), global_seed_)); - cpu_space_.reset(new ResourceTempSpace( + cpu_space_.reset(new ResourceTempSpace( Context::CPU(), cpu_temp_space_copy_)); cpu_parallel_rand_.reset(new ResourceParallelRandom( Context::CPU(), cpu_native_rand_copy_, global_seed_)); @@ -110,6 +114,9 @@ class ResourceManagerImpl : public ResourceManager { gpu_rand_.Clear(); gpu_space_.Clear(); gpu_parallel_rand_.Clear(); +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + gpu_cudnn_dropout_state_.Clear(); +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 #endif if (engine_ref_ != nullptr) { engine_ref_ = nullptr; @@ -139,7 +146,7 @@ class ResourceManagerImpl : public ResourceManager { } case ResourceRequest::kTempSpace: { return gpu_space_.Get(ctx.dev_id, [ctx, this]() { - return new ResourceTempSpace(ctx, gpu_temp_space_copy_); + return new ResourceTempSpace(ctx, gpu_temp_space_copy_); })->GetNext(); } case ResourceRequest::kParallelRandom: { @@ -147,6 +154,14 @@ class ResourceManagerImpl : public ResourceManager { return new ResourceParallelRandom(ctx, gpu_native_rand_copy_, global_seed_); })->GetNext(); } +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + case ResourceRequest::kCuDNNDropoutDesc: { + return gpu_cudnn_dropout_state_.Get(ctx.dev_id, [ctx, this]() { + return new ResourceTempSpace( + ctx, gpu_cudnn_dropout_state_copy_); + })->GetNext(); + } +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 default: LOG(FATAL) << "Unknown supported type " << req.type; } #else @@ -231,7 +246,8 @@ class ResourceManagerImpl : public ResourceManager { } }; - // temporal space resource. + // temporary space resource. + template struct ResourceTempSpace { /*! \brief the context of the device */ Context ctx; @@ -248,7 +264,7 @@ class ResourceManagerImpl : public ResourceManager { resource[i].var = Engine::Get()->NewVariable(); resource[i].id = static_cast(i); resource[i].ptr_ = &space[i]; - resource[i].req = ResourceRequest(ResourceRequest::kTempSpace); + resource[i].req = ResourceRequest(req); space[i].ctx = ctx; CHECK_EQ(space[i].handle.size, 0U); } @@ -372,16 +388,23 @@ class ResourceManagerImpl : public ResourceManager { /*! \brief CPU random number resources */ std::unique_ptr > cpu_rand_; /*! \brief CPU temp space resources */ - std::unique_ptr cpu_space_; + std::unique_ptr> cpu_space_; /*! \brief CPU parallel random number resources */ std::unique_ptr > cpu_parallel_rand_; #if MXNET_USE_CUDA /*! \brief random number generator for GPU */ common::LazyAllocArray > gpu_rand_; /*! \brief temp space for GPU */ - common::LazyAllocArray gpu_space_; + common::LazyAllocArray> gpu_space_; /*! \brief GPU parallel (on device) random number resources */ common::LazyAllocArray > gpu_parallel_rand_; +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + /*! \brief number of copies in GPU cudnn dropout descriptor resources */ + int gpu_cudnn_dropout_state_copy_; + /*! \brief GPU parallel (on device) random number resources */ + common::LazyAllocArray> + gpu_cudnn_dropout_state_; +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 #endif }; } // namespace resource @@ -394,6 +417,36 @@ void* Resource::get_host_space_internal(size_t size) const { return static_cast(ptr_)->GetHostSpace(size); } +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 +void Resource::get_cudnn_dropout_desc( + cudnnDropoutDescriptor_t* dropout_desc, + mshadow::Stream *stream, + const float dropout, + uint64_t seed) const { + + CHECK_EQ(req.type, ResourceRequest::kCuDNNDropoutDesc); + auto state_space = static_cast(ptr_); + CHECK_EQ(state_space->ctx.dev_id, stream->dev_id) + << "The device id of cudnn dropout state space doesn't match that from stream."; + if (!state_space->handle.size) { + // not initialized yet. + size_t dropout_state_size; + CUDNN_CALL(cudnnDropoutGetStatesSize(stream->dnn_handle_, &dropout_state_size)); + CUDNN_CALL(cudnnSetDropoutDescriptor(*dropout_desc, stream->dnn_handle_, + dropout, + state_space->GetSpace(dropout_state_size), + dropout_state_size, + seed)); + } else { + CUDNN_CALL(cudnnRestoreDropoutDescriptor(*dropout_desc, stream->dnn_handle_, + dropout, + state_space->handle.dptr, + state_space->handle.size, + seed)); + } +} +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + ResourceManager* ResourceManager::Get() { typedef dmlc::ThreadLocalStore inst; return inst::Get(); diff --git a/tests/cpp/include/test_core_op.h b/tests/cpp/include/test_core_op.h index 224bcff4b190..f59a9e8d74dc 100644 --- a/tests/cpp/include/test_core_op.h +++ b/tests/cpp/include/test_core_op.h @@ -186,20 +186,32 @@ class CoreOpExecutor : public test::op::OperatorDataInitializer if (!reqs.empty()) { // Get the resource of temporal space. for (const ResourceRequest& req : reqs) { - if (req.type == ResourceRequest::kTempSpace) { - Resource r = ResourceManager::Get()->Request(ctx->run_ctx.ctx, req); - requested.emplace_back(r); - } else if (req.type == ResourceRequest::kRandom) { - requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req)); - } else if (req.type == ResourceRequest::kParallelRandom) { - Resource rm = ResourceManager::Get()->Request(ctx->run_ctx.ctx, req); - if (ctx->run_ctx.ctx.dev_mask() == Context::kCPU) { - common::random::RandGenerator::AllocState( - rm.get_parallel_random()); + switch (req.type) { + case ResourceRequest::kTempSpace: { + requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req)); + break; } - requested.emplace_back(rm); - } else { - LOG(FATAL) << "resource type not yet supported"; + case ResourceRequest::kRandom: { + requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req)); + break; + } + case ResourceRequest::kParallelRandom: { + Resource rm = ResourceManager::Get()->Request(ctx->run_ctx.ctx, req); + if (ctx->run_ctx.ctx.dev_mask() == Context::kCPU) { + common::random::RandGenerator::AllocState( + rm.get_parallel_random()); + } + requested.emplace_back(rm); + break; + } +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + case ResourceRequest::kCuDNNDropoutDesc: { + requested.emplace_back(ResourceManager::Get()->Request(ctx->run_ctx.ctx, req)); + break; + } +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + default: + LOG(FATAL) << "resource type " << req.type << " is not yet supported"; } } } diff --git a/tests/cpp/include/test_legacy_op.h b/tests/cpp/include/test_legacy_op.h index 3e395a4cd0ec..7fd407e39807 100644 --- a/tests/cpp/include/test_legacy_op.h +++ b/tests/cpp/include/test_legacy_op.h @@ -494,25 +494,38 @@ class LegacyOperatorExecutor : public OperatorDataInitializer ctx.dev_id = 0; for (const ResourceRequest& req : reqs) { - if (req.type == ResourceRequest::kTempSpace) { - if (cached_temp.count(ctx) != 0) { - opContext_.requested.emplace_back(cached_temp.at(ctx)); - } else { - Resource r = ResourceManager::Get()->Request(ctx, req); - opContext_.requested.emplace_back(r); - cached_temp[ctx] = r; + switch (req.type) { + case ResourceRequest::kTempSpace: { + if (cached_temp.count(ctx) != 0) { + opContext_.requested.emplace_back(cached_temp.at(ctx)); + } else { + Resource r = ResourceManager::Get()->Request(ctx, req); + opContext_.requested.emplace_back(r); + cached_temp[ctx] = r; + } + break; + } + case ResourceRequest::kRandom: { + opContext_.requested.emplace_back(ResourceManager::Get()->Request(ctx, req)); + break; + } + case ResourceRequest::kParallelRandom: { + Resource rm = ResourceManager::Get()->Request(ctx, req); + if (ctx.dev_mask() == Context::kCPU) { + common::random::RandGenerator::AllocState( + rm.get_parallel_random()); + } + opContext_.requested.emplace_back(rm); + break; } - } else if (req.type == ResourceRequest::kRandom) { - opContext_.requested.emplace_back(ResourceManager::Get()->Request(ctx, req)); - } else if (req.type == ResourceRequest::kParallelRandom) { - Resource rm = ResourceManager::Get()->Request(ctx, req); - if (ctx.dev_mask() == Context::kCPU) { - common::random::RandGenerator::AllocState( - rm.get_parallel_random()); +#if MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + case ResourceRequest::kCuDNNDropoutDesc: { + opContext_.requested.push_back(ResourceManager::Get()->Request(ctx, req)); + break; } - opContext_.requested.emplace_back(rm); - } else { - LOG(FATAL) << "resource type not yet supported"; +#endif // MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 + default: + LOG(FATAL) << "resource type " << req.type << " is not yet supported"; } } } From 2c1ea12472621d4c7dd2047affe563878d73f358 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Wed, 30 Jan 2019 21:53:53 -0800 Subject: [PATCH 8/9] dropout passthrough --- python/mxnet/gluon/nn/basic_layers.py | 5 ++++- src/operator/nn/dropout-inl.h | 2 +- src/operator/nn/dropout.cc | 3 ++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index ace814275d61..f8566dd05aa5 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -262,7 +262,10 @@ def __init__(self, rate, axes=(), **kwargs): self._axes = axes def hybrid_forward(self, F, x): - return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) + if self._rate > 0: + return F.Dropout(x, p=self._rate, axes=self._axes, name='fwd', cudnn_off=False) + else: + return F.identity(x) def __repr__(self): s = '{name}(p = {_rate}, axes={_axes})' diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index e8a808350eb5..55fb03283e55 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -338,7 +338,7 @@ class DropoutOp { const TBlob &in = in_data[dropout::kData]; const TBlob &out = out_data[dropout::kOut]; const TBlob &mask = out_data[dropout::kMask]; - if (ctx.is_train || this->mode_ == dropout::kAlways) { + if (this->pkeep_ < 1 && (ctx.is_train || this->mode_ == dropout::kAlways)) { this->dropout_passthrough_ = false; if (this->axes_.ndim() == 0) { #if MXNET_USE_MKL_DROPOUT diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 3205fe9fb320..d6cbeb4e561d 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -128,9 +128,10 @@ Example:: .set_attr("FResourceRequestEx", [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { std::vector request; + const DropoutParam& param = nnvm::get(attrs.parsed); + if (param.p == 0) return request; if (dev_mask == kGPU) { #if MXNET_USE_CUDNN_DROPOUT - const DropoutParam& param = nnvm::get(attrs.parsed); // if cudnn is used, parallel random is not needed. if (1.0f - param.p > 0 && !(param.cudnn_off && param.cudnn_off.value()) From ba725eb8ba31dad84caeb536afb6d41d2468d62f Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Mon, 4 Feb 2019 20:07:38 -0800 Subject: [PATCH 9/9] address comments --- docs/faq/env_var.md | 20 ++++++++++++++++++++ src/executor/attach_op_resource_pass.cc | 1 + src/operator/nn/dropout-inl.h | 4 ++++ 3 files changed, 25 insertions(+) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index 99ebae21d61f..83368bf4d0c3 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -227,6 +227,26 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca - Maximum value is 60. - This variable controls how many weights will be updated in a single call to optimizer (for optimizers that support aggregation, currently limited to SGD). +* MXNET_CPU_TEMP_COPY + - Values: Int ```(default=4)``` + - This variable controls how many temporary memory resources to create for all CPU context for use in operator. + +* MXNET_GPU_TEMP_COPY + - Values: Int ```(default=1)``` + - This variable controls how many temporary memory resources to create for each GPU context for use in operator. + +* MXNET_CPU_PARALLEL_RAND_COPY + - Values: Int ```(default=1)``` + - This variable controls how many parallel random number generator resources to create for all CPU context for use in operator. + +* MXNET_GPU_PARALLEL_RAND_COPY + - Values: Int ```(default=4)``` + - This variable controls how many parallel random number generator resources to create for each GPU context for use in operator. + +* MXNET_GPU_CUDNN_DROPOUT_STATE_COPY + - Values: Int ```(default=4)``` + - This variable controls how many CuDNN dropout state resources to create for each GPU context for use in operator. + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` diff --git a/src/executor/attach_op_resource_pass.cc b/src/executor/attach_op_resource_pass.cc index 76ac4ec56fd6..939adbbcedd3 100644 --- a/src/executor/attach_op_resource_pass.cc +++ b/src/executor/attach_op_resource_pass.cc @@ -66,6 +66,7 @@ void AttachOpResources( for (const ResourceRequest& req : reqs) { switch (req.type) { case ResourceRequest::kTempSpace: { + // the scope is needed when there's new declaration of variable. if (cached_temp.count(ctx) != 0) { requested.push_back(cached_temp.at(ctx)); } else { diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 55fb03283e55..2a828994fb44 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -474,7 +474,9 @@ class DropoutOp { real_t pkeep_; /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; + /*! \brief Axes on which dropout mask is shared in the form of broadcast multiply */ TShape axes_; + /*! \brief Flag to record whether forward is executed in pass-through mode */ bool dropout_passthrough_; #if MXNET_USE_CUDNN_DROPOUT bool cudnn_off_; @@ -539,4 +541,6 @@ void DropoutGradCompute(const OpStatePtr& state, } // namespace op } // namespace mxnet + +#undef MXNET_USE_MKL_DROPOUT #endif // MXNET_OPERATOR_NN_DROPOUT_INL_H_