Skip to content

Commit

Permalink
[numpy] Fix several places in numpy (apache#15398)
Browse files Browse the repository at this point in the history
* Fix

* More fix
  • Loading branch information
reminisce committed Aug 1, 2019
1 parent 8cfeeb7 commit 0614286
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 8 deletions.
4 changes: 3 additions & 1 deletion include/mxnet/base.h
Original file line number Diff line number Diff line change
Expand Up @@ -451,7 +451,9 @@ inline int32_t Context::GetGPUCount() {
}
int32_t count;
cudaError_t e = cudaGetDeviceCount(&count);
if (e == cudaErrorNoDevice) {
// TODO(junwu): Remove e == 35
// This is skipped for working around wheel build system with older CUDA driver.
if (e == cudaErrorNoDevice || e == 35) {
return 0;
}
CHECK_EQ(e, cudaSuccess) << " CUDA: " << cudaGetErrorString(e);
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/contrib/text/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -405,7 +405,7 @@ def get_vecs_by_tokens(self, tokens, lower_case_backup=False):
for token in tokens]

if is_np_array():
embedding_fn = _mx_npx.Embedding
embedding_fn = _mx_npx.embedding
array_fn = _mx_np.array
else:
embedding_fn = nd.Embedding
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/gluon/nn/basic_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -436,9 +436,9 @@ class Flatten(HybridBlock):
def __init__(self, **kwargs):
super(Flatten, self).__init__(**kwargs)

@_adapt_np_array
def hybrid_forward(self, F, x):
return F.Flatten(x)
flatten = F.npx.batch_flatten if is_np_array() else F.flatten
return flatten(x)

def __repr__(self):
return self.__class__.__name__
Expand Down
10 changes: 6 additions & 4 deletions python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,10 +466,12 @@ def __repr__(self):
"""
array_str = self.asnumpy().__repr__()
dtype = self.dtype
if dtype == _np.float64:
array_str = array_str[:-1] + ', dtype=float64)'
elif dtype == _np.float32:
array_str = array_str[:array_str.rindex(', dtype=')] + ')'
if 'dtype=' in array_str:
if dtype == _np.float32:
array_str = array_str[:array_str.rindex(',')] + ')'
elif dtype != _np.float32:
array_str = array_str[:-1] + ', dtype={})'.format(dtype.__name__)

context = self.context
if context.device_type == 'cpu':
return array_str
Expand Down
2 changes: 2 additions & 0 deletions python/mxnet/numpy_extension/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,6 @@

"""Image pre-processing operators."""

from ..image import * # pylint: disable=wildcard-import, unused-wildcard-import

__all__ = []

0 comments on commit 0614286

Please sign in to comment.