diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 5e60989489f6..39fcd81642d3 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -44,6 +44,18 @@ def test_module_dtype(): assert x.dtype == dtype +def test_module_bind(): + sym = mx.sym.Variable('data') + sym = mx.sym.Activation(data=sym, act_type='relu', __layout__='TNC') + + mod = mx.mod.Module(sym, ('data',), None, context=[mx.cpu(0), mx.cpu(1)]) + assertRaises(TypeError, mod.bind, data_shapes=[('data', mx.nd.array([10,10]))]) + assert mod.binded == False + + mod.bind(data_shapes=[('data', (10,10))]) + assert mod.binded == True + + @with_seed() def test_module_input_grads(): a = mx.sym.Variable('a', __layout__='NC')