Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Lower aten::_linalg_eigh #7674

Merged
merged 1 commit into from
Jul 15, 2024
Merged

Lower aten::_linalg_eigh #7674

merged 1 commit into from
Jul 15, 2024

Conversation

tengyifei
Copy link
Collaborator

@tengyifei tengyifei commented Jul 12, 2024

Fixes #6017

There's an xla::SelfAdjointEig function so we lower it to that.

I discovered that the XLA implementation of eigenvalue decomposition is not as numerically stable as numpy or torch, despite passing a small tolerance and large max_iter. The unit test thus uses a hardcoded tensor value.

@tengyifei
Copy link
Collaborator Author

cc @vanbasten23

if (!compute_v) {
// Fallback to aten in case of `eigvalsh`, which does not compute
// eigenvectors but requires numerically stable gradients.
return at::native::call_fallback_fn<&xla_fallback,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I understand we only need to lower torch.linalg.eigh. But in case we need to lower eigvalsh later, would requires numerically stable gradients be a blocker so we have to fall back to aten?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So I read the PyTorch docs again and I think I misunderstood it initially. What PyTorch doc suggests is that the gradients of the eigenvectors are unstable. Therefore, if the user calls eigvalsh, they will only get eigenvalues and thus the gradients will be stable. I removed this misleading comment.

If we want to support eigvalsh later the simplest way is probably discarding the eigenvectors from XLA and also figuring out what to return for the second at::Tensor tuple member.


std::array<xla::XlaOp, 2> LowerImpl(xla::XlaOp input, bool lower) {
auto [eigenvectors, eigenvalues] =
xla::SelfAdjointEig(input, lower, /* max_iter */ 64, /* tol */ 1e-6);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the max_iter and tol (for opaque) used for?
Also, could you add a comment on why changing the default value?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When testing I discovered that the default settings lead to a very low accuracy in the reconstructed matrix (e.g. let's say we decompose A to A' = V @ Q @ V_T, then A and A' have a difference in elements above 0.1.

After looking at what JAX does I think it can be simpler to align with JAX: https://github.com/google/jax/blob/a8b425cac50c842f66f36903dfb93fe6ad5a2a5b/jax/_src/lax/linalg.py#L726. Looks like they use the same tolerance but higher max_iter.

There's an xla::SelfAdjointEig function so we lower it to that.

I discovered that the XLA implementation of eigenvalue decomposition
is not as numerically stable as numpy or torch, despite passing a
small tolerance and large max_iter. The unit test thus uses a
hardcoded tensor value copied from

https://android.googlesource.com/platform/external/tensorflow/+/f2a058296dd/tensorflow/compiler/xla/client/lib/self_adjoint_eig_test.cc#149
@tengyifei tengyifei merged commit f975ad6 into pytorch:master Jul 15, 2024
23 checks passed
@miladm miladm assigned miladm and tengyifei and unassigned miladm Jul 16, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

not lowered: aten::_linalg_eigh
4 participants