From d00418308f1d67ba2b32d1b6b48f305017f6f2c2 Mon Sep 17 00:00:00 2001 From: "Joshua Z. Zhang" Date: Thu, 21 May 2020 16:32:17 -0700 Subject: [PATCH] Pin mxnet version in response to mx.metric reorg (#1310) * mx.metric * pin version * fix lint * reduce parallel test iter/bs --- gluoncv/__init__.py | 2 +- gluoncv/utils/version.py | 23 +++++++++++++---------- tests/unittests/test_utils_parallel.py | 9 +++++---- 3 files changed, 19 insertions(+), 15 deletions(-) diff --git a/gluoncv/__init__.py b/gluoncv/__init__.py index 556dfcbdac..fb8d8f1875 100644 --- a/gluoncv/__init__.py +++ b/gluoncv/__init__.py @@ -7,7 +7,7 @@ from .utils.version import _require_mxnet_version, _deprecate_python2 _deprecate_python2() -_require_mxnet_version('1.4.0') +_require_mxnet_version('1.4.0', '2.0.0') from . import data from . import model_zoo diff --git a/gluoncv/utils/version.py b/gluoncv/utils/version.py index 2103200c6c..7ab3a6b12b 100644 --- a/gluoncv/utils/version.py +++ b/gluoncv/utils/version.py @@ -27,24 +27,27 @@ def check_version(min_version, warning_only=False): raise AssertionError(msg) -def _require_mxnet_version(mx_version): +def _require_mxnet_version(mx_version, max_mx_version='2.0.0'): try: import mxnet as mx from distutils.version import LooseVersion - if LooseVersion(mx.__version__) < LooseVersion(mx_version): + if LooseVersion(mx.__version__) < LooseVersion(mx_version) or \ + LooseVersion(mx.__version__) >= LooseVersion(max_mx_version): + version_str = '>={},<{}'.format(mx_version, max_mx_version) msg = ( - "Legacy mxnet-mkl=={} detected, some new modules may not work properly. " - "mxnet-mkl>={} is required. You can use pip to upgrade mxnet " - "`pip install -U --pre mxnet -f https://dist.mxnet.io/python/mkl` " - "or `pip install -U --pre mxnet -f https://dist.mxnet.io/python/cu100mkl`\ - ").format(mx.__version__, mx_version) - raise ImportError(msg) + "Legacy mxnet-mkl=={0} detected, some modules may not work properly. " + "mxnet-mkl{1} is required. You can use pip to upgrade mxnet " + "`pip install -U 'mxnet-mkl{1}'` " + "or `pip install -U 'mxnet-cu100mkl{1}'`\ + ").format(mx.__version__, version_str) + raise RuntimeError(msg) except ImportError: raise ImportError( "Unable to import dependency mxnet. " "A quick tip is to install via " - "`pip install --pre mxnet -f https://dist.mxnet.io/python/cu100mkl`. " - "please refer to https://gluon-cv.mxnet.io/#installation for details.") + "`pip install 'mxnet-cu100mkl<{}'`. " + "please refer to https://gluon-cv.mxnet.io/#installation for details.".format( + max_mx_version)) def _deprecate_python2(): if sys.version_info[0] < 3: diff --git a/tests/unittests/test_utils_parallel.py b/tests/unittests/test_utils_parallel.py index f6fb201629..25d7818b18 100644 --- a/tests/unittests/test_utils_parallel.py +++ b/tests/unittests/test_utils_parallel.py @@ -26,18 +26,19 @@ def test_net_sync(net, criterion, sync, nDevices): ctx_list = [mx.cpu(0) for i in range(nDevices)] net = DataParallelModel(net, ctx_list, sync=sync) criterion = DataParallelCriterion(criterion, ctx_list, sync=sync) - iters = 100 + iters = 10 + bs = 2 # train mode for i in range(iters): - x = mx.random.uniform(shape=(8, 1, 28, 28)) - t = nd.ones(shape=(8)) + x = mx.random.uniform(shape=(bs, 1, 28, 28)) + t = nd.ones(shape=(bs)) with autograd.record(): y = net(x) loss = criterion(y, t) autograd.backward(loss) # evaluation mode for i in range(iters): - x = mx.random.uniform(shape=(8, 1, 28, 28)) + x = mx.random.uniform(shape=(bs, 1, 28, 28)) y = net(x) nd.waitall()