Skip to content

Commit

Permalink
[fix] Fix activation checkpointing of SwiGLU when AMP is enabled. (#1152
Browse files Browse the repository at this point in the history
)

Without this fix the number of tensors saved during recomputation is equal to 0.
Moved at::AutoDispatchBelowADInplaceOrView guard after ctx->get_saved_variables().
ctx->get_saved_variables() is the call where the recomputation of the forward pass occurs.
  • Loading branch information
warpuv authored Nov 14, 2024
1 parent 210e32a commit a561291
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion xformers/csrc/swiglu/swiglu_packedw.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,11 @@ class SwiGLUPackedWeights
static torch::autograd::variable_list backward(
torch::autograd::AutogradContext* ctx,
torch::autograd::variable_list grad_outputs) {
auto saved = ctx->get_saved_variables();
at::AutoDispatchBelowADInplaceOrView g;

// Unpack variables
auto dx5 = grad_outputs[0];
auto saved = ctx->get_saved_variables();
auto x = saved[0];
auto w1w2 = saved[1];
auto w3 = saved[2];
Expand Down

0 comments on commit a561291

Please sign in to comment.