From 94a48f9c287d5272982229127247de45b272a577 Mon Sep 17 00:00:00 2001 From: Sheng Zha Date: Fri, 25 Jan 2019 16:00:12 -0800 Subject: [PATCH] refactor --- src/operator/nn/dropout-inl.h | 275 ++++++++++++------------- src/operator/nn/dropout.cc | 20 +- tests/python/unittest/test_operator.py | 28 ++- 3 files changed, 167 insertions(+), 156 deletions(-) diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h index 1cb324c4957c..a1e554f554af 100644 --- a/src/operator/nn/dropout-inl.h +++ b/src/operator/nn/dropout-inl.h @@ -39,12 +39,15 @@ #include "../random/sampler.h" #include "../tensor/elemwise_binary_broadcast_op.h" -#if defined(USE_MKL) && defined(_OPENMP) +#define MXNET_USE_MKL_DROPOUT defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) +#if MXNET_USE_MKL_DROPOUT #include #include #include -#endif // USE_MKL && _OPENMP +#endif // MXNET_USE_MKL_DROPOUT + +#define MXNET_USE_CUDNN_DROPOUT MXNET_USE_CUDNN == 1 && CUDNN_MAJOR >= 7 namespace dropout { enum DropoutOpInputs {kData}; @@ -74,14 +77,14 @@ struct DropoutParam : public dmlc::Parameter { .describe("Whether to only turn on dropout during training or to also turn on for inference."); DMLC_DECLARE_FIELD(axes).set_default(TShape()) .describe("Axes for variational dropout kernel."); - DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional(false)) + DMLC_DECLARE_FIELD(cudnn_off).set_default(dmlc::optional(true)) .describe("Whether to turn off cudnn in dropout operator."); } }; // struct DropoutParam template class DropoutOp { -#if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) +#if MXNET_USE_MKL_DROPOUT static void BernoulliGenerate(common::random::RandGenerator gen, int n, double p, int* r) { typename RandGenerator::Impl genImpl(&gen, 1); @@ -103,70 +106,55 @@ class DropoutOp { } } } - - // MKL forward pass - static bool MSHADOW_CINLINE MKLForward(mshadow::Stream *s, RandGenerator *pgen, - const double pkeep, - const std::vector &in_data, - const std::vector &out_data) { + static inline bool MKLAvailable() { // BernoulliGenerate expects an array int, so for types smaller than int, the mask buffer // will be too small, so we can;t use MKL in those cases - if (sizeof(DType) >= sizeof(int)) { - Tensor mask = out_data[dropout::kMask].FlatTo2D(s); - Tensor data = in_data[dropout::kData].FlatTo2D(s); - Tensor out = out_data[dropout::kOut].FlatTo2D(s); - DType *outptr = out.dptr_; - DType *dataptr = data.dptr_; - auto maskptr = reinterpret_cast(mask.dptr_); - int count = mask.shape_[0] * mask.shape_[1]; - BernoulliGenerate(*pgen, count, pkeep, maskptr); - const float pk_1 = 1.0f / pkeep; + return sizeof(DType) >= sizeof(int); + } + + // MKL forward pass + inline void MKLForward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &out_data) { + Stream *s = ctx.get_stream(); + RandGenerator *pgen = ctx.requested[0].get_parallel_random(); + CHECK_NOTNULL(pgen); + Tensor mask = out_data[dropout::kMask].FlatTo2D(s); + Tensor data = in_data[dropout::kData].FlatTo2D(s); + Tensor out = out_data[dropout::kOut].FlatTo2D(s); + DType *outptr = out.dptr_; + DType *dataptr = data.dptr_; + auto maskptr = reinterpret_cast(mask.dptr_); + int count = mask.shape_[0] * mask.shape_[1]; + BernoulliGenerate(*pgen, count, this->pkeep_, maskptr); + const float pk_1 = 1.0f / this->pkeep_; #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int i = 0; i < count; ++i) { - outptr[i] = dataptr[i] * maskptr[i] * pk_1; - } - return true; + for (int i = 0; i < count; ++i) { + outptr[i] = dataptr[i] * maskptr[i] * pk_1; } - return false; } // MKL backward pass - static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream *s, const double pkeep, - const std::vector &in_grad, - const std::vector &out_data, - const std::vector &out_grad) { - if (sizeof(DType) >= sizeof(int)) { - Tensor grad = out_grad[dropout::kOut].FlatTo2D(s); - Tensor mask = out_data[dropout::kMask].FlatTo2D(s); - Tensor gdata = in_grad[dropout::kData].FlatTo2D(s); - DType *ingradptr = gdata.dptr_; - const DType *outgradptr = grad.dptr_; - auto maskptr = reinterpret_cast(mask.dptr_); - int count = mask.shape_[0] * mask.shape_[1]; - const float pk_1 = 1.0f / pkeep; + inline void MKLBackward(const OpContext &ctx, + const std::vector &in_grad, + const std::vector &out_data, + const std::vector &out_grad) { + Stream *s = ctx.get_stream(); + Tensor grad = out_grad[dropout::kOut].FlatTo2D(s); + Tensor mask = out_data[dropout::kMask].FlatTo2D(s); + Tensor gdata = in_grad[dropout::kData].FlatTo2D(s); + DType *ingradptr = gdata.dptr_; + const DType *outgradptr = grad.dptr_; + auto maskptr = reinterpret_cast(mask.dptr_); + int count = mask.shape_[0] * mask.shape_[1]; + const float pk_1 = 1.0f / this->pkeep_; #pragma omp parallel for num_threads(engine::OpenMP::Get()->GetRecommendedOMPThreadCount()) - for (int i = 0; i < count; ++i) { - ingradptr[i] = outgradptr[i] * maskptr[i] * pk_1; - } - return true; + for (int i = 0; i < count; ++i) { + ingradptr[i] = outgradptr[i] * maskptr[i] * pk_1; } - return false; } -#else // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) - static bool MSHADOW_CINLINE MKLForward(mshadow::Stream *s, RandGenerator *pgen, - const double pkeep, - const std::vector &in_data, - const std::vector &out_data) { - return false; - } - static bool MSHADOW_CINLINE MKLBackward(mshadow::Stream *s, const double pkeep, - const std::vector &in_grad, - const std::vector &out_data, - const std::vector &out_grad) { - return false; - } -#endif // #if defined(USE_MKL) && defined(_OPENMP) && !defined(__CUDACC__) +#endif // #if MXNET_USE_MKL_DROPOUT public: /*! @@ -218,10 +206,10 @@ class DropoutOp { this->pkeep_ = 1.0f - param.p; this->mode_ = static_cast(param.mode); this->axes_ = param.axes; -#if MXNET_USE_CUDNN == 1 +#if MXNET_USE_CUDNN_DROPOUT this->cudnn_off_ = param.cudnn_off && param.cudnn_off.value(); this->ctx_ = ctx; - if (ctx.dev_type == kGPU && this->pkeep_ > 0) { + if (ctx.dev_type == kGPU && this->pkeep_ > 0 && !this->cudnn_off_) { init_cudnn_ = false; dtype_ = mshadow::DataType::kCudnnFlag; CUDNN_CALL(cudnnCreateTensorDescriptor(&x_desc_)); @@ -230,12 +218,12 @@ class DropoutOp { CUDNN_CALL(cudnnCreateTensorDescriptor(&dy_desc_)); CUDNN_CALL(cudnnCreateDropoutDescriptor(&dropout_desc_)); } -#endif // MXNET_USE_CUDNN == 1 +#endif // MXNET_USE_CUDNN_DROPOUT } ~DropoutOp() { -#if MXNET_USE_CUDNN == 1 - if (this->ctx_.dev_type == kGPU && this->pkeep_ > 0) { +#if MXNET_USE_CUDNN_DROPOUT + if (this->ctx_.dev_type == kGPU && this->pkeep_ > 0 && !this->cudnn_off_) { CUDNN_CALL(cudnnDestroyTensorDescriptor(x_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(y_desc_)); CUDNN_CALL(cudnnDestroyTensorDescriptor(dx_desc_)); @@ -243,15 +231,19 @@ class DropoutOp { CUDNN_CALL(cudnnDestroyDropoutDescriptor(dropout_desc_)); if (init_cudnn_) { Storage::Get()->Free(dropout_states_); - Storage::Get()->Free(reserve_space_); } } -#endif // MXNET_USE_CUDNN == 1 +#endif // MXNET_USE_CUDNN_DROPOUT + } + +#if MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + inline bool CuDNNAvailable() { + return this->pkeep_ > 0 && !this->cudnn_off_; } -#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) inline void CuDNNForward(const OpContext &ctx, const TBlob &in, + const TBlob &mask, const TBlob &out) { Stream *s = ctx.get_stream(); @@ -264,6 +256,7 @@ class DropoutOp { 1.0f - this->pkeep_, dropout_states_.dptr, dropout_state_byte_, seed_)); + init_cudnn_ = true; } // describe input/output tensor @@ -289,26 +282,23 @@ class DropoutOp { // perform dropout with cudnn CUDNN_CALL(cudnnDropoutGetReserveSpaceSize(x_desc_, &dropout_reserve_byte_)); - if (init_cudnn_ && dropout_reserve_byte_ > reserve_space_.size) { - Storage::Get()->Free(reserve_space_); - init_cudnn_ = false; - } - if (!init_cudnn_) { - reserve_space_ = Storage::Get()->Alloc(dropout_reserve_byte_, Context::GPU(s->dev_id)); - init_cudnn_ = true; - } + // cudnn uses bits to record the positions that are dropped, so reserve bytes is always + // 1/8 of input size. + CHECK_GE(mask.Size() * sizeof(DType), dropout_reserve_byte_) << + "The size of the mask space is smaller than the required cudnn reserved space."; CUDNN_CALL(cudnnDropoutForward(s->dnn_handle_, dropout_desc_, x_desc_, in.dptr(), y_desc_, out.dptr(), - reserve_space_.dptr, + mask.dptr(), dropout_reserve_byte_)); } inline void CuDNNBackward(const OpContext &ctx, const TBlob &out_grad, + const TBlob &mask, const TBlob &in_grad) { Stream *s = ctx.get_stream(); @@ -340,10 +330,10 @@ class DropoutOp { out_grad.dptr(), dx_desc_, in_grad.dptr(), - reserve_space_.dptr, + mask.dptr(), dropout_reserve_byte_)); } -#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) +#endif // MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) void Forward(const OpContext &ctx, const std::vector &in_data, @@ -355,50 +345,48 @@ class DropoutOp { CHECK_EQ(out_data.size(), 2U); } Stream *s = ctx.get_stream(); + const TBlob &in = in_data[dropout::kData]; const TBlob &out = out_data[dropout::kOut]; + const TBlob &mask = out_data[dropout::kMask]; if (ctx.is_train || this->mode_ == dropout::kAlways) { - RandGenerator *pgen = ctx.requested[0].get_parallel_random(); - CHECK_NOTNULL(pgen); - if (this->axes_.ndim() != 0 || !MKLForward(s, pgen, this->pkeep_, in_data, out_data)) { - const TBlob &mask = out_data[dropout::kMask]; - CHECK(req[dropout::kOut] != kAddTo); - if (this->axes_.ndim() == 0) { - // standard case for dropout -#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) - if (this->pkeep_ > 0 && !this->cudnn_off_) { - CuDNNForward(ctx, in_data[dropout::kData], out); - } else { - // existing dropout produces inf with pkeep=0, - // thus revert to existing GPU kernel for consistency. - LaunchRNG(s, pgen, out.Size(), - out.dptr(), - mask.dptr(), - in_data[dropout::kData].dptr(), - this->pkeep_); - } -#else - LaunchRNG(s, pgen, out.Size(), - out.dptr(), - mask.dptr(), - in_data[dropout::kData].dptr(), - this->pkeep_); -#endif // MXNET_USE_CUDNN == 1 && defined(__CUDACC__) + if (this->axes_.ndim() == 0) { +#if MXNET_USE_MKL_DROPOUT + if (MKLAvailable()) { + MKLForward(ctx, in_data, out_data); return; } - +#endif // MXNET_USE_MKL_DROPOUT +#if MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + if (CuDNNAvailable()) { + CuDNNForward(ctx, in, mask, out); + return; + } +#endif // MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + RandGenerator *pgen = ctx.requested[0].get_parallel_random(); + CHECK_NOTNULL(pgen); + CHECK(req[dropout::kOut] != kAddTo); + LaunchRNG(s, pgen, out.Size(), + out.dptr(), + mask.dptr(), + in.dptr(), + this->pkeep_); + return; + } else { + RandGenerator *pgen = ctx.requested[0].get_parallel_random(); + CHECK_NOTNULL(pgen); // initialize the mask LaunchRNG(s, pgen, mask.Size(), mask.dptr(), this->pkeep_); // broadcast mul TShape new_lshape, new_rshape, new_oshape; - int ndim = BinaryBroadcastShapeCompact(in_data[dropout::kData].shape_, + int ndim = BinaryBroadcastShapeCompact(in.shape_, mask.shape_, out.shape_, &new_lshape, &new_rshape, &new_oshape); if (!ndim) { MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { mxnet_op::Kernel, xpu>::Launch( - s, out.Size(), out.dptr(), in_data[dropout::kData].dptr(), + s, out.Size(), out.dptr(), in.dptr(), mask.dptr()); }); } else { @@ -410,21 +398,16 @@ class DropoutOp { mshadow_op::mul>, xpu>:: template LaunchEx(s, new_oshape.Size(), req[dropout::kOut], lstride, rstride, oshape, - in_data[dropout::kData].dptr(), + in.dptr(), mask.dptr(), out.dptr()); }); } } } else { - const TBlob& data = in_data[dropout::kData]; - if (req[dropout::kOut] == kWriteTo) { - mxnet_op::copy(s, out, data); - } else { - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, out.Size(), out.dptr(), data.dptr()); - }); - } + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kOut], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, out.Size(), out.dptr(), in.dptr()); + }); } } } @@ -438,30 +421,30 @@ class DropoutOp { using namespace mshadow::expr; Stream *s = ctx.get_stream(); if (ctx.is_train || mode_ == dropout::kAlways) { - if (this->axes_.ndim() != 0 || !MKLBackward(s, this->pkeep_, in_grad, out_data, out_grad)) { - const TBlob &gdata = in_grad[dropout::kData]; - const TBlob &grad = out_grad[dropout::kOut]; - const TBlob &mask = out_data[dropout::kMask]; - if (this->axes_.ndim() == 0) { - // standard case for dropout - CHECK_EQ(grad.Size(), mask.Size()); -#if MXNET_USE_CUDNN == 1 && defined(__CUDACC__) - if (this->pkeep_ > 0 && !this->cudnn_off_) { - CuDNNBackward(ctx, grad, gdata); - } else { - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); - }); - } -#else - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); - }); -#endif // MXNET_USE_CUDNN == 1 & defined(__CUDACC__) + const TBlob &gdata = in_grad[dropout::kData]; + const TBlob &grad = out_grad[dropout::kOut]; + const TBlob &mask = out_data[dropout::kMask]; + if (this->axes_.ndim() == 0) { +#if MXNET_USE_MKL_DROPOUT + if (MKLAvailable()) { + MKLBackward(ctx, in_grad, out_data, out_grad); + return; + } +#endif // MXNET_USE_MKL_DROPOUT +#if MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + if (CuDNNAvailable()) { + CuDNNBackward(ctx, grad, mask, gdata); return; } +#endif // MXNET_USE_CUDNN_DROPOUT && defined(__CUDACC__) + // standard case for dropout + CHECK_EQ(grad.Size(), mask.Size()); + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, gdata.Size(), gdata.dptr(), grad.dptr(), mask.dptr()); + }); + return; + } else { // broardcast mul TShape new_lshape, new_rshape, new_oshape; int ndim = BinaryBroadcastShapeCompact(grad.shape_, @@ -487,14 +470,10 @@ class DropoutOp { } else { const TBlob& gdata = in_grad[dropout::kData]; const TBlob& grad = out_grad[dropout::kOut]; - if (req[dropout::kData] == kWriteTo) { - mxnet_op::copy(s, gdata, grad); - } else { - MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { - mxnet_op::Kernel, xpu>::Launch( - s, gdata.Size(), gdata.dptr(), grad.dptr()); - }); - } + MXNET_ASSIGN_REQ_SWITCH(req[dropout::kData], Req, { + mxnet_op::Kernel, xpu>::Launch( + s, gdata.Size(), gdata.dptr(), grad.dptr()); + }); } } @@ -504,7 +483,7 @@ class DropoutOp { /*! \brief Dropout mode */ dropout::DropoutOpMode mode_; TShape axes_; -#if MXNET_USE_CUDNN == 1 +#if MXNET_USE_CUDNN_DROPOUT bool cudnn_off_; Context ctx_; cudnnDataType_t dtype_; @@ -512,9 +491,9 @@ class DropoutOp { bool init_cudnn_; uint64_t seed_ = 17 + rand() % 4096; // NOLINT(runtime/threadsafe_fn) size_t dropout_state_byte_, dropout_reserve_byte_; - Storage::Handle dropout_states_, reserve_space_; + Storage::Handle dropout_states_; cudnnTensorDescriptor_t x_desc_, y_desc_, dx_desc_, dy_desc_; -#endif // MXNET_USE_CUDNN == 1 +#endif // MXNET_USE_CUDNN_DROPOUT }; // class DropoutOp static OpStatePtr CreateDropoutState(const nnvm::NodeAttrs &attrs, diff --git a/src/operator/nn/dropout.cc b/src/operator/nn/dropout.cc index 27cecee906a0..1862a2e96a2f 100644 --- a/src/operator/nn/dropout.cc +++ b/src/operator/nn/dropout.cc @@ -125,9 +125,23 @@ Example:: .set_attr("FInplaceOption", [](const NodeAttrs& attrs){ return std::vector >{{0, 0}}; }) -.set_attr("FResourceRequest", [](const NodeAttrs& n) { - return std::vector{ResourceRequest::kParallelRandom}; -}) +.set_attr("FResourceRequestEx", + [](const NodeAttrs& attrs, const int dev_mask, const DispatchMode dispatch_mode) { + std::vector request; + if (dev_mask == kGPU) { +#if MXNET_USE_CUDNN_DROPOUT + const DropoutParam& param = nnvm::get(attrs.parsed); + // if cudnn is used, parallel random is not needed. + if (1.0f - param.p > 0 + && !(param.cudnn_off && param.cudnn_off.value()) + && param.axes.ndim() == 0) { + return request; + } +#endif + } + request.emplace_back(ResourceRequest::kParallelRandom); + return request; + }) .add_argument("data", "NDArray-or-Symbol", "Input array to which dropout will be applied.") .add_arguments(DropoutParam::__FIELDS__()); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 670cc7eb15e0..b648d2507798 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5909,10 +5909,10 @@ def check_correctness(executor, input, ratio): elif ratio == 0: assert output_zeroes == 0 - def check_dropout_ratio(ratio, shape): + def check_dropout_ratio(ratio, shape, cudnn_off=True): # test dropout x = mx.sym.var('data') - y = mx.sym.Dropout(x, p=ratio) + y = mx.sym.Dropout(x, p=ratio, cudnn_off=cudnn_off) exe = y.simple_bind(ctx=default_context(), data=shape) if ratio == 1: @@ -5950,7 +5950,7 @@ def check_dropout_ratio(ratio, shape): # test permanent dropout x = mx.sym.var('data') - y = mx.sym.Dropout(x, p=ratio, mode='always') + y = mx.sym.Dropout(x, p=ratio, mode='always', cudnn_off=cudnn_off) exe = y.simple_bind(ctx=default_context(), data=shape) exe.arg_arrays[0][:] = 1 @@ -5975,13 +5975,13 @@ def get_slice(x, axis, idx): ix += (slice(None, None, None),) return x[ix] - def check_dropout_axes(ratio, shape, axes): + def check_dropout_axes(ratio, shape, axes, cudnn_off=False): compactshape = list(shape) for axis in axes: compactshape[axis] = 1 compactx = mx.random.uniform(shape=tuple(compactshape)) broadcastx = compactx.broadcast_to(shape) - dropouty = mx.nd.Dropout(broadcastx, p=ratio, axes=axes) + dropouty = mx.nd.Dropout(broadcastx, p=ratio, axes=axes, cudnn_off=cudnn_off) for axis in axes: target = get_slice(dropouty, axis, 0).asnumpy() for i in range(1, shape[axis]): @@ -5993,6 +5993,11 @@ def check_dropout_axes(ratio, shape, axes): check_dropout_ratio(1.0, shape) check_dropout_ratio(0.75, shape) check_dropout_ratio(0.25, shape) + check_dropout_ratio(0.5, shape, cudnn_off=False) + check_dropout_ratio(0.0, shape, cudnn_off=False) + check_dropout_ratio(1.0, shape, cudnn_off=False) + check_dropout_ratio(0.75, shape, cudnn_off=False) + check_dropout_ratio(0.25, shape, cudnn_off=False) nshape = (10, 10, 10, 10) with mx.autograd.train_mode(): @@ -6009,6 +6014,19 @@ def check_dropout_axes(ratio, shape, axes): check_dropout_axes(0.25, nshape, axes = (0, 1, 2)) check_dropout_axes(0.25, nshape, axes = (0, 2, 3)) check_dropout_axes(0.25, nshape, axes = (1, 2, 3)) + check_dropout_axes(0.25, nshape, axes = (0,), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (1,), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (2,), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (3,), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 1), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 2), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 3), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (1, 2), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (1, 3), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (2, 3), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 1, 2), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (0, 2, 3), cudnn_off=False) + check_dropout_axes(0.25, nshape, axes = (1, 2, 3), cudnn_off=False) @unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11290")