Skip to content

Commit

Permalink
Register ModelOutput subclasses as supported torch.utils._pytree nodes (
Browse files Browse the repository at this point in the history
#25358)

* Register ModelOutput subclasses as supported torch.utils._pytree nodes

Fixes #25357 where DDP with static_graph=True does not sync gradients when calling backward() over tensors contained in ModelOutput subclasses

* Add test for torch pytree ModelOutput serialization and deserialization
  • Loading branch information
ringohoffman authored Aug 8, 2023
1 parent a23ac36 commit d4bd33c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/transformers/utils/generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,21 @@ class ModelOutput(OrderedDict):
</Tip>
"""

def __init_subclass__(cls) -> None:
"""Register subclasses as pytree nodes.
This is necessary to synchronize gradients when using `torch.nn.parallel.DistributedDataParallel` with
`static_graph=True` with modules that output `ModelOutput` subclasses.
"""
if is_torch_available():
import torch.utils._pytree

torch.utils._pytree._register_pytree_node(
cls,
torch.utils._pytree._dict_flatten,
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
)

def __post_init__(self):
class_fields = fields(self)

Expand Down
23 changes: 23 additions & 0 deletions tests/utils/test_model_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from dataclasses import dataclass
from typing import Optional

from transformers.testing_utils import require_torch
from transformers.utils import ModelOutput


Expand Down Expand Up @@ -120,3 +121,25 @@ def test_instantiate_from_iterator(self):
x = ModelOutputTest(a=(30, 30))
self.assertEqual(list(x.keys()), ["a"])
self.assertEqual(x.a, (30, 30))

@require_torch
def test_torch_pytree(self):
# ensure torch.utils._pytree treats ModelOutput subclasses as nodes (and not leaves)
# this is important for DistributedDataParallel gradient synchronization with static_graph=True
import torch
import torch.utils._pytree

x = ModelOutputTest(a=1.0, c=2.0)
self.assertFalse(torch.utils._pytree._is_leaf(x))

expected_flat_outs = [1.0, 2.0]
expected_tree_spec = torch.utils._pytree.TreeSpec(
ModelOutputTest, ["a", "c"], [torch.utils._pytree.LeafSpec(), torch.utils._pytree.LeafSpec()]
)

actual_flat_outs, actual_tree_spec = torch.utils._pytree.tree_flatten(x)
self.assertEqual(expected_flat_outs, actual_flat_outs)
self.assertEqual(expected_tree_spec, actual_tree_spec)

unflattened_x = torch.utils._pytree.tree_unflatten(actual_flat_outs, actual_tree_spec)
self.assertEqual(x, unflattened_x)

0 comments on commit d4bd33c

Please sign in to comment.