diff --git a/python/paddle/distributed/auto_parallel/api.py b/python/paddle/distributed/auto_parallel/api.py index e10e0cee0ceb99..70171edf51ed74 100644 --- a/python/paddle/distributed/auto_parallel/api.py +++ b/python/paddle/distributed/auto_parallel/api.py @@ -484,20 +484,15 @@ def shard_accumulator(param: paddle.Tensor) -> None: key, param, accumulator, process_mesh ) else: - sharding_specs = [None for _ in range(len(accumulator.shape))] + placements = [ + dist.Replicate() for _ in range(len(process_mesh.shape)) + ] if 'beta' not in key and param.is_dist(): - # if param is a dist tensor, should keep the shard info - dims_mappings = param.dist_attr.dims_mapping - dim_names = param.dist_attr.process_mesh.dim_names - for i in range(len(param.shape)): - if dims_mappings[i] != -1: - sharding_specs[i] = dim_names[dims_mappings[i]] - dist_attr = dist.DistAttr( - mesh=process_mesh, - sharding_specs=sharding_specs, - ) + # if param is a dist tensor, should keep the shard info for + # accumulators except beta + placements = param.placements optimizer._accumulators[key][target_name] = shard_tensor( - accumulator, dist_attr=dist_attr + accumulator, mesh=process_mesh, placements=placements ) if parameter_list is not None: @@ -557,7 +552,11 @@ class ShardOptimizer: """ def __init__( - self, optimizer, process_mesh, shard_dims_name=None, gather_output=True + self, + optimizer, + process_mesh, + sharding_mesh_axis=None, + gather_output=True, ): assert ( paddle.in_dynamic_mode() @@ -576,7 +575,7 @@ def __init__( ), "The argument `process_mesh` is not `dist.ProcessMesh` type." # TODO(Yuang Liu): support sharding parallel - assert shard_dims_name is None + assert sharding_mesh_axis is None # if shard_dims_name is not None: # assert isinstance( # shard_dims_name, str @@ -588,7 +587,7 @@ def __init__( self.optimizer = optimizer self.process_mesh = process_mesh self.gather_output = gather_output - self.shard_dims_name = shard_dims_name + self.sharding_mesh_axis = sharding_mesh_axis def clear_grad(self, set_to_zero=True): self.optimizer.clear_grad(set_to_zero) @@ -597,7 +596,9 @@ def _shard_fn(self, accumulator_name, param, accumulator, process_mesh): pass def step(self): - shard_fn = self._shard_fn if self.shard_dims_name is not None else None + shard_fn = ( + self._shard_fn if self.sharding_mesh_axis is not None else None + ) if not isinstance(self.optimizer._parameter_list[0], dict): params_grads = [] parameter_list = [] diff --git a/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py b/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py index 3b01221311f8f3..daaa3d8bdfd0ca 100644 --- a/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py +++ b/test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py @@ -44,14 +44,10 @@ def get_single_card_rst(self): def shard_layer_fn(self, layer_name, layer, process_mesh): layer.weight = dist.shard_tensor( - layer.weight, - dist_attr=dist.DistAttr( - mesh=process_mesh, sharding_specs=[None, 'x'] - ), + layer.weight, process_mesh, [dist.Shard(1)] ) layer.bias = dist.shard_tensor( - layer.bias, - dist_attr=dist.DistAttr(mesh=process_mesh, sharding_specs=['x']), + layer.bias, process_mesh, [dist.Shard(0)] ) def test_opt(self, opt): @@ -110,18 +106,10 @@ def test_shard_optimizer_master_params(self): assert v.shape[-1] == v._local_shape[-1] * 2 def shard_opt_fn(self, accumulator_name, param, accumulator, process_mesh): - sharding_specs = [None for _ in range(len(accumulator.shape))] + placements = [dist.Replicate() for _ in range(len(process_mesh.shape))] if 'beta' not in accumulator_name and param.is_dist(): - dims_mappings = param.dist_attr.dims_mapping - dim_names = param.dist_attr.process_mesh.dim_names - for i in range(len(param.shape)): - if dims_mappings[i] != -1: - sharding_specs[i] = dim_names[dims_mappings[i]] - dist_attr = dist.DistAttr( - mesh=process_mesh, - sharding_specs=sharding_specs, - ) - return dist.shard_tensor(accumulator, dist_attr=dist_attr) + placements = param.placements + return dist.shard_tensor(accumulator, process_mesh, placements) def test_shard_optimizer_shard_fn(self): paddle.seed(self._seed)