-
Notifications
You must be signed in to change notification settings - Fork 16
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switched from reverse mode to forward mode where possible. #61
Conversation
As for complex conjugate, isn't that expected? The complex derivative is given by
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should fix the problem
optimistix/_search.py
Outdated
conj_residual = jtu.tree_map(jnp.conj, self.residual) | ||
conj_jac = lx.conj(self.jac) | ||
return ( | ||
0.5 | ||
* ( | ||
self.jac.transpose().mv(conj_residual) ** ω | ||
+ conj_jac.transpose().mv(self.residual) ** ω | ||
) | ||
).ω |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
conj_residual = jtu.tree_map(jnp.conj, self.residual) | |
conj_jac = lx.conj(self.jac) | |
return ( | |
0.5 | |
* ( | |
self.jac.transpose().mv(conj_residual) ** ω | |
+ conj_jac.transpose().mv(self.residual) ** ω | |
) | |
).ω | |
conj_jac = lx.conj(self.jac) | |
return (conj_jac.transpose().mv(self.residual) ** ω).ω |
Something similar is needed in grad_dot
. Maybe you want to pre-compute conj_jac
. Note again that this is df/dz*
, but that's the one you need for optimization. Alternatively, you can calculate df/dz
and conjugate later, there is no difference for real-valued functions
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For this one in particular my belief is that we really do have to compute two different terms: the derivative of the square absolute value |z|^2 = z conj(z)
is via the chain rule dz/dtheta conj(z) + z dconj(z)/dtheta
(for whatever theta
we are differentiating with respect to). I don't see a way to avoid that?
|
||
assert tree_allclose(grad2, true_grad2) | ||
assert tree_allclose((grad1.real, grad1.imag), grad2) | ||
assert tree_allclose(grad1, true_grad1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assert tree_allclose(grad1, true_grad1) | |
assert tree_allclose(grad1, jnp.conj(true_grad1)) |
It's in fact conjugate gradient though.
Indeed, gradient descent should be performed by performing steps in the direction of the conjugate Wirtinger derivative. That's an optimisation thing though, and that's not what we're doing here: we're just computing the gradient. And when it comes to computing the gradient, then JAX and PyTorch do different things: AFAIK JAX computes the (unconjugated) Wirtinger derivative, whilst PyTorch computes the conjugate Wirtinger derivative. So what we're seeing here is that we (Optimistix) are sometimes computing the conjugate Wirtinger derivative, and sometimes we're computing the (unconjugated) Wirtinger derivative. We should always compute the latter? |
When you calculate the derivative using real and imaginary parts it is up to you how to combine them. Jax returns
So there is an extra sign here. Same in calculating gradient, of course it depends on how you plan to use it but you sum two conjugate values so basically take the real part. |
Right, but I don't think that's the case: > jax.grad(lambda z: z**2, holomorphic=True)(jax.numpy.array(1+1j))
Array(2.+2.j, dtype=complex64, weak_type=True) contrast the same computation in PyTorch, which does additionally conjugate: > import torch
> x = torch.tensor(1+1j, requires_grad=True)
> y = x ** 2
> torch.autograd.backward(y, grad_tensors=(torch.tensor(1+0j),))
> x.grad
tensor(2.-2.j) So we still have a conjugation bug somewhere. |
Again, this is convention rather than bug. Pytorch returns |
I agree it's a convention! And it's a convention we're breaking. I believe we're returning (Do you agree we're returning |
Return where, sorry? Your current Maybe let's try to specify where exactly you think the bug is. What is the minimal failing test? |
Oh, I think I understand where the confusion comes from. More precisely, PyTorch returns |
@patrick-kidger are you satisfied with this answer? |
Sorry, this has been sitting on my backlog for a while! First of all! The thing which was so confusing to me -- and that I did not realise until now -- is that JAX actually uses a non-Wirtinger derivative for nonholomorphic functions! This is most easily seen with the square absolute value function: import jax
import jax.numpy as jnp
def f(z):
return (z * jnp.conj(z)).real
print(jax.grad(f)(1+1j))
# 2 - 2j In contrast the Wirtinger derivative of To be fair, PyTorch does the same thing, up to the difference in conjugate convention: import torch
x = torch.tensor(1+1j, requires_grad=True)
y = x.abs().pow(2)
y.backward()
print(x.grad)
# 2 + 2j (On which note, the fact that PyTorch computes As such I've now corrected much of this PR, essentially in line with your suggestions above -- along with copious comments explaining what's going on here! Basically it mostly boils down to the fact that given a function I'd like to thank you for your patience with me, your explanations above were really useful in clearing a lot of this up for me. There are still a couple of details that remain unresolved, and which I'd appreciate your input on.
WDYT? |
So, I think these two questions are basically the same or at least have the same motivation, e.g., optimistix/optimistix/_solver/trust_region.py Lines 264 to 276 in bc52e76
We have some function f: C->R , and we want to find the linear approximation of f . Since f is not holomorphic, the approximation f(z) = f(0) + df/dz(0) * z is invalid. The correct approximation is indeed one acquired from representing f as a function of real arguments z=x+iy ; then f(x,y) = f(0) + df/dx(0) * x + df/dy(0) * y . So, we should conjugate the gradient before doing the dot product, I think. This is not a mere convention anymore since after discarding imaginary part, some information is lost; I believe this is the "useful" value in the context of real-valued functions and specifically linear approximations (as a side note, this is exactly the reason pytorch calculates (df/dz)* ). Do you have any other applications of compute_grad_dot in mind?
An alternative is to add an argument to As for silent bugs, I'd expect the strict typing to fail on such a comparison. If it doesn't, I'd start from feature requesting it in Jax :) I think it is reasonable to expect that with strict typing, any binary operation of complex and real fails. Could it be that this part is untested? It may be the case that I've encountered this bug and didn't add the tests since they failed, and I didn't figure it out. |
I don't think so. But my hope is that this really can do what it states on the tin -- compute a dot product -- and that downstream consumers can then conjugate as appropriate. (In this case to compute a linear approximation.) I think that's probably important for keeping track of what everything does. What I haven't quite worked out is exactly where we need to put conjugates in order for that to be the case. As for strict typing, doing From my testing I believe it computes the inequality wrt the real parts of each array. I wonder if we could request a JAX level flag to error out if performing comparisons on complex numbers. |
If we do only "correct" As for strict typing, this bit indeed gave type error in #71 so it was just lack of tests |
Honestly I think I like |
For tests, the cleanest thing would be to add minus to real version I think if we're testing specific function |
Another problem that I'm currently not sure how to deal with is that the JVP of the C->R function is not (complex) linear, but rather linear in the real and imaginary parts separately. I think that means we need to rework |
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment). Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
bc52e76
to
a59ab5d
Compare
Okay! It seems like getting complex numbers working reliably might still be trickier than just the things we're facing here. What I'm going to do for now is to merge this PR to get the forward-mode improvements, and skip the failing complex parts of the tests for now. I've also added a warning around complex numbers whilst we figure out how best to handle them. |
This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment).
Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the new
test_least_squares.py::test_residual_jac
that I've added actually fails! I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!Here's some context:
r(y)
. The goal is then to optimise the minimisation problem0.5*||r(y)||^2
where||.||
denotes the 2-norm. Whilst you could do this viaoptx.minimise
, in practice there are specialised least-squares algorithms that exploit having access tor
. (Indeed computing this is what occurs inFunctionInfo.ResidualJac.as_min
, in cases where directly working with the residualsr
is not necessary.)f(y) = 0.5 * r(y)^T conj(r(y))
.FunctionInfo.ResidualJac.compute_grad
computes the derivativedf/dy = 0.5 * (r^T dconj(r)/dy + dr^T/dy conj(r))
. The jacobiandr/dy
is available asFunctionInfo.ResidualJac.jac
jax.grad
directly onFunctionInfo.ResidualJac.as_min
differ by a conjugate! This is the problem :(Function.ResidualJac.compute_grad_dot
that computes the dot-productdf/dy . z
against some vectorz
, i.e.df/dy^T conj(z)
. I haven't debugged/looked at this at all yet due to the earlier error.)Tagging @Randl -- do you have a clearer idea of what's going on here?