Skip to content

Commit

Permalink
Pin mxnet version in response to mx.metric reorg (dmlc#1310)
Browse files Browse the repository at this point in the history
* mx.metric

* pin version

* fix lint

* reduce parallel test iter/bs
  • Loading branch information
zhreshold authored and chongruo committed May 26, 2020
1 parent d2cecf5 commit d004183
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 15 deletions.
2 changes: 1 addition & 1 deletion gluoncv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .utils.version import _require_mxnet_version, _deprecate_python2
_deprecate_python2()
_require_mxnet_version('1.4.0')
_require_mxnet_version('1.4.0', '2.0.0')

from . import data
from . import model_zoo
Expand Down
23 changes: 13 additions & 10 deletions gluoncv/utils/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,24 +27,27 @@ def check_version(min_version, warning_only=False):
raise AssertionError(msg)


def _require_mxnet_version(mx_version):
def _require_mxnet_version(mx_version, max_mx_version='2.0.0'):
try:
import mxnet as mx
from distutils.version import LooseVersion
if LooseVersion(mx.__version__) < LooseVersion(mx_version):
if LooseVersion(mx.__version__) < LooseVersion(mx_version) or \
LooseVersion(mx.__version__) >= LooseVersion(max_mx_version):
version_str = '>={},<{}'.format(mx_version, max_mx_version)
msg = (
"Legacy mxnet-mkl=={} detected, some new modules may not work properly. "
"mxnet-mkl>={} is required. You can use pip to upgrade mxnet "
"`pip install -U --pre mxnet -f https://dist.mxnet.io/python/mkl` "
"or `pip install -U --pre mxnet -f https://dist.mxnet.io/python/cu100mkl`\
").format(mx.__version__, mx_version)
raise ImportError(msg)
"Legacy mxnet-mkl=={0} detected, some modules may not work properly. "
"mxnet-mkl{1} is required. You can use pip to upgrade mxnet "
"`pip install -U 'mxnet-mkl{1}'` "
"or `pip install -U 'mxnet-cu100mkl{1}'`\
").format(mx.__version__, version_str)
raise RuntimeError(msg)
except ImportError:
raise ImportError(
"Unable to import dependency mxnet. "
"A quick tip is to install via "
"`pip install --pre mxnet -f https://dist.mxnet.io/python/cu100mkl`. "
"please refer to https://gluon-cv.mxnet.io/#installation for details.")
"`pip install 'mxnet-cu100mkl<{}'`. "
"please refer to https://gluon-cv.mxnet.io/#installation for details.".format(
max_mx_version))

def _deprecate_python2():
if sys.version_info[0] < 3:
Expand Down
9 changes: 5 additions & 4 deletions tests/unittests/test_utils_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,19 @@ def test_net_sync(net, criterion, sync, nDevices):
ctx_list = [mx.cpu(0) for i in range(nDevices)]
net = DataParallelModel(net, ctx_list, sync=sync)
criterion = DataParallelCriterion(criterion, ctx_list, sync=sync)
iters = 100
iters = 10
bs = 2
# train mode
for i in range(iters):
x = mx.random.uniform(shape=(8, 1, 28, 28))
t = nd.ones(shape=(8))
x = mx.random.uniform(shape=(bs, 1, 28, 28))
t = nd.ones(shape=(bs))
with autograd.record():
y = net(x)
loss = criterion(y, t)
autograd.backward(loss)
# evaluation mode
for i in range(iters):
x = mx.random.uniform(shape=(8, 1, 28, 28))
x = mx.random.uniform(shape=(bs, 1, 28, 28))
y = net(x)
nd.waitall()

Expand Down

0 comments on commit d004183

Please sign in to comment.