Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
add cudnn_off
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jan 28, 2019
1 parent abed0b1 commit bc0927e
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
float p;
int mode;
TShape axes;
dmlc::optional<bool> cudnn_off;
DMLC_DECLARE_PARAMETER(DropoutParam) {
DMLC_DECLARE_FIELD(p).set_default(0.5)
.set_range(0, 1)
Expand All @@ -73,6 +74,8 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {
.describe("Whether to only turn on dropout during training or to also turn on for inference.");
DMLC_DECLARE_FIELD(axes).set_default(TShape())
.describe("Axes for variational dropout kernel.");
DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional<bool>(false))
.describe("Whether to turn off cudnn in dropout operator.");
}
}; // struct DropoutParam

Expand Down Expand Up @@ -216,6 +219,7 @@ class DropoutOp {
this->mode_ = static_cast<dropout::DropoutOpMode>(param.mode);
this->axes_ = param.axes;
#if MXNET_USE_CUDNN == 1
this->cudnn_off_ = param.cudnn_off && param.cudnn_off.value();
this->ctx_ = ctx;
if (ctx.dev_type == kGPU && this->pkeep_ > 0) {
init_cudnn_ = false;
Expand Down Expand Up @@ -361,9 +365,11 @@ class DropoutOp {
if (this->axes_.ndim() == 0) {
// standard case for dropout
#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
if (this->pkeep_ > 0) { // existing dropout produces inf with pkeep=0
if (this->pkeep_ > 0 && !this->cudnn_off_) {
CuDNNForward(ctx, in_data[dropout::kData], out);
} else {
// existing dropout produces inf with pkeep=0,
// thus revert to existing GPU kernel for consistency.
LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(),
out.dptr<DType>(),
mask.dptr<DType>(),
Expand Down Expand Up @@ -440,7 +446,7 @@ class DropoutOp {
// standard case for dropout
CHECK_EQ(grad.Size(), mask.Size());
#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
if (this->pkeep_ > 0) {
if (this->pkeep_ > 0 && !this->cudnn_off_) {
CuDNNBackward(ctx, grad, gdata);
} else {
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
Expand Down Expand Up @@ -499,6 +505,7 @@ class DropoutOp {
dropout::DropoutOpMode mode_;
TShape axes_;
#if MXNET_USE_CUDNN == 1
bool cudnn_off_;
Context ctx_;
cudnnDataType_t dtype_;
cudnnDropoutDescriptor_t dropout_desc_;
Expand Down

0 comments on commit bc0927e

Please sign in to comment.