-
Hi, Typically in Jax you'd have something like: def loss_fn(params, data):
pass
value_fn, grad_fn = jax.value_and_grad(loss_fn)
grad_m = grad_fn(params)
grad = jax.lax.pmean(grad_m, axis_name='batch')
updates, opt_state = tx.update(grad, opt_state)
params = optax.apply_updates(params, updates) Internally, the backward produces And, in particular, while evaluating this say on the 8 cores of a TPU, I'd want to first do a What would be the easiest way to do this in Jax? |
Beta Was this translation helpful? Give feedback.
Replies: 3 comments 3 replies
-
Thanks for the question!
Is the function w2_grad = jax.grad(L, 1)(w1, w2, x)
w2 = w2_grad - 1e-3 * w2_grad # or whatever update
w1_grad = jax.grad(L, 0)(w1, w2, x)
w1 = w1_grad - 1e-3 * w1_grad There is likely some redundant work being done here. To avoid that, you can either put the computation under a def loss_fn(w1, w2, x):
x2 = f1(w1, x)
x3 = f2(w2, x2)
return g(x3) Then you might write something like def loss_fn_update(w1, w2, x):
x2, f1_vjp_0 = jax.vjp(lambda w1: f1(w1, x), w1)
x3, f2_vjp_0 = jax.vjp(lambda w2: f2(w2, x2), w2)
x3_grad = jax.grad(g)(x3)
w2_grad, = f2_vjp_0(x3_grad)
w2 = w2 - 1e-3 * w2_grad # or whatever update
x2_grad = jax.grad(lambda x2: g(f2(w2, x2)))(x2)
w1_grad, = f1_vjp_0(x2_grad)
w1 = w1 - 1e-3 * w1_grad
x_grad = jax.grad(lambda x: g(f2(w2, f1(w1, x))))(x)
return w1, w2, x_grad That can be abstracted a bit, but hopefully it gets the point across.
I didn't think about this part but I think it should be orthogonal to the autodiff question. What do you think? |
Beta Was this translation helpful? Give feedback.
-
Thanks @mattjj. Yes, that makes sense. I'm not that familiar with Jax internals, but what I was hoping for is to implement, in some lower level Jax api, a "custom" version of jax.value_and_grad and/or jax.vjp which expresses these two bits of logic (1 - use updated weights to compute the other gradients, and 2 - perform some collective reduction operation to get the updates to the weights). So, then I could just call (super hand-wavy) value_fn, imm_grad_update_fn = jax.value_and_immediate_grad(loss_fn, param_args=(0,))
params = imm_grad_update_fn(params, data, collective=jax.lax.pmean) This hypothetical function transformation needs to know which of the function's inputs are to be treated as weights, and also would need to have some way to designate the collective reduction to be applied. Mainly the point of this is to see whether this new way of computing "gradients" would speed up training. Theoretically, it could be implemented with the identical amount of computation. The only difference is the weight updates are performed at each step of backprop, as are the all-reduce. But, it would then be important for it to be implemented efficiently, otherwise it would defeat the purpose. Is this something that could be done so that I could then just plug in any model, the way you can plug in any function into If you could point me to the appropriate API level to begin to contemplate this, that would be much appreciated! Or, any pointers as to why it wouldn't work, or what would be the better option. But, mainly I would like to avoid any requirement for model-specific code. |
Beta Was this translation helpful? Give feedback.
-
Okay, upon studying your code in For context, I'm trying to make this idea work for my transformer model. It's complicated enough that I'd much rather keep the custom code at least tightly coupled with the "primal" functions. Would this idea work? Abuse the @custom_jvp mechanism as follows: @custom_jvp
def f(data, weights):
...
return out
def f_fwd(data, weights):
res = weights, ... # will need weights for f_bwd
return f(data, weights), res
def f_bwd(res, out_grads):
weights, ... = res
in_grad_weights_m = ...
in_grad_weights = jax.lax.pmean(in_grad_weights_m, axis_name='batch')
updated_weights = weights - learn_rate * in_grad_weights
in_grad_data = # some function of updated weights ...
return in_grad_data, updated_weights # I know, weird right?
f.defvjp(f_fwd, f_bwd) So, I'm returning the updated weights from f_bwd, rather than the weight gradients. But, maybe it doesn't matter, because the weights are technically leaf nodes of the computation graph of the backward pass. It's abusing the notion of bwd - I'd rather it be auxiliary data, but f_bwd cannot return auxiliary data. |
Beta Was this translation helpful? Give feedback.
Thanks for the question!
Is the function$L$ linear in $w$ ? If not, then you'll need to re-linearize the function. That is, if you had something like
loss_fn(w1, w2, x)
, you'd end up writing something likeThere is likely some redundant work being done here. To avoid that, you can either put the computation under a
ja…