11
11
from collections import defaultdict , OrderedDict
12
12
from dataclasses import dataclass , field
13
13
from functools import partial
14
+ from itertools import zip_longest
14
15
from typing import (
15
16
Any ,
16
17
cast ,
56
57
from torchrec .distributed .sharding .cw_sharding import CwPooledEmbeddingSharding
57
58
from torchrec .distributed .sharding .dp_sharding import DpPooledEmbeddingSharding
58
59
from torchrec .distributed .sharding .dynamic_sharding import (
59
- get_largest_dims_from_sharding_plan_updates ,
60
- move_sharded_tensors_to_cpu ,
61
- shards_all_to_all ,
60
+ CommP2PMetadata ,
61
+ CommStrategy ,
62
+ prepare_comm_ops ,
63
+ transfer_data ,
62
64
update_module_sharding_plan ,
63
- update_optimizer_state_post_resharding ,
64
- update_state_post_resharding ,
65
+ update_state_dictionaries ,
65
66
)
66
67
from torchrec .distributed .sharding .grid_sharding import GridPooledEmbeddingSharding
67
68
from torchrec .distributed .sharding .rw_sharding import RwPooledEmbeddingSharding
@@ -1378,24 +1379,10 @@ def _init_mean_pooling_callback(
1378
1379
device = self ._device ,
1379
1380
)
1380
1381
1381
- def _purge_lookups (self ) -> None :
1382
- # Purge old lookups
1383
- for lookup in self ._lookups :
1384
- # Call purge method if available (for TBE modules)
1385
- if hasattr (lookup , "purge" ) and callable (lookup .purge ):
1386
- # Pyre-ignore
1387
- lookup .purge ()
1388
-
1389
- # For DDP modules, get the underlying module
1390
- while isinstance (lookup , DistributedDataParallel ):
1391
- lookup = lookup .module
1392
- if hasattr (lookup , "purge" ) and callable (lookup .purge ):
1393
- lookup .purge ()
1394
-
1395
- # Clear the lookups list
1382
+ def _softcopy_lookups (self ) -> List [nn .Module ]:
1383
+ old_modules : List [nn .Module ] = [lookup for lookup in self ._lookups ]
1396
1384
self ._lookups .clear ()
1397
- # Force garbage collection to free memory
1398
- torch .cuda .empty_cache ()
1385
+ return old_modules
1399
1386
1400
1387
def _create_lookups (
1401
1388
self ,
@@ -1727,77 +1714,58 @@ def update_shards(
1727
1714
device : Optional [torch .device ],
1728
1715
) -> None :
1729
1716
"""
1730
- This is the main API used in sharder.reshard, currently only support redistribution
1731
- of existing shards (across different ranks, ideally from hot ranks to cold ranks)
1732
- Update shards for this module based on the changed_sharding_params. This will:
1733
- 1. Move current lookup tensors to CPU
1734
- 2. Purge lookups
1735
- 3. Call shards_all_2_all containing collective to redistribute tensors
1736
- 4. Update state_dict and other attributes to reflect new placements and shards
1737
- 5. Create new lookups, and load in updated state_dict
1717
+ Updates the sharded embedding module in place based on the changed_sharding_params,
1718
+ which contains the new ParameterSharding with different shard placements.
1719
+
1720
+ This method handles resharding of embedding tables, optimizer state transfer,
1721
+ and updates the internal lookup and distribution modules to reflect the new sharding.
1738
1722
1739
1723
Args:
1740
1724
changed_sharding_params (Dict[str, ParameterSharding]): A dictionary mapping
1741
1725
table names to their new parameter sharding configs. This should only
1742
1726
contain shards/table names that need to be moved.
1743
- env (ShardingEnv): The sharding environment for the module .
1727
+ env (ShardingEnv): The sharding environment.
1744
1728
device (Optional[torch.device]): The device to place the updated module on.
1729
+
1730
+ Returns:
1731
+ None
1732
+ Raises:
1733
+ RuntimeError: If DTensor output is enabled, as resharding is not yet supported for DTensor.
1745
1734
"""
1746
1735
if env .output_dtensor :
1747
1736
raise RuntimeError ("We do not yet support DTensor for resharding yet" )
1748
1737
return
1749
1738
1750
1739
current_state = self .state_dict ()
1751
- current_state = move_sharded_tensors_to_cpu ( current_state )
1752
- # TODO: improve, checking one would be enough
1740
+
1741
+ # Check if local optimizer state exists and is non-empty for all optimizers.
1753
1742
has_local_optimizer = len (self ._optim ._optims ) > 0 and all (
1754
1743
len (i ) > 0 for i in self ._optim .state_dict ()["state" ].values ()
1755
1744
)
1756
1745
1757
- # communicate optimizer state across all ranks, because if one rank owns all tables
1758
- # and other ranks does not own any table, and later transfer the weights to empty rank
1759
- # creates inconsistent state, because initally empty rank does not have optimizer state
1760
- # hence, incorrectly computes the tensor splits
1761
-
1746
+ # Communicate optimizer state across all ranks to ensure consistency.
1762
1747
has_optimizer = self ._is_optimizer_enabled (has_local_optimizer , env , device )
1763
1748
1764
- # TODO: make sure this is clearing all lookups
1765
- self ._purge_lookups ()
1749
+ # Save old lookup modules for cleanup.
1750
+ old_lookups : List [ nn . Module ] = self ._softcopy_lookups ()
1766
1751
1767
- # Get max dim size to enable padding for all_to_all
1768
- max_dim_0 , max_dim_1 = get_largest_dims_from_sharding_plan_updates (
1769
- changed_sharding_params
1770
- )
1752
+ # Save old optimizer state if present.
1771
1753
old_optimizer_state = self ._optim .state_dict () if has_local_optimizer else None
1772
- if old_optimizer_state is not None :
1773
- move_sharded_tensors_to_cpu (old_optimizer_state )
1774
1754
1775
- local_shard_names_by_src_rank , local_output_tensor_cpu = shards_all_to_all (
1776
- module = self ,
1777
- state_dict = current_state ,
1778
- device = device , # pyre-ignore
1779
- changed_sharding_params = changed_sharding_params ,
1780
- env = env ,
1781
- extend_shard_name = self .extend_shard_name ,
1782
- max_dim_0 = max_dim_0 ,
1783
- max_dim_1 = max_dim_1 ,
1784
- optimizer_state = old_optimizer_state ,
1785
- has_optimizer = has_optimizer ,
1786
- )
1755
+ assert hasattr (self , "module_sharding_plan" )
1756
+ current_module_sharding_plan = copy .deepcopy (self .module_sharding_plan )
1787
1757
1788
- for name , param in changed_sharding_params .items ():
1789
- self .module_sharding_plan [name ] = param
1790
- # TODO: Support detecting old sharding type when sharding type is changing
1791
- for sharding_info in self .sharding_type_to_sharding_infos [
1792
- param .sharding_type
1793
- ]:
1794
- if sharding_info .embedding_config .name == name :
1795
- sharding_info .param_sharding = param
1758
+ # Update the module sharding plan with the changed sharding parameters.
1759
+ update_module_sharding_plan (
1760
+ self , changed_sharding_params , self .sharding_type_to_sharding_infos
1761
+ )
1796
1762
1797
1763
self ._sharding_types : List [str ] = list (
1798
1764
self .sharding_type_to_sharding_infos .keys ()
1799
1765
)
1800
1766
# TODO: Optimize to update only the changed embedding shardings
1767
+
1768
+ # Recreate embedding sharding modules based on the new sharding infos.
1801
1769
self ._embedding_shardings : List [
1802
1770
EmbeddingSharding [
1803
1771
EmbeddingShardingContext ,
@@ -1816,7 +1784,7 @@ def update_shards(
1816
1784
for embedding_configs in self .sharding_type_to_sharding_infos .values ()
1817
1785
]
1818
1786
1819
- # Reset input dists
1787
+ # Reset input distribution and feature ordering.
1820
1788
self ._has_uninitialized_input_dist = True
1821
1789
self ._input_dists : List [nn .Module ] = []
1822
1790
self ._features_order : List [int ] = []
@@ -1825,15 +1793,16 @@ def update_shards(
1825
1793
self ._create_lookups ()
1826
1794
self ._update_output_dist ()
1827
1795
1796
+ # Re-initialize torch state if in a distributed environment.
1828
1797
if env .process_group and dist .get_backend (env .process_group ) != "fake" :
1829
1798
self ._initialize_torch_state (skip_registering = True )
1830
1799
1831
- # update optimizer
1800
+ # Update optimizer to reflect new parameters.
1832
1801
optims = []
1833
1802
for lookup in self ._lookups :
1834
1803
for _ , tbe_module in lookup .named_modules ():
1835
1804
if isinstance (tbe_module , FusedOptimizerModule ):
1836
- # modify param keys to match EmbeddingBagCollection
1805
+ # Modify param keys to match EmbeddingBagCollection
1837
1806
params : Mapping [str , Union [torch .Tensor , ShardedTensor ]] = {}
1838
1807
for (
1839
1808
param_key ,
@@ -1845,31 +1814,82 @@ def update_shards(
1845
1814
optims .append (("" , tbe_module .fused_optimizer ))
1846
1815
1847
1816
self ._optim : CombinedOptimizer = CombinedOptimizer (optims )
1817
+ new_state = self .state_dict ()
1848
1818
1849
- if has_optimizer :
1850
- optimizer_state = update_optimizer_state_post_resharding (
1851
- old_opt_state = old_optimizer_state , # pyre-ignore
1852
- new_opt_state = self ._optim .state_dict (),
1853
- ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1854
- output_tensor = local_output_tensor_cpu ,
1855
- max_dim_0 = max_dim_0 ,
1819
+ optimizer_state : Dict [str , Dict [str , Dict [str , Any ]]] = self ._optim .state_dict ()
1820
+
1821
+ # Prepare and execute communication operations for state transfer.
1822
+ shard_keys = list (changed_sharding_params .keys ())
1823
+ comms_op : Dict [CommStrategy , List [CommP2PMetadata ]] = {}
1824
+ reqs : List [Tuple [dist .Work , CommP2PMetadata ]] = []
1825
+ # Pipeline for communication and computation overlapping
1826
+ # move shards of current table while loading next table shards for communiucation
1827
+ for i , (shard_name , nxt_shard_name ) in enumerate (
1828
+ zip_longest (shard_keys , shard_keys [1 :])
1829
+ ):
1830
+ if i == 0 :
1831
+ # Prepare communication P2P operations
1832
+ comms_op = prepare_comm_ops (
1833
+ module_sharding_plan = current_module_sharding_plan ,
1834
+ current_state_dict = current_state ,
1835
+ new_state_dict = new_state ,
1836
+ changed_sharding_params = changed_sharding_params ,
1837
+ shard_name = shard_name ,
1838
+ env = env ,
1839
+ current_opt_state = old_optimizer_state ,
1840
+ new_opt_state = optimizer_state ,
1841
+ extend_shard_name = self .extend_shard_name ,
1842
+ has_optimizer = has_optimizer ,
1843
+ )
1844
+
1845
+ if comms_op :
1846
+ # call underlying batch_isend_irecv primitives
1847
+ reqs = transfer_data (comms_op = comms_op )
1848
+
1849
+ if nxt_shard_name :
1850
+ comms_op = prepare_comm_ops (
1851
+ module_sharding_plan = current_module_sharding_plan ,
1852
+ current_state_dict = current_state ,
1853
+ new_state_dict = new_state ,
1854
+ changed_sharding_params = changed_sharding_params ,
1855
+ shard_name = nxt_shard_name ,
1856
+ env = env ,
1857
+ current_opt_state = old_optimizer_state ,
1858
+ new_opt_state = optimizer_state ,
1859
+ extend_shard_name = self .extend_shard_name ,
1860
+ has_optimizer = has_optimizer ,
1861
+ )
1862
+ else :
1863
+ break
1864
+ # Update state and optimizer states
1865
+ update_state_dictionaries (
1866
+ reqs = reqs ,
1867
+ old_optimizer_state = old_optimizer_state ,
1868
+ new_optimizer_state = optimizer_state ,
1869
+ old_state = current_state ,
1870
+ new_state = new_state ,
1871
+ changed_sharding_params = changed_sharding_params ,
1856
1872
extend_shard_name = self .extend_shard_name ,
1857
1873
)
1858
- self ._optim .load_state_dict (optimizer_state )
1859
1874
1860
- new_state = self .state_dict ()
1861
- current_state = update_state_post_resharding (
1875
+ update_state_dictionaries (
1876
+ reqs = reqs ,
1877
+ old_optimizer_state = old_optimizer_state ,
1878
+ new_optimizer_state = optimizer_state ,
1862
1879
old_state = current_state ,
1863
1880
new_state = new_state ,
1864
- ordered_shard_names_and_lengths = local_shard_names_by_src_rank ,
1865
- output_tensor = local_output_tensor_cpu ,
1881
+ changed_sharding_params = changed_sharding_params ,
1866
1882
extend_shard_name = self .extend_shard_name ,
1867
- has_optimizer = has_optimizer ,
1883
+ update_local = True ,
1868
1884
)
1869
1885
1870
- self .load_state_dict (current_state )
1871
-
1872
- update_module_sharding_plan (self , changed_sharding_params )
1886
+ # Clean up old lookup modules.
1887
+ for lookup in old_lookups :
1888
+ del lookup
1889
+ old_lookups .clear ()
1890
+ self .load_state_dict (new_state , assign = True )
1891
+ if has_optimizer :
1892
+ self ._optim .load_state_dict (optimizer_state )
1873
1893
return
1874
1894
1875
1895
def create_rocksdb_hard_link_snapshot (self ) -> None :
0 commit comments