diff --git a/tests/python/unittest/test_autograd.py b/tests/python/unittest/test_autograd.py index 3fa4a743cc25..bfa3f04156c7 100644 --- a/tests/python/unittest/test_autograd.py +++ b/tests/python/unittest/test_autograd.py @@ -89,29 +89,42 @@ def autograd_assert(*args, **kwargs): assert same(a.asnumpy(), b.asnumpy()) def test_unary_func(): - x = nd.uniform(shape=(4, 5)) - f_exp = lambda x: nd.exp(x) - f_exp_grad = lambda x: [nd.exp(x)] - autograd_assert(x, func=f_exp, grad_func=f_exp_grad) - f_half = lambda x: x/2 - f_half_grad = lambda x: [nd.ones(x.shape) * 0.5] - autograd_assert(x, func=f_half, grad_func=f_half_grad) - f_square = lambda x: x**2 - f_square_grad = lambda x: [2*x] - autograd_assert(x, func=f_square, grad_func=f_square_grad) + def check_unary_func(x): + f_exp = lambda x: nd.exp(x) + f_exp_grad = lambda x: [nd.exp(x)] + autograd_assert(x, func=f_exp, grad_func=f_exp_grad) + f_half = lambda x: x/2 + f_half_grad = lambda x: [nd.ones(x.shape) * 0.5] + autograd_assert(x, func=f_half, grad_func=f_half_grad) + f_square = lambda x: x**2 + f_square_grad = lambda x: [2*x] + autograd_assert(x, func=f_square, grad_func=f_square_grad) + uniform = nd.uniform(shape=(4, 5)) + stypes = ['row_sparse', 'csr', 'default'] + for stype in stypes: + x = mx.nd.cast_storage(uniform, stype=stype) + check_unary_func(x) def test_binary_func(): - x = nd.uniform(shape=(4, 5)) - y = nd.uniform(shape=(4, 5)) - f_add = lambda x, y: x+y - f_add_grad = lambda x, y: [nd.ones(x.shape), nd.ones(y.shape)] - autograd_assert(x, y, func=f_add, grad_func=f_add_grad) - f_mul = lambda x, y: x*y - f_mul_grad = lambda x, y: [y, x] - autograd_assert(x, y, func=f_mul, grad_func=f_mul_grad) - f_compose = lambda x, y: x+x*y - f_compose_grad = lambda x, y: [nd.ones(x.shape) + y, x] - autograd_assert(x, y, func=f_compose, grad_func=f_compose_grad) + def check_binary_func(x, y): + f_add = lambda x, y: x+y + f_add_grad = lambda x, y: [nd.ones(x.shape), nd.ones(y.shape)] + autograd_assert(x, y, func=f_add, grad_func=f_add_grad) + f_mul = lambda x, y: x*y + f_mul_grad = lambda x, y: [y, x] + autograd_assert(x, y, func=f_mul, grad_func=f_mul_grad) + f_compose = lambda x, y: x+x*y + f_compose_grad = lambda x, y: [nd.ones(x.shape) + y, x] + autograd_assert(x, y, func=f_compose, grad_func=f_compose_grad) + uniform_x = nd.uniform(shape=(4, 5)) + uniform_y = nd.uniform(shape=(4, 5)) + stypes = ['row_sparse', 'csr', 'default'] + for stype_x in stypes: + for stype_y in stypes: + x = mx.nd.cast_storage(uniform_x, stype=stype_x) + y = mx.nd.cast_storage(uniform_y, stype=stype_y) + check_binary_func(x, y) + def test_operator_with_state(): def f_fc(a, b, weight, bias): @@ -235,14 +248,19 @@ def test_retain_grad(): def test_attach_grad(): - x = mx.nd.zeros((10,)) - assert x.grad is None - x.attach_grad() - with record(): - y = x * 2 - assert y.grad is None - y.backward() - assert (x.grad.asnumpy() == 2).all() + def check_attach_grad(x): + assert x.grad is None + x.attach_grad() + with record(): + y = x * 2 + assert y.grad is None + y.backward() + assert (x.grad.asnumpy() == 2).all() + zeros = mx.nd.zeros((10, 10)) + stypes = ['default', 'row_sparse', 'csr'] + for stype in stypes: + x = mx.nd.cast_storage(zeros, stype=stype) + check_attach_grad(x) if __name__ == "__main__":