Skip to content

Commit

Permalink
Merge pull request #6144 from google:issue6121
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 364018304
  • Loading branch information
jax authors committed Mar 20, 2021
2 parents 4f8814a + e0b3ef0 commit f8c36d9
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 1 deletion.
2 changes: 1 addition & 1 deletion jax/_src/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,7 @@ def _cofactor_solve(a, b):
"a=[..., m, m] and b=[..., m, m]; got a={} and b={}")
raise ValueError(msg.format(a_shape, b_shape))
if a_shape[-1] == 1:
return a[0, 0], b
return a[..., 0, 0], b
# lu contains u in the upper triangular matrix and l in the strict lower
# triangular matrix.
# The diagonal of l is set to ones without loss of generality.
Expand Down
6 changes: 6 additions & 0 deletions tests/linalg_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ def testDetGrad(self, shape, dtype):
a[0] = 0
jtu.check_grads(jnp.linalg.det, (a,), 1, atol=1e-1, rtol=1e-1)

def testDetGradIssue6121(self):
f = lambda x: jnp.linalg.det(x).sum()
x = jnp.ones((16, 1, 1))
jax.grad(f)(x)
jtu.check_grads(f, (x,), 2, atol=1e-1, rtol=1e-1)

def testDetGradOfSingularMatrixCorank1(self):
# Rank 2 matrix with nonzero gradient
a = jnp.array([[ 50, -30, 45],
Expand Down

0 comments on commit f8c36d9

Please sign in to comment.