-
Notifications
You must be signed in to change notification settings - Fork 505
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Lower aten::_linalg_eigh #7674
Lower aten::_linalg_eigh #7674
Conversation
cc @vanbasten23 |
if (!compute_v) { | ||
// Fallback to aten in case of `eigvalsh`, which does not compute | ||
// eigenvectors but requires numerically stable gradients. | ||
return at::native::call_fallback_fn<&xla_fallback, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I understand we only need to lower torch.linalg.eigh. But in case we need to lower eigvalsh later, would requires numerically stable gradients
be a blocker so we have to fall back to aten?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So I read the PyTorch docs again and I think I misunderstood it initially. What PyTorch doc suggests is that the gradients of the eigenvectors are unstable. Therefore, if the user calls eigvalsh
, they will only get eigenvalues and thus the gradients will be stable. I removed this misleading comment.
If we want to support eigvalsh
later the simplest way is probably discarding the eigenvectors from XLA and also figuring out what to return for the second at::Tensor
tuple member.
torch_xla/csrc/ops/eigh.cpp
Outdated
|
||
std::array<xla::XlaOp, 2> LowerImpl(xla::XlaOp input, bool lower) { | ||
auto [eigenvectors, eigenvalues] = | ||
xla::SelfAdjointEig(input, lower, /* max_iter */ 64, /* tol */ 1e-6); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the max_iter and tol (for opaque) used for?
Also, could you add a comment on why changing the default value?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When testing I discovered that the default settings lead to a very low accuracy in the reconstructed matrix (e.g. let's say we decompose A
to A' = V @ Q @ V_T
, then A
and A'
have a difference in elements above 0.1.
After looking at what JAX does I think it can be simpler to align with JAX: https://github.com/google/jax/blob/a8b425cac50c842f66f36903dfb93fe6ad5a2a5b/jax/_src/lax/linalg.py#L726. Looks like they use the same tolerance but higher max_iter.
There's an xla::SelfAdjointEig function so we lower it to that. I discovered that the XLA implementation of eigenvalue decomposition is not as numerically stable as numpy or torch, despite passing a small tolerance and large max_iter. The unit test thus uses a hardcoded tensor value copied from https://android.googlesource.com/platform/external/tensorflow/+/f2a058296dd/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc#149
Fixes #6017
There's an xla::SelfAdjointEig function so we lower it to that.
I discovered that the XLA implementation of eigenvalue decomposition is not as numerically stable as numpy or torch, despite passing a small tolerance and large max_iter. The unit test thus uses a hardcoded tensor value.