-
Notifications
You must be signed in to change notification settings - Fork 229
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
value_and_grad(kernel_fn) not equal to kernel_fn with standard parameterization #123
Comments
Thanks for the repro and good find, it's indeed a bug in our custom differentiation rule for the square root, where we clip the derivative around zero, but also clipped the outputs accidentally as well. I've sent a change to fix it, but needs code review so will likely land tomorrow, in the meantime this is what the change looks like neural-tangents/neural_tangents/stax.py Line 4343 in 94e7498
def _sqrt_jvp(tol, primals, tangents):
x, = primals
x_dot, = tangents
safe_tol = max(tol, 1e-30)
square_root = _sqrt(x, safe_tol)
+ square_root_out = _sqrt(x, tol)
- return square_root, np.where(x > safe_tol, x_dot / (2 * square_root), 0.)
+ return square_root_out, np.where(x > safe_tol, x_dot / (2 * square_root), 0.) |
Wow thanks for quickly determining the issue! |
Hmm after pulling 8b7917f, I see that the values match now, but the |
Thanks, I'll need look into this, for the meantime, I suspect it's only happening for zero-value inputs and generally shouldn't be a problem otherwise (but perhaps I'm wrong, so worth double-checking to see if there are still nans or discrepancy in normal inputs like images etc) |
A much smaller example reproducing the from jax import *
from neural_tangents.stax import *
def f(x): return serial(Conv(1, (3, 3)), Relu(), Flatten())[2](x, x, "ntk")[0][0]
print(grad(f)(jax.numpy.zeros((1, 32, 32, 3)))) I guess this no longer has anything to do with |
I think it's probably the same issue, likely related to differentiating kernel functions with |
…7d01d65#diff-096654d44536fb53f7fee3c9c85f41ab9fedb894de1333cbdb39959b6b914fd6 I made a change to default to no bias variable instead of `b_std=0`, but in standard parameterization no bias variable (`b_std=None`) is different from zero-variance bias variable (`b_std=0`). This is also related to #123. Also update tests to catch this error (scale down `W_std`, which in standard parameterization dwarfed the bias contribution). Add absolute tolerance to testing and logging. Make `stax` tests deterministic. PiperOrigin-RevId: 421074614
…ble nonlinearities at 0 - #123. Currently we have NaNs and/or values at 0 that are inconsistent with the limit of finite-width empirical kernels. First, note that `np.sign(0) == 0`, therefore we must have `T_sign(0, x2) = T_sign(x1, 0) = T_sign(0, 0) = 0`, but we currently have it equal to 1, i.e. we assume incorrectly in our infinite width limit that `np.sign(0) == 1`. Secondly, JAX defines gradient of ABRelu at 0 to be (a + b) / 2, i.e. mean subgradient. This means that we must have `Tdot_abrelu(0, x2) = Tdot_abrelu(x1, 0) = Tdot_abrelu(0, 0) = [(a + b) / 2]^2`, but we currently have it equal to `(a^2 + b^2) / 2`, which is equivalent to assuming that the gradient is `[(a^2 + b^2) / 2)]^0.5`, i.e. for Relu the gradient at 0 is 1/2^0.5 instead of 1/2. We fix the above issues by extending `np.arctan2(0, 0) := np.pi / 2` (mathematically the function is undefined, and by default JAX/numpy have it be 0, but `np.pi / 2` gives us correct values above). Finally, we also extend the gradient of np.arctan2 at (0, 0) to (0, 0). The gradient at 0 is by default undefined, and earlier we had NaN gradients at zero inputs to nonlinearities. While the gradient can't be extended continuously at (0, 0), setting it to (0, 0) at least makes it continuous along `x = 0` or `y = 0`, and helps fix a lot of NaNs. Also add more tests, including comparisons of gradients of infinite-width kernels with MC estimates. Make matrix comparison tests fail on NaNs or infinities. PiperOrigin-RevId: 427073332
Thank you for your patience here! I think you were right that there were actually two bugs here. One was wrong treatment of biases with The other was that the derivative at Hope this helps! |
Thanks so much for the thorough fix! All of the gradient-related anomalies I've been seeing have gone away. I'll open new issues if I run into more problems in the future. |
I am confused by the behavior of the following snippet of code (the WideResNet from the README with standard parameterization):
My understanding is that the two printed values should be the same. However, when I run it, I get two totally different values:
Is my understanding correct? I have not yet found a simpler network that features this behavior.
Versions:
jax
0.2.20
jaxlib
0.1.71+cuda111
neural-tangents
0.3.7
The text was updated successfully, but these errors were encountered: