Skip to content

Commit 61b7449

Browse files
isururanawakafacebook-github-bot
authored andcommitted
Resharding API Performance Improvement (#3323)
Summary: Pull Request resolved: #3323 **Resharding API Implementation** {F1981530544} - Take current plan and next plan difference as an argument - Move shards to match the new plan based on the diff plan - Supports TW, CW sharding - Use async P2P(batch_isend_irecv) communication overlapping with data preparation for maximize latency hides - Use hierarchical communication (Inter node and Intra node communication groups for transfers) Reviewed By: aporialiao Differential Revision: D80647762 fbshipit-source-id: 5d00f406387e5bbc57862ca3623829ba8051df0c
1 parent c6918e6 commit 61b7449

File tree

4 files changed

+654
-647
lines changed

4 files changed

+654
-647
lines changed

torchrec/distributed/benchmark/benchmark_resharding_handler.py

Lines changed: 15 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,13 @@ def __init__(self, train_module: nn.Module, num_plans: int) -> None:
7676
rng = random.Random(index)
7777
ranks_per_tables_for_CW.append(rng.choice(valid_candidates))
7878

79+
lightweight_ebc = EmbeddingBagCollection(
80+
tables=module._embedding_bag_configs,
81+
device=torch.device(
82+
"meta"
83+
), # Use meta device to avoid actual memory allocation
84+
)
85+
meta_device = torch.device("meta")
7986
for i in range(num_plans):
8087
new_ranks = generate_rank_placements(
8188
world_size, num_tables, ranks_per_tables, i
@@ -98,23 +105,14 @@ def __init__(self, train_module: nn.Module, num_plans: int) -> None:
98105
)
99106
new_per_param_sharding[talbe_id] = tw_gen
100107

101-
lightweight_ebc = EmbeddingBagCollection(
102-
tables=module._embedding_bag_configs,
103-
device=torch.device(
104-
"meta"
105-
), # Use meta device to avoid actual memory allocation
106-
)
107-
108-
meta_device = torch.device("meta")
109-
new_plan = construct_module_sharding_plan(
110-
lightweight_ebc,
111-
per_param_sharding=new_per_param_sharding, # Pyre-ignore
112-
local_size=world_size,
113-
world_size=world_size,
114-
# Pyre-ignore
115-
device_type=meta_device,
116-
)
117-
self._resharding_plans.append(new_plan)
108+
new_plan = construct_module_sharding_plan(
109+
lightweight_ebc,
110+
per_param_sharding=new_per_param_sharding,
111+
world_size=world_size,
112+
# Pyre-ignore
113+
device_type=meta_device,
114+
)
115+
self._resharding_plans.append(new_plan)
118116
else:
119117
raise RuntimeError(f"Plan does not have key: {key}")
120118

torchrec/distributed/embeddingbag.py

Lines changed: 105 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from collections import defaultdict, OrderedDict
1212
from dataclasses import dataclass, field
1313
from functools import partial
14+
from itertools import zip_longest
1415
from typing import (
1516
Any,
1617
cast,
@@ -56,12 +57,12 @@
5657
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5758
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
5859
from torchrec.distributed.sharding.dynamic_sharding import (
59-
get_largest_dims_from_sharding_plan_updates,
60-
move_sharded_tensors_to_cpu,
61-
shards_all_to_all,
60+
CommP2PMetadata,
61+
CommStrategy,
62+
prepare_comm_ops,
63+
transfer_data,
6264
update_module_sharding_plan,
63-
update_optimizer_state_post_resharding,
64-
update_state_post_resharding,
65+
update_state_dictionaries,
6566
)
6667
from torchrec.distributed.sharding.grid_sharding import GridPooledEmbeddingSharding
6768
from torchrec.distributed.sharding.rw_sharding import RwPooledEmbeddingSharding
@@ -1378,24 +1379,10 @@ def _init_mean_pooling_callback(
13781379
device=self._device,
13791380
)
13801381

1381-
def _purge_lookups(self) -> None:
1382-
# Purge old lookups
1383-
for lookup in self._lookups:
1384-
# Call purge method if available (for TBE modules)
1385-
if hasattr(lookup, "purge") and callable(lookup.purge):
1386-
# Pyre-ignore
1387-
lookup.purge()
1388-
1389-
# For DDP modules, get the underlying module
1390-
while isinstance(lookup, DistributedDataParallel):
1391-
lookup = lookup.module
1392-
if hasattr(lookup, "purge") and callable(lookup.purge):
1393-
lookup.purge()
1394-
1395-
# Clear the lookups list
1382+
def _softcopy_lookups(self) -> List[nn.Module]:
1383+
old_modules: List[nn.Module] = [lookup for lookup in self._lookups]
13961384
self._lookups.clear()
1397-
# Force garbage collection to free memory
1398-
torch.cuda.empty_cache()
1385+
return old_modules
13991386

14001387
def _create_lookups(
14011388
self,
@@ -1727,77 +1714,58 @@ def update_shards(
17271714
device: Optional[torch.device],
17281715
) -> None:
17291716
"""
1730-
This is the main API used in sharder.reshard, currently only support redistribution
1731-
of existing shards (across different ranks, ideally from hot ranks to cold ranks)
1732-
Update shards for this module based on the changed_sharding_params. This will:
1733-
1. Move current lookup tensors to CPU
1734-
2. Purge lookups
1735-
3. Call shards_all_2_all containing collective to redistribute tensors
1736-
4. Update state_dict and other attributes to reflect new placements and shards
1737-
5. Create new lookups, and load in updated state_dict
1717+
Updates the sharded embedding module in place based on the changed_sharding_params,
1718+
which contains the new ParameterSharding with different shard placements.
1719+
1720+
This method handles resharding of embedding tables, optimizer state transfer,
1721+
and updates the internal lookup and distribution modules to reflect the new sharding.
17381722
17391723
Args:
17401724
changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
17411725
table names to their new parameter sharding configs. This should only
17421726
contain shards/table names that need to be moved.
1743-
env (ShardingEnv): The sharding environment for the module.
1727+
env (ShardingEnv): The sharding environment.
17441728
device (Optional[torch.device]): The device to place the updated module on.
1729+
1730+
Returns:
1731+
None
1732+
Raises:
1733+
RuntimeError: If DTensor output is enabled, as resharding is not yet supported for DTensor.
17451734
"""
17461735
if env.output_dtensor:
17471736
raise RuntimeError("We do not yet support DTensor for resharding yet")
17481737
return
17491738

17501739
current_state = self.state_dict()
1751-
current_state = move_sharded_tensors_to_cpu(current_state)
1752-
# TODO: improve, checking one would be enough
1740+
1741+
# Check if local optimizer state exists and is non-empty for all optimizers.
17531742
has_local_optimizer = len(self._optim._optims) > 0 and all(
17541743
len(i) > 0 for i in self._optim.state_dict()["state"].values()
17551744
)
17561745

1757-
# communicate optimizer state across all ranks, because if one rank owns all tables
1758-
# and other ranks does not own any table, and later transfer the weights to empty rank
1759-
# creates inconsistent state, because initally empty rank does not have optimizer state
1760-
# hence, incorrectly computes the tensor splits
1761-
1746+
# Communicate optimizer state across all ranks to ensure consistency.
17621747
has_optimizer = self._is_optimizer_enabled(has_local_optimizer, env, device)
17631748

1764-
# TODO: make sure this is clearing all lookups
1765-
self._purge_lookups()
1749+
# Save old lookup modules for cleanup.
1750+
old_lookups: List[nn.Module] = self._softcopy_lookups()
17661751

1767-
# Get max dim size to enable padding for all_to_all
1768-
max_dim_0, max_dim_1 = get_largest_dims_from_sharding_plan_updates(
1769-
changed_sharding_params
1770-
)
1752+
# Save old optimizer state if present.
17711753
old_optimizer_state = self._optim.state_dict() if has_local_optimizer else None
1772-
if old_optimizer_state is not None:
1773-
move_sharded_tensors_to_cpu(old_optimizer_state)
17741754

1775-
local_shard_names_by_src_rank, local_output_tensor_cpu = shards_all_to_all(
1776-
module=self,
1777-
state_dict=current_state,
1778-
device=device, # pyre-ignore
1779-
changed_sharding_params=changed_sharding_params,
1780-
env=env,
1781-
extend_shard_name=self.extend_shard_name,
1782-
max_dim_0=max_dim_0,
1783-
max_dim_1=max_dim_1,
1784-
optimizer_state=old_optimizer_state,
1785-
has_optimizer=has_optimizer,
1786-
)
1755+
assert hasattr(self, "module_sharding_plan")
1756+
current_module_sharding_plan = copy.deepcopy(self.module_sharding_plan)
17871757

1788-
for name, param in changed_sharding_params.items():
1789-
self.module_sharding_plan[name] = param
1790-
# TODO: Support detecting old sharding type when sharding type is changing
1791-
for sharding_info in self.sharding_type_to_sharding_infos[
1792-
param.sharding_type
1793-
]:
1794-
if sharding_info.embedding_config.name == name:
1795-
sharding_info.param_sharding = param
1758+
# Update the module sharding plan with the changed sharding parameters.
1759+
update_module_sharding_plan(
1760+
self, changed_sharding_params, self.sharding_type_to_sharding_infos
1761+
)
17961762

17971763
self._sharding_types: List[str] = list(
17981764
self.sharding_type_to_sharding_infos.keys()
17991765
)
18001766
# TODO: Optimize to update only the changed embedding shardings
1767+
1768+
# Recreate embedding sharding modules based on the new sharding infos.
18011769
self._embedding_shardings: List[
18021770
EmbeddingSharding[
18031771
EmbeddingShardingContext,
@@ -1816,7 +1784,7 @@ def update_shards(
18161784
for embedding_configs in self.sharding_type_to_sharding_infos.values()
18171785
]
18181786

1819-
# Reset input dists
1787+
# Reset input distribution and feature ordering.
18201788
self._has_uninitialized_input_dist = True
18211789
self._input_dists: List[nn.Module] = []
18221790
self._features_order: List[int] = []
@@ -1825,15 +1793,16 @@ def update_shards(
18251793
self._create_lookups()
18261794
self._update_output_dist()
18271795

1796+
# Re-initialize torch state if in a distributed environment.
18281797
if env.process_group and dist.get_backend(env.process_group) != "fake":
18291798
self._initialize_torch_state(skip_registering=True)
18301799

1831-
# update optimizer
1800+
# Update optimizer to reflect new parameters.
18321801
optims = []
18331802
for lookup in self._lookups:
18341803
for _, tbe_module in lookup.named_modules():
18351804
if isinstance(tbe_module, FusedOptimizerModule):
1836-
# modify param keys to match EmbeddingBagCollection
1805+
# Modify param keys to match EmbeddingBagCollection
18371806
params: Mapping[str, Union[torch.Tensor, ShardedTensor]] = {}
18381807
for (
18391808
param_key,
@@ -1845,31 +1814,82 @@ def update_shards(
18451814
optims.append(("", tbe_module.fused_optimizer))
18461815

18471816
self._optim: CombinedOptimizer = CombinedOptimizer(optims)
1817+
new_state = self.state_dict()
18481818

1849-
if has_optimizer:
1850-
optimizer_state = update_optimizer_state_post_resharding(
1851-
old_opt_state=old_optimizer_state, # pyre-ignore
1852-
new_opt_state=self._optim.state_dict(),
1853-
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1854-
output_tensor=local_output_tensor_cpu,
1855-
max_dim_0=max_dim_0,
1819+
optimizer_state: Dict[str, Dict[str, Dict[str, Any]]] = self._optim.state_dict()
1820+
1821+
# Prepare and execute communication operations for state transfer.
1822+
shard_keys = list(changed_sharding_params.keys())
1823+
comms_op: Dict[CommStrategy, List[CommP2PMetadata]] = {}
1824+
reqs: List[Tuple[dist.Work, CommP2PMetadata]] = []
1825+
# Pipeline for communication and computation overlapping
1826+
# move shards of current table while loading next table shards for communiucation
1827+
for i, (shard_name, nxt_shard_name) in enumerate(
1828+
zip_longest(shard_keys, shard_keys[1:])
1829+
):
1830+
if i == 0:
1831+
# Prepare communication P2P operations
1832+
comms_op = prepare_comm_ops(
1833+
module_sharding_plan=current_module_sharding_plan,
1834+
current_state_dict=current_state,
1835+
new_state_dict=new_state,
1836+
changed_sharding_params=changed_sharding_params,
1837+
shard_name=shard_name,
1838+
env=env,
1839+
current_opt_state=old_optimizer_state,
1840+
new_opt_state=optimizer_state,
1841+
extend_shard_name=self.extend_shard_name,
1842+
has_optimizer=has_optimizer,
1843+
)
1844+
1845+
if comms_op:
1846+
# call underlying batch_isend_irecv primitives
1847+
reqs = transfer_data(comms_op=comms_op)
1848+
1849+
if nxt_shard_name:
1850+
comms_op = prepare_comm_ops(
1851+
module_sharding_plan=current_module_sharding_plan,
1852+
current_state_dict=current_state,
1853+
new_state_dict=new_state,
1854+
changed_sharding_params=changed_sharding_params,
1855+
shard_name=nxt_shard_name,
1856+
env=env,
1857+
current_opt_state=old_optimizer_state,
1858+
new_opt_state=optimizer_state,
1859+
extend_shard_name=self.extend_shard_name,
1860+
has_optimizer=has_optimizer,
1861+
)
1862+
else:
1863+
break
1864+
# Update state and optimizer states
1865+
update_state_dictionaries(
1866+
reqs=reqs,
1867+
old_optimizer_state=old_optimizer_state,
1868+
new_optimizer_state=optimizer_state,
1869+
old_state=current_state,
1870+
new_state=new_state,
1871+
changed_sharding_params=changed_sharding_params,
18561872
extend_shard_name=self.extend_shard_name,
18571873
)
1858-
self._optim.load_state_dict(optimizer_state)
18591874

1860-
new_state = self.state_dict()
1861-
current_state = update_state_post_resharding(
1875+
update_state_dictionaries(
1876+
reqs=reqs,
1877+
old_optimizer_state=old_optimizer_state,
1878+
new_optimizer_state=optimizer_state,
18621879
old_state=current_state,
18631880
new_state=new_state,
1864-
ordered_shard_names_and_lengths=local_shard_names_by_src_rank,
1865-
output_tensor=local_output_tensor_cpu,
1881+
changed_sharding_params=changed_sharding_params,
18661882
extend_shard_name=self.extend_shard_name,
1867-
has_optimizer=has_optimizer,
1883+
update_local=True,
18681884
)
18691885

1870-
self.load_state_dict(current_state)
1871-
1872-
update_module_sharding_plan(self, changed_sharding_params)
1886+
# Clean up old lookup modules.
1887+
for lookup in old_lookups:
1888+
del lookup
1889+
old_lookups.clear()
1890+
self.load_state_dict(new_state, assign=True)
1891+
if has_optimizer:
1892+
self._optim.load_state_dict(optimizer_state)
18731893
return
18741894

18751895
def create_rocksdb_hard_link_snapshot(self) -> None:

0 commit comments

Comments
 (0)