Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
cudnn dropout
Browse files Browse the repository at this point in the history
  • Loading branch information
szha committed Jan 16, 2019
1 parent 6a4bac6 commit 03453c4
Show file tree
Hide file tree
Showing 4 changed files with 181 additions and 39 deletions.
4 changes: 2 additions & 2 deletions src/operator/cudnn_rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@ class CuDNNRNNOp : public Operator {
if (param_.p > 0) {
CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_byte_));
dropout_size_ = dropout_byte_ / sizeof(DType);
dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU());
dropout_states_ = Storage::Get()->Alloc(dropout_byte_, Context::GPU(s->dev_id));
} else {
dropout_states_ = {};
dropout_byte_ = 0;
Expand Down Expand Up @@ -764,7 +764,7 @@ class CuDNNRNNOp : public Operator {
&reserve_space_byte_));
workspace_size_ = workspace_byte_ / sizeof(DType);
// Allocate the reserve space
reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU());
reserve_space_ = Storage::Get()->Alloc(reserve_space_byte_, Context::GPU(s->dev_id));

// Check that number of params are correct
size_t cudnn_param_size;
Expand Down
204 changes: 172 additions & 32 deletions src/operator/nn/dropout-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ struct DropoutParam : public dmlc::Parameter<DropoutParam> {

template<typename xpu, typename DType>
class DropoutOp {
#if defined(USE_MKL) && defined(_OPENMP)
#if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__)
static void BernoulliGenerate(common::random::RandGenerator<cpu, DType> gen,
int n, double p, int* r) {
typename RandGenerator<xpu, DType>::Impl genImpl(&gen, 1);
Expand Down Expand Up @@ -150,23 +150,7 @@ class DropoutOp {
return false;
}

#ifdef __CUDACC__
// GPU never uses MKL
static bool MSHADOW_CINLINE MKLForward(mshadow::Stream<gpu> *s, RandGenerator<gpu, DType> *pgen,
const double pkeep,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data) {
return false;
}
static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream<gpu> *s, const double pkeep,
const std::vector<TBlob> &in_grad,
const std::vector<TBlob> &out_data,
const std::vector<TBlob> &out_grad) {
return false;
}
#endif // __CUDACC__

#else // #if defined(USE_MKL) && defined(_OPENMP)
#else // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__)
static bool MSHADOW_CINLINE MKLForward(mshadow::Stream<xpu> *s, RandGenerator<xpu, DType> *pgen,
const double pkeep,
const std::vector<TBlob> &in_data,
Expand All @@ -179,7 +163,7 @@ class DropoutOp {
const std::vector<TBlob> &out_grad) {
return false;
}
#endif // #if defined(USE_MKL) && defined(_OPENMP)
#endif // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__)

public:
/*!
Expand Down Expand Up @@ -227,12 +211,136 @@ class DropoutOp {
}
};

void Init(const DropoutParam &param) {
explicit DropoutOp(const DropoutParam &param, Context ctx) {
this->pkeep_ = 1.0f - param.p;
this->mode_ = static_cast<dropout::DropoutOpMode>(param.mode);
this->axes_ = param.axes;
#if MXNET_USE_CUDNN == 1
this->ctx_ = ctx;
if (ctx.dev_type == kGPU) {
init_cudnn_ = false;
dtype_ = mshadow::DataType<DType>::kCudnnFlag;
CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&y_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dx_desc_));
CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_desc_));
CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_));
}
#endif // MXNET_USE_CUDNN == 1
}

~DropoutOp() {
#if MXNET_USE_CUDNN == 1
if (this->ctx_.dev_type == kGPU) {
CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_));
CUDNN_CALL(cudnnDestroyTensorDescriptor(dy_desc_));
CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_));
if (init_cudnn_) {
Storage::Get()->Free(dropout_states_);
Storage::Get()->Free(reserve_space_);
}
}
#endif // MXNET_USE_CUDNN == 1
}

#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
inline void CuDNNForward(const OpContext &ctx,
const TBlob &in,
const TBlob &out) {
Stream<xpu> *s = ctx.get_stream<xpu>();

// set dropout state.
// TODO(szha): expensive call, should be cached and reused across operators.
if (!init_cudnn_) {
CUDNN_CALL(cudnnDropoutGetStatesSize(s->dnn_handle_, &dropout_state_byte_));
dropout_states_ = Storage::Get()->Alloc(dropout_state_byte_, Context::GPU(s->dev_id));
CUDNN_CALL(cudnnSetDropoutDescriptor(dropout_desc_, s->dnn_handle_,
this->pkeep_,
dropout_states_.dptr, dropout_state_byte_,
seed_));
}

// describe input/output tensor
int dim[4], stride[4];
dim[0] = 1;
dim[1] = 1;
dim[2] = 1;
dim[3] = out.Size();
stride[0] = out.Size();
stride[1] = out.Size();
stride[2] = out.Size();
stride[3] = 1;
CUDNN_CALL(cudnnSetTensorNdDescriptor(x_desc_,
dtype_,
4,
dim,
stride));
CUDNN_CALL(cudnnSetTensorNdDescriptor(y_desc_,
dtype_,
4,
dim,
stride));

// perform dropout with cudnn
CUDNN_CALL(cudnnDropoutGetReserveSpaceSize(x_desc_, &dropout_reserve_byte_));
if (init_cudnn_ && dropout_reserve_byte_ > reserve_space_.size) {
Storage::Get()->Free(reserve_space_);
init_cudnn_ = false;
}
if (!init_cudnn_) {
reserve_space_ = Storage::Get()->Alloc(dropout_reserve_byte_, Context::GPU(s->dev_id));
init_cudnn_ = true;
}
CUDNN_CALL(cudnnDropoutForward(s->dnn_handle_,
dropout_desc_,
x_desc_,
in.dptr<DType>(),
y_desc_,
out.dptr<DType>(),
reserve_space_.dptr,
dropout_reserve_byte_));
}

inline void CuDNNBackward(const OpContext &ctx,
const TBlob &out_grad,
const TBlob &in_grad) {
Stream<xpu> *s = ctx.get_stream<xpu>();

// describe input/output tensor
int dim[4], stride[4];
dim[0] = 1;
dim[1] = 1;
dim[2] = 1;
dim[3] = in_grad.Size();
stride[0] = in_grad.Size();
stride[1] = in_grad.Size();
stride[2] = in_grad.Size();
stride[3] = 1;
CUDNN_CALL(cudnnSetTensorNdDescriptor(dy_desc_,
dtype_,
4,
dim,
stride));
CUDNN_CALL(cudnnSetTensorNdDescriptor(dx_desc_,
dtype_,
4,
dim,
stride));

// perform dropout with cudnn
CUDNN_CALL(cudnnDropoutBackward(s->dnn_handle_,
dropout_desc_,
dy_desc_,
out_grad.dptr<DType>(),
dx_desc_,
in_grad.dptr<DType>(),
reserve_space_.dptr,
dropout_reserve_byte_));
}
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)

void Forward(const OpContext &ctx,
const std::vector<TBlob> &in_data,
const std::vector<OpReqType> &req,
Expand All @@ -252,11 +360,15 @@ class DropoutOp {
CHECK(req[dropout::kOut] != kAddTo);
if (this->axes_.ndim() == 0) {
// standard case for dropout
#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
CuDNNForward(ctx, in_data[dropout::kData], out);
#else
LaunchRNG<DropoutKernel, xpu>(s, pgen, out.Size(),
out.dptr<DType>(),
mask.dptr<DType>(),
in_data[dropout::kData].dptr<DType>(),
this->pkeep_);
out.dptr<DType>(),
mask.dptr<DType>(),
in_data[dropout::kData].dptr<DType>(),
this->pkeep_);
#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
return;
}

Expand Down Expand Up @@ -319,10 +431,14 @@ class DropoutOp {
if (this->axes_.ndim() == 0) {
// standard case for dropout
CHECK_EQ(grad.Size(), mask.Size());
#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__)
CuDNNBackward(ctx, grad, gdata);
#else
MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, {
mxnet_op::Kernel<mxnet_op::op_with_req<mshadow_op::mul, Req>, xpu>::Launch(
s, gdata.Size(), gdata.dptr<DType>(), grad.dptr<DType>(), mask.dptr<DType>());
});
#endif // MXNET_USE_CUDNN == 1 & defined(__CUDACC__)
return;
}
// broardcast mul
Expand Down Expand Up @@ -367,29 +483,54 @@ class DropoutOp {
/*! \brief Dropout mode */
dropout::DropoutOpMode mode_;
TShape axes_;
#if MXNET_USE_CUDNN == 1
Context ctx_;
cudnnDataType_t dtype_;
cudnnDropoutDescriptor_t dropout_desc_;
bool init_cudnn_;
uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn)
size_t dropout_state_byte_, dropout_reserve_byte_;
Storage::Handle dropout_states_, reserve_space_;
cudnnTensorDescriptor_t x_desc_, y_desc_, dx_desc_, dy_desc_;
#endif // MXNET_USE_CUDNN == 1
}; // class DropoutOp

static OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs,
const Context ctx,
const std::vector<TShape> &in_shapes,
const std::vector<int> &in_types) {
const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
OpStatePtr state;
MSHADOW_REAL_TYPE_SWITCH(in_types[dropout::kData], DType, {
if (ctx.dev_type == kGPU) {
state = OpStatePtr::Create<DropoutOp<gpu, DType>>(param, ctx);
} else {
state = OpStatePtr::Create<DropoutOp<cpu, DType>>(param, ctx);
}
return state;
});
LOG(FATAL) << "should never reach here";
return OpStatePtr(); // should never reach here
}

template<typename xpu>
void DropoutCompute(const nnvm::NodeAttrs& attrs,
void DropoutCompute(const OpStatePtr& state,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
DropoutOp<xpu, DType> op;
op.Init(param);
DropoutOp<xpu, DType>& op = state.get_state<DropoutOp<xpu, DType>>();
op.Forward(ctx, inputs, req, outputs);
});
}

template<typename xpu>
void DropoutGradCompute(const nnvm::NodeAttrs& attrs,
void DropoutGradCompute(const OpStatePtr& state,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
const DropoutParam& param = nnvm::get<DropoutParam>(attrs.parsed);
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1);
CHECK_EQ(req.size(), 1);
Expand All @@ -399,8 +540,7 @@ void DropoutGradCompute(const nnvm::NodeAttrs& attrs,
out_data[dropout::kMask] = inputs[1];

MSHADOW_REAL_TYPE_SWITCH(inputs[0].type_flag_, DType, {
DropoutOp<xpu, DType> op;
op.Init(param);
DropoutOp<xpu, DType>& op = state.get_state<DropoutOp<xpu, DType>>();
op.Backward(ctx, out_grads, out_data, req, outputs);
});
}
Expand Down
8 changes: 5 additions & 3 deletions src/operator/nn/dropout.cc
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,27 @@ Example::
for (size_t i = 0; i < nout; ++i) out_type->push_back(dtype);
return true;
})
.set_attr<FCompute>("FCompute<cpu>", DropoutCompute<cpu>)
.set_attr<FCreateOpState>("FCreateOpState", CreateDropoutState)
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", DropoutCompute<cpu>)
.set_attr<nnvm::FGradient>("FGradient", DropoutGrad{"_backward_Dropout"})
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
return std::vector<ResourceRequest>{ ResourceRequest::kParallelRandom };
return std::vector<ResourceRequest>{ResourceRequest::kParallelRandom};
})
.add_argument("data", "NDArray-or-Symbol", "Input array to which dropout will be applied.")
.add_arguments(DropoutParam::__FIELDS__());

NNVM_REGISTER_OP(_backward_Dropout)
.set_num_outputs(1)
.set_attr<bool>("TIsLayerOpBackward", true)
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
.set_attr_parser(ParamParser<DropoutParam>)
.set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
return std::vector<std::pair<int, int> >{{0, 0}};
})
.set_attr<FCompute>("FCompute<cpu>", DropoutGradCompute<cpu>);
.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", DropoutGradCompute<cpu>);

} // namespace op
} // namespace mxnet
4 changes: 2 additions & 2 deletions src/operator/nn/dropout.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,10 @@ namespace mxnet {
namespace op {

NNVM_REGISTER_OP(Dropout)
.set_attr<FCompute>("FCompute<gpu>", DropoutCompute<gpu>);
.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", DropoutCompute<gpu>);

NNVM_REGISTER_OP(_backward_Dropout)
.set_attr<FCompute>("FCompute<gpu>", DropoutGradCompute<gpu>);
.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", DropoutGradCompute<gpu>);

} // namespace op
} // namespace mxnet
Expand Down

0 comments on commit 03453c4

Please sign in to comment.