You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am not sure if this is an XLA bug or our bug. Consider this example in Axon which implements gradient accumulation:
defnpaccumulate_gradients(gradients,model_state,new_state,optimizer_state,gradient_state,gradient_step,update_optimizer_fn,opts\\[])doopts=keyword!(opts,[:steps])steps=opts[:steps]# TODO: this explodes the graphifNx.greater_equal(gradient_step,steps-1)do{updates,new_optimizer_state}=update_optimizer_fn.(gradients,optimizer_state,model_state)new_gradient_state=zeros_like(model_state)new_model_state=Axon.Updates.apply_updates(model_state,updates,new_state){new_model_state,new_optimizer_state,new_gradient_state,0}elseacc_gradients=deep_merge(gradient_state,gradients,fnx,y->x+yend){model_state,optimizer_state,acc_gradients,gradient_step+1}endend
Leaving this as is causes an Axon training loop to OOM with batch size 4 and sequence length 16 (maybe even lower than that), whereas if I remove the conditional logic altogether and just do the update I can run with 4x longer sequences or batch sizes
The text was updated successfully, but these errors were encountered:
I looked at the generated expressions and they are definitely much larger for this if, but also this might be a result more specific to my implementation
I am not sure if this is an XLA bug or our bug. Consider this example in Axon which implements gradient accumulation:
Leaving this as is causes an Axon training loop to OOM with batch size 4 and sequence length 16 (maybe even lower than that), whereas if I remove the conditional logic altogether and just do the update I can run with 4x longer sequences or batch sizes
The text was updated successfully, but these errors were encountered: