From 811e777eac4dd5659a52a06a796d5af276b7c37e Mon Sep 17 00:00:00 2001 From: Alexander Zai Date: Tue, 14 Aug 2018 13:08:36 -0700 Subject: [PATCH 1/4] add test to check binded is not when exception thrown --- tests/python/unittest/test_module.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index a21527a5a4ad..e5aaad0faa18 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -44,6 +44,18 @@ def test_module_dtype(): assert x.dtype == dtype +def test_module_bind(): + sym = mx.sym.Variable('data') + sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC') + + mod = mx.mod.Module(sym, ('data',), None, context=[mx.cpu(0), mx.cpu(1)]) + assertRaises(mod.bind, data_shapes=[('data', mx.nd.array([10,10]))]) + mod.binded = False + + mod.bind(data_shapes=[('data', (10,10))]) + mod.binded = True + + @with_seed() def test_module_input_grads(): a = mx.sym.Variable('a', __layout__='NC') From f090785a0911755ba7d049d832bbae91bf12b454 Mon Sep 17 00:00:00 2001 From: Alexander Zai Date: Tue, 14 Aug 2018 13:11:13 -0700 Subject: [PATCH 2/4] add error to assertion --- tests/python/unittest/test_module.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index e5aaad0faa18..cc699f6a7d99 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -49,7 +49,7 @@ def test_module_bind(): sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC') mod = mx.mod.Module(sym, ('data',), None, context=[mx.cpu(0), mx.cpu(1)]) - assertRaises(mod.bind, data_shapes=[('data', mx.nd.array([10,10]))]) + assertRaises(TypeError, mod.bind, data_shapes=[('data', mx.nd.array([10,10]))]) mod.binded = False mod.bind(data_shapes=[('data', (10,10))]) From ead7e8fdf098927e0c55e1a41171e30af92e724a Mon Sep 17 00:00:00 2001 From: Alexander Zai Date: Wed, 15 Aug 2018 10:29:56 -0700 Subject: [PATCH 3/4] fix typo by adding asserts --- tests/python/unittest/test_module.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index cc699f6a7d99..b6212d8747f9 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -50,10 +50,10 @@ def test_module_bind(): mod = mx.mod.Module(sym, ('data',), None, context=[mx.cpu(0), mx.cpu(1)]) assertRaises(TypeError, mod.bind, data_shapes=[('data', mx.nd.array([10,10]))]) - mod.binded = False + assert mod.binded == False mod.bind(data_shapes=[('data', (10,10))]) - mod.binded = True + assert mod.binded == True @with_seed() From cea8e14d4abb862cd144f0ea19a8f8e41cf83f1e Mon Sep 17 00:00:00 2001 From: Alexander Zai Date: Thu, 16 Aug 2018 16:54:41 -0700 Subject: [PATCH 4/4] retrigger