From 76374fd8009065f351c8abd88dbed699a2d78271 Mon Sep 17 00:00:00 2001 From: Joe Evans Date: Mon, 27 Jul 2020 20:53:44 -0700 Subject: [PATCH] Add unit tests for potri and potrf backward and check output shape in unit tests. --- tests/nightly/test_large_array.py | 40 ++++++++++++++++++------------- 1 file changed, 24 insertions(+), 16 deletions(-) diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 306c827bab9f..21d5cd28d041 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1170,26 +1170,34 @@ def check_correctness(mxnet_op, numpy_op, atol=1e-3): def test_linalg(): def check_potrf(): - # creating an identity matrix input - A = nd.zeros((LARGE_SQ_X, LARGE_SQ_X)) - for i in range(LARGE_SQ_X): - A[i,i] = 1 + def run_potrf(inp): + inp.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.potrf(inp) + return inp.grad, out - out = nd.linalg.potrf(A) - # output should be an identity matrix - for i in range(LARGE_SQ_X): - assert out[i,i] == 1 + A = get_identity_mat(LARGE_SQ_X) + grad, out = run_potrf(A) + assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(out[0, 0] == 1) + out.backward() + assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(grad[0, 0] == 0.5) def check_potri(): - # creating an identity matrix input - A = nd.zeros((LARGE_SQ_X, LARGE_SQ_X)) - for i in range(LARGE_SQ_X): - A[i,i] = 1 + def run_potri(inp): + inp.attach_grad() + with mx.autograd.record(): + out = mx.nd.linalg.potri(inp) + return inp.grad, out - out = nd.linalg.potri(A) - # output should be an identity matrix - for i in range(LARGE_SQ_X): - assert out[i,i] == 1 + A = get_identity_mat(LARGE_SQ_X) + grad, out = run_potri(A) + assert(out.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(out[0, 0] == 1) + out.backward() + assert(grad.shape == (LARGE_SQ_X, LARGE_SQ_X)) + assert(grad[0, 0] == -2) def check_syrk_batch(): # test both forward and backward