From 06142862e5676aeaab6a8113687b826ccfb0556b Mon Sep 17 00:00:00 2001 From: reminisce Date: Fri, 28 Jun 2019 01:08:12 -0700 Subject: [PATCH] [numpy] Fix several places in numpy (#15398) * Fix * More fix --- include/mxnet/base.h | 4 +++- python/mxnet/contrib/text/embedding.py | 2 +- python/mxnet/gluon/nn/basic_layers.py | 4 ++-- python/mxnet/numpy/multiarray.py | 10 ++++++---- python/mxnet/numpy_extension/image.py | 2 ++ 5 files changed, 14 insertions(+), 8 deletions(-) diff --git a/include/mxnet/base.h b/include/mxnet/base.h index c1e2da7a9db3..25d9ba8c32a0 100644 --- a/include/mxnet/base.h +++ b/include/mxnet/base.h @@ -451,7 +451,9 @@ inline int32_t Context::GetGPUCount() { } int32_t count; cudaError_t e = cudaGetDeviceCount(&count); - if (e == cudaErrorNoDevice) { + // TODO(junwu): Remove e == 35 + // This is skipped for working around wheel build system with older CUDA driver. + if (e == cudaErrorNoDevice || e == 35) { return 0; } CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e); diff --git a/python/mxnet/contrib/text/embedding.py b/python/mxnet/contrib/text/embedding.py index da20fbed1cbf..979ba2afbeb5 100644 --- a/python/mxnet/contrib/text/embedding.py +++ b/python/mxnet/contrib/text/embedding.py @@ -405,7 +405,7 @@ def get_vecs_by_tokens(self, tokens, lower_case_backup=False): for token in tokens] if is_np_array(): - embedding_fn = _mx_npx.Embedding + embedding_fn = _mx_npx.embedding array_fn = _mx_np.array else: embedding_fn = nd.Embedding diff --git a/python/mxnet/gluon/nn/basic_layers.py b/python/mxnet/gluon/nn/basic_layers.py index b99d5ef1499e..d7f599de66c5 100644 --- a/python/mxnet/gluon/nn/basic_layers.py +++ b/python/mxnet/gluon/nn/basic_layers.py @@ -436,9 +436,9 @@ class Flatten(HybridBlock): def __init__(self, **kwargs): super(Flatten, self).__init__(**kwargs) - @_adapt_np_array def hybrid_forward(self, F, x): - return F.Flatten(x) + flatten = F.npx.batch_flatten if is_np_array() else F.flatten + return flatten(x) def __repr__(self): return self.__class__.__name__ diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py index 97571ef4465d..10cfe7d73926 100644 --- a/python/mxnet/numpy/multiarray.py +++ b/python/mxnet/numpy/multiarray.py @@ -466,10 +466,12 @@ def __repr__(self): """ array_str = self.asnumpy().__repr__() dtype = self.dtype - if dtype == _np.float64: - array_str = array_str[:-1] + ', dtype=float64)' - elif dtype == _np.float32: - array_str = array_str[:array_str.rindex(', dtype=')] + ')' + if 'dtype=' in array_str: + if dtype == _np.float32: + array_str = array_str[:array_str.rindex(',')] + ')' + elif dtype != _np.float32: + array_str = array_str[:-1] + ', dtype={})'.format(dtype.__name__) + context = self.context if context.device_type == 'cpu': return array_str diff --git a/python/mxnet/numpy_extension/image.py b/python/mxnet/numpy_extension/image.py index b3bd27fc503c..00a028b3c18f 100644 --- a/python/mxnet/numpy_extension/image.py +++ b/python/mxnet/numpy_extension/image.py @@ -17,4 +17,6 @@ """Image pre-processing operators.""" +from ..image import * # pylint: disable=wildcard-import, unused-wildcard-import + __all__ = []