Skip to content

Commit

Permalink
[SPMD] Add apply_backward_optimization_barrier (#6097)
Browse files Browse the repository at this point in the history
Summary:
This pull request adds a new API to xla_sharding.py called apply_backward_optimization_barrier where registers a full backward hook that apply an optimization barrier to the given module. This API will prevent the XLA compiler from fusing the module's backward pass with others. And It's useful to prevent gigantic buffers being allocated to synchronize the gradients.

It's also used in pytorch-tpu/transformers#50.

Test Plan:
python test/spmd/test_xla_sharding.py -v -k test_backward_optimization_barrier
  • Loading branch information
alanwaketan authored Dec 12, 2023
1 parent eb3c446 commit 07540f2
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 0 deletions.
15 changes: 15 additions & 0 deletions test/spmd/test_xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,6 +1047,21 @@ def test_from_cpu_shards_global_shape(self):
with self.assertRaises(RuntimeError):
from_cpu_shards(shards, op_sharding, torch.Size((1,)))

def test_backward_optimization_barrier(self):
model = self.SimpleLinear().to(xm.xla_device())
# The first layer won't have gradients in the hook. Not sure why.
xs.xla_sharding.apply_backward_optimization_barrier(model.fc2)

x = torch.randn(2, 128).to(xm.xla_device())
y = model(x)
loss = y.sum()
loss.backward()

hlo = torch_xla._XLAC._get_xla_tensors_hlo([model.fc2.weight.grad])
self.assertIn(
'%opt-barrier.37 = (f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) opt-barrier((f32[1,64]{0,1}, f32[1]{0}, f32[2,64]{1,0}) %tuple.36)',
hlo)


if __name__ == '__main__':
test = unittest.main()
Expand Down
20 changes: 20 additions & 0 deletions torch_xla/distributed/spmd/xla_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,3 +640,23 @@ def backward(ctx, grad_output):

def xla_patched_nn_linear_forward(m, input):
return XLAPatchedLinear.apply(input, m.weight, m.bias)


def apply_backward_optimization_barrier(m: torch.nn.Module):
"""
Register a full backward hook that apply an optimization barrier to the given module.
This will prevent the XLA compiler from fusing the module's backward pass with others.
It's useful to prevent gigantic buffers being allocated to synchronize the gradients.
"""

def optimization_barrier(module, grad_input, grad_output):
from torch_xla.utils.checkpoint import CheckpointFunction
gradients = []
for param in module.parameters():
if param.grad != None:
gradients.append(param.grad)
xm.optimization_barrier_(
CheckpointFunction._extract_tensors_from_list(gradients +
list(grad_input)))

m.register_full_backward_hook(optimization_barrier)

0 comments on commit 07540f2

Please sign in to comment.