Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switched from reverse mode to forward mode where possible. #61

Merged
merged 1 commit into from
Aug 17, 2024

Conversation

patrick-kidger
Copy link
Owner

This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment).

Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the new test_least_squares.py::test_residual_jac that I've added actually fails! I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!

Here's some context:

  • a least-squares problem returns a "residual vector" r(y). The goal is then to optimise the minimisation problem 0.5*||r(y)||^2 where ||.|| denotes the 2-norm. Whilst you could do this via optx.minimise, in practice there are specialised least-squares algorithms that exploit having access to r. (Indeed computing this is what occurs in FunctionInfo.ResidualJac.as_min, in cases where directly working with the residuals r is not necessary.)
  • In the complex case this minimisation problem can be written f(y) = 0.5 * r(y)^T conj(r(y)).
  • The new method FunctionInfo.ResidualJac.compute_grad computes the derivative df/dy = 0.5 * (r^T dconj(r)/dy + dr^T/dy conj(r)). The jacobian dr/dy is available as FunctionInfo.ResidualJac.jac
  • The test provided implements the same problem using both complex and real (real+imag held separate) numbers. Weirdly, the two reference implementations themselves differ -- not just the version tested in Optimistix! To be precise, the values returned by calling jax.grad directly on FunctionInfo.ResidualJac.as_min differ by a conjugate! This is the problem :(
  • (There is also a method Function.ResidualJac.compute_grad_dot that computes the dot-product df/dy . z against some vector z, i.e. df/dy^T conj(z). I haven't debugged/looked at this at all yet due to the earlier error.)

Tagging @Randl -- do you have a clearer idea of what's going on here?

@Randl
Copy link
Contributor

Randl commented May 26, 2024

As for complex conjugate, isn't that expected? The complex derivative is given by d/dz=d/dx-i*d/dy
See, e.g., https://pytorch.org/docs/stable/notes/autograd.html#complex-autograd-doc or https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html#complex-numbers-and-differentiation

We can use grad to optimize functions, like real-valued loss functions of complex parameters x, by taking steps in the direction of the conjugate of grad(f)(x).

Copy link
Contributor

@Randl Randl left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should fix the problem

Comment on lines 167 to 175
conj_residual = jtu.tree_map(jnp.conj, self.residual)
conj_jac = lx.conj(self.jac)
return (
0.5
* (
self.jac.transpose().mv(conj_residual) ** ω
+ conj_jac.transpose().mv(self.residual) ** ω
)
).ω
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
conj_residual = jtu.tree_map(jnp.conj, self.residual)
conj_jac = lx.conj(self.jac)
return (
0.5
* (
self.jac.transpose().mv(conj_residual) ** ω
+ conj_jac.transpose().mv(self.residual) ** ω
)
).ω
conj_jac = lx.conj(self.jac)
return (conj_jac.transpose().mv(self.residual) ** ω).ω

Something similar is needed in grad_dot. Maybe you want to pre-compute conj_jac. Note again that this is df/dz*, but that's the one you need for optimization. Alternatively, you can calculate df/dz and conjugate later, there is no difference for real-valued functions

Copy link
Owner Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For this one in particular my belief is that we really do have to compute two different terms: the derivative of the square absolute value |z|^2 = z conj(z) is via the chain rule dz/dtheta conj(z) + z dconj(z)/dtheta (for whatever theta we are differentiating with respect to). I don't see a way to avoid that?


assert tree_allclose(grad2, true_grad2)
assert tree_allclose((grad1.real, grad1.imag), grad2)
assert tree_allclose(grad1, true_grad1)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
assert tree_allclose(grad1, true_grad1)
assert tree_allclose(grad1, jnp.conj(true_grad1))

It's in fact conjugate gradient though.

@patrick-kidger
Copy link
Owner Author

We can use grad to optimize functions, like real-valued loss functions of complex parameters x, by taking steps in the direction of the conjugate of grad(f)(x).

Indeed, gradient descent should be performed by performing steps in the direction of the conjugate Wirtinger derivative. That's an optimisation thing though, and that's not what we're doing here: we're just computing the gradient.

And when it comes to computing the gradient, then JAX and PyTorch do different things: AFAIK JAX computes the (unconjugated) Wirtinger derivative, whilst PyTorch computes the conjugate Wirtinger derivative.

So what we're seeing here is that we (Optimistix) are sometimes computing the conjugate Wirtinger derivative, and sometimes we're computing the (unconjugated) Wirtinger derivative. We should always compute the latter?

@Randl
Copy link
Contributor

Randl commented May 29, 2024

When you calculate the derivative using real and imaginary parts it is up to you how to combine them. Jax returns d/dz=d/dx-i*d/dy and you have a tuple of d/dx, d/dy:

assert tree_allclose((grad1.real, grad1.imag), grad2)

So there is an extra sign here.

Same in calculating gradient, of course it depends on how you plan to use it but you sum two conjugate values so basically take the real part.

@patrick-kidger
Copy link
Owner Author

When you calculate the derivative using real and imaginary parts it is up to you how to combine them. Jax returns d/dz=d/dx-i*d/dy

Right, but I don't think that's the case:

> jax.grad(lambda z: z**2, holomorphic=True)(jax.numpy.array(1+1j))
Array(2.+2.j, dtype=complex64, weak_type=True)

contrast the same computation in PyTorch, which does additionally conjugate:

> import torch
> x = torch.tensor(1+1j, requires_grad=True)
> y = x ** 2
> torch.autograd.backward(y, grad_tensors=(torch.tensor(1+0j),))
> x.grad
tensor(2.-2.j)

So we still have a conjugation bug somewhere.

@Randl
Copy link
Contributor

Randl commented Jun 15, 2024

Again, this is convention rather than bug. Pytorch returns df/dz* while Jax returns df/dz.

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Jun 15, 2024

I agree it's a convention! And it's a convention we're breaking. I believe we're returning df/dz*, despite being in JAX, and that therefore we should be returning a df/dz.

(Do you agree we're returning df/dz*? Or do you think we're returning df/dz?)

@Randl
Copy link
Contributor

Randl commented Jun 15, 2024

Return where, sorry? Your current compute_grad function computes the real part of the gradient, the real-valued version computes the conjugate gradient (more precisely real and imaginary parts of it), and Jax computes the regular gradient, as far as I can tell. If you want to follow the Jax conventions, the compute_grad should return a regular gradient, and the tests against the real version should have an extra minus sign.

Maybe let's try to specify where exactly you think the bug is. What is the minimal failing test?

@Randl
Copy link
Contributor

Randl commented Jun 16, 2024

Oh, I think I understand where the confusion comes from. More precisely, PyTorch returns (df/dz)*, which equals df/dz* only for C->R functions. If your z^2 function is part of a larger computation with real output, this is ok (but not in general).
In your example, f is C->C. In general, we can't represent gradient as a single number in this case, only if f is holomorphic (i.e., depend only on z and not z*) or anti-holomorphic. Your f is holomorphic, and indeed d/dz z^2 = 2z (what Jax returns), while d/dz* z^2 = 0. What PyTorch returns is (d/dz z^2)* = 2z*, which matches the convention.

@Randl
Copy link
Contributor

Randl commented Aug 3, 2024

@patrick-kidger are you satisfied with this answer?

@patrick-kidger
Copy link
Owner Author

patrick-kidger commented Aug 6, 2024

Sorry, this has been sitting on my backlog for a while!

First of all! The thing which was so confusing to me -- and that I did not realise until now -- is that JAX actually uses a non-Wirtinger derivative for nonholomorphic functions! This is most easily seen with the square absolute value function:

import jax
import jax.numpy as jnp

def f(z):
    return (z * jnp.conj(z)).real

print(jax.grad(f)(1+1j))
# 2 - 2j

In contrast the Wirtinger derivative of zz* is just z*, not 2z* as we have here.

To be fair, PyTorch does the same thing, up to the difference in conjugate convention:

import torch

x = torch.tensor(1+1j, requires_grad=True)
y = x.abs().pow(2)
y.backward()
print(x.grad)
# 2 + 2j

(On which note, the fact that PyTorch computes (df/dz)* and not df/d(z*) -- at least on holomorphic functions -- was not a detail I appreciated until now. In retrospect this should have been obvious as df/d(z*) = 0 for all holomorphic functions. Thank you for clearing that up for me!)

As such I've now corrected much of this PR, essentially in line with your suggestions above -- along with copious comments explaining what's going on here! Basically it mostly boils down to the fact that given a function f = u+iv: C -> C, that jax.grad computes du/dx - i du/dy, which agrees with the Wirtinger derivative df/dz = 0.5 (df/dx - df/dy) on holomorphic functions (by the Cauchy-Riemann equations), but is notably not the same on nonholomorphic functions. And it completely ignores the imaginary component. (I think you already know this, but it's helpful for me to write out for my own reference.)

I'd like to thank you for your patience with me, your explanations above were really useful in clearing a lot of this up for me.


There are still a couple of details that remain unresolved, and which I'd appreciate your input on.

  1. The gradient tests are now passing, hurrrah! However the final tests for the dot products are failing. Now the real part of a complex dot product is equal to the real dot product:
    (a + bi)* . (c + di) = ac + bd + i(ad - bc)
    (a, b) . (c, d) = ac + bd
    
    but we actually have an additional conjugate on top of this! JAX combines (du/dx, du/dy) into du/dx - i du/dy. This means that our dot product computations seem to be going awry. I'm not sure what change is appropriate to ensure that both (a) the dot products considered in isolation are doing the right thing, and (b) the downstream use cases of those dot products also do the right thing.
  2. Making the changes here has made me realise that I think we may still have some silent bugs. The one I can see in particular is here:
    predicted_reduction = f_info.compute_grad_dot(y_diff)
    # Terminate when the Armijo condition is satisfied. That is, `fn(y_eval)`
    # must do better than its linear approximation:
    # `fn(y_eval) < fn(y) + grad•y_diff`
    f_min = f_info.as_min()
    f_min_eval = f_eval_info.as_min()
    f_min_diff = f_min_eval - f_min # This number is probably negative
    satisfies_armijo = f_min_diff <= self.slope * predicted_reduction

    where predicted_reduction is actually a complex number -- but it should probably be a real number, especially given how we use it in a comparison later. I'm not confident that any of these interactions are doing the right thing.
    My feeling is that we would need to go through and make sure that many quantites we would like to be real are in fact real.

WDYT?

@Randl
Copy link
Contributor

Randl commented Aug 6, 2024

So, I think these two questions are basically the same or at least have the same motivation, e.g.,

The expected decrease in loss from moving from `y0` to `y0 + y_diff`.
"""
if isinstance(
f_info,
(
FunctionInfo.EvalGrad,
FunctionInfo.EvalGradHessian,
FunctionInfo.EvalGradHessianInv,
FunctionInfo.ResidualJac,
),
):
return f_info.compute_grad_dot(y_diff)

We have some function f: C->R, and we want to find the linear approximation of f. Since f is not holomorphic, the approximation f(z) = f(0) + df/dz(0) * z is invalid. The correct approximation is indeed one acquired from representing f as a function of real arguments z=x+iy; then f(x,y) = f(0) + df/dx(0) * x + df/dy(0) * y. So, we should conjugate the gradient before doing the dot product, I think. This is not a mere convention anymore since after discarding imaginary part, some information is lost; I believe this is the "useful" value in the context of real-valued functions and specifically linear approximations (as a side note, this is exactly the reason pytorch calculates (df/dz)*). Do you have any other applications of compute_grad_dot in mind?

An alternative is to add an argument to compute_grad_dot to decide whether grad is conjugated, keeping our options open for the future. It still makes sense to conjugate by default, I think, and then it doesn't really matter if we add the parameter now or later.

As for silent bugs, I'd expect the strict typing to fail on such a comparison. If it doesn't, I'd start from feature requesting it in Jax :) I think it is reasonable to expect that with strict typing, any binary operation of complex and real fails.

Could it be that this part is untested? It may be the case that I've encountered this bug and didn't add the tests since they failed, and I didn't figure it out.

@patrick-kidger
Copy link
Owner Author

Do you have any other applications of compute_grad_dot in mind?

I don't think so. But my hope is that this really can do what it states on the tin -- compute a dot product -- and that downstream consumers can then conjugate as appropriate. (In this case to compute a linear approximation.) I think that's probably important for keeping track of what everything does.

What I haven't quite worked out is exactly where we need to put conjugates in order for that to be the case.


As for strict typing, doing complex_array < complex_array will still pass under strict typing. Despite being a fairly questionable kind of thing to compute!

From my testing I believe it computes the inequality wrt the real parts of each array. I wonder if we could request a JAX level flag to error out if performing comparisons on complex numbers.

@Randl
Copy link
Contributor

Randl commented Aug 11, 2024

If we do only "correct" grad_dot, we'll need to do something like grad_dot(vec.conj()) downstream for minimization which is a bit awkward I think. We can also just implement grad_dot and conj_grad_dot.

As for strict typing, this bit indeed gave type error in #71 so it was just lack of tests

@patrick-kidger
Copy link
Owner Author

Honestly I think I like grad_dot(vec.conj()) as a solution, actually. (Well, tree_map(jnp.conj, vec).)
I think making the tests pass with the 'correct' grad_dot is something I'm still not super clear on. Are we just going to have to insert a manual conj there too?

@Randl
Copy link
Contributor

Randl commented Aug 11, 2024

For tests, the cleanest thing would be to add minus to real version I think if we're testing specific function

@Randl
Copy link
Contributor

Randl commented Aug 12, 2024

Another problem that I'm currently not sure how to deal with is that the JVP of the C->R function is not (complex) linear, but rather linear in the real and imaginary parts separately. I think that means we need to rework ResidualJac to work properly in complex case. Specifically, lineax is not expecting such a function I believe in construction of the operator.

This commit switches some functions that unnecessarily use reverse-mode autodiff to using forward-mode autodiff. In particular this is to fix #51 (comment).

Whilst I"m here, I noticed what looks like some incorrect handling of complex numbers. I've tried fixing those up, but at least as of this commit the test I've added fails. I've poked at this a bit but not yet been able to resolve this. It seems something is still awry!
@patrick-kidger
Copy link
Owner Author

Okay! It seems like getting complex numbers working reliably might still be trickier than just the things we're facing here.

What I'm going to do for now is to merge this PR to get the forward-mode improvements, and skip the failing complex parts of the tests for now. I've also added a warning around complex numbers whilst we figure out how best to handle them.

@patrick-kidger patrick-kidger merged commit 4fc86e0 into main Aug 17, 2024
2 checks passed
@patrick-kidger patrick-kidger deleted the forward-complex branch August 17, 2024 08:42
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants