From e0b3ef0f65d0272edcbd461965c2354625303cdc Mon Sep 17 00:00:00 2001 From: Roy Frostig Date: Fri, 19 Mar 2021 19:13:45 -0700 Subject: [PATCH] fix broadcasted 1x1 cofactor solve, called by linalg.det jvp --- jax/_src/numpy/linalg.py | 2 +- tests/linalg_test.py | 6 ++++++ 2 files changed, 7 insertions(+), 1 deletion(-) diff --git a/jax/_src/numpy/linalg.py b/jax/_src/numpy/linalg.py index e4ec6be3c40b..7df4825ff7a6 100644 --- a/jax/_src/numpy/linalg.py +++ b/jax/_src/numpy/linalg.py @@ -202,7 +202,7 @@ def _cofactor_solve(a, b): "a=[..., m, m] and b=[..., m, m]; got a={} and b={}") raise ValueError(msg.format(a_shape, b_shape)) if a_shape[-1] == 1: - return a[0, 0], b + return a[..., 0, 0], b # lu contains u in the upper triangular matrix and l in the strict lower # triangular matrix. # The diagonal of l is set to ones without loss of generality. diff --git a/tests/linalg_test.py b/tests/linalg_test.py index bf016c92f2e0..eda71122827c 100644 --- a/tests/linalg_test.py +++ b/tests/linalg_test.py @@ -119,6 +119,12 @@ def testDetGrad(self, shape, dtype): a[0] = 0 jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1) + def testDetGradIssue6121(self): + f = lambda x: jnp.linalg.det(x).sum() + x = jnp.ones((16, 1, 1)) + jax.grad(f)(x) + jtu.check_grads(f, (x,), 2, atol=1e-1, rtol=1e-1) + def testDetGradOfSingularMatrixCorank1(self): # Rank 2 matrix with nonzero gradient a = jnp.array([[ 50, -30, 45],