Skip to content

How to implement an immediately updated partial gradient #24265

Answered by mattjj
hrbigelow asked this question in General
Discussion options

You must be logged in to vote

Thanks for the question!

$\Delta w \equiv \dfrac{dL}{dw}$ and $\Delta a \equiv \dfrac{dL}{da}$ for each layer. $\Delta a$ is evaluated at the current setting for $w$, but what if I would instead like to evaluate it at the updated setting for $w$?

Is the function $L$ linear in $w$? If not, then you'll need to re-linearize the function. That is, if you had something like loss_fn(w1, w2, x), you'd end up writing something like

w2_grad = jax.grad(L, 1)(w1, w2, x)
w2 = w2_grad - 1e-3 * w2_grad  # or whatever update
w1_grad = jax.grad(L, 0)(w1, w2, x)
w1 = w1_grad - 1e-3 * w1_grad

There is likely some redundant work being done here. To avoid that, you can either put the computation under a ja…

Replies: 3 comments 3 replies

Comment options

You must be logged in to vote
2 replies
@mattjj
Comment options

@hrbigelow
Comment options

Answer selected by hrbigelow
Comment options

You must be logged in to vote
1 reply
@hrbigelow
Comment options

Comment options

You must be logged in to vote
0 replies
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
2 participants