Skip to content

Commit

Permalink
update to new api
Browse files Browse the repository at this point in the history
  • Loading branch information
FeixLiu committed Nov 27, 2023
1 parent 1b6ed58 commit d63d4b8
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 33 deletions.
33 changes: 17 additions & 16 deletions python/paddle/distributed/auto_parallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,20 +484,15 @@ def shard_accumulator(param: paddle.Tensor) -> None:
key, param, accumulator, process_mesh
)
else:
sharding_specs = [None for _ in range(len(accumulator.shape))]
placements = [
dist.Replicate() for _ in range(len(process_mesh.shape))
]
if 'beta' not in key and param.is_dist():
# if param is a dist tensor, should keep the shard info
dims_mappings = param.dist_attr.dims_mapping
dim_names = param.dist_attr.process_mesh.dim_names
for i in range(len(param.shape)):
if dims_mappings[i] != -1:
sharding_specs[i] = dim_names[dims_mappings[i]]
dist_attr = dist.DistAttr(
mesh=process_mesh,
sharding_specs=sharding_specs,
)
# if param is a dist tensor, should keep the shard info for
# accumulators except beta
placements = param.placements
optimizer._accumulators[key][target_name] = shard_tensor(
accumulator, dist_attr=dist_attr
accumulator, mesh=process_mesh, placements=placements
)

if parameter_list is not None:
Expand Down Expand Up @@ -557,7 +552,11 @@ class ShardOptimizer:
"""

def __init__(
self, optimizer, process_mesh, shard_dims_name=None, gather_output=True
self,
optimizer,
process_mesh,
sharding_mesh_axis=None,
gather_output=True,
):
assert (
paddle.in_dynamic_mode()
Expand All @@ -576,7 +575,7 @@ def __init__(
), "The argument `process_mesh` is not `dist.ProcessMesh` type."

# TODO(Yuang Liu): support sharding parallel
assert shard_dims_name is None
assert sharding_mesh_axis is None
# if shard_dims_name is not None:
# assert isinstance(
# shard_dims_name, str
Expand All @@ -588,7 +587,7 @@ def __init__(
self.optimizer = optimizer
self.process_mesh = process_mesh
self.gather_output = gather_output
self.shard_dims_name = shard_dims_name
self.sharding_mesh_axis = sharding_mesh_axis

def clear_grad(self, set_to_zero=True):
self.optimizer.clear_grad(set_to_zero)
Expand All @@ -597,7 +596,9 @@ def _shard_fn(self, accumulator_name, param, accumulator, process_mesh):
pass

def step(self):
shard_fn = self._shard_fn if self.shard_dims_name is not None else None
shard_fn = (
self._shard_fn if self.sharding_mesh_axis is not None else None
)
if not isinstance(self.optimizer._parameter_list[0], dict):
params_grads = []
parameter_list = []
Expand Down
22 changes: 5 additions & 17 deletions test/auto_parallel/semi_auto_parallel_shard_optimizer_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,10 @@ def get_single_card_rst(self):

def shard_layer_fn(self, layer_name, layer, process_mesh):
layer.weight = dist.shard_tensor(
layer.weight,
dist_attr=dist.DistAttr(
mesh=process_mesh, sharding_specs=[None, 'x']
),
layer.weight, process_mesh, [dist.Shard(1)]
)
layer.bias = dist.shard_tensor(
layer.bias,
dist_attr=dist.DistAttr(mesh=process_mesh, sharding_specs=['x']),
layer.bias, process_mesh, [dist.Shard(0)]
)

def test_opt(self, opt):
Expand Down Expand Up @@ -110,18 +106,10 @@ def test_shard_optimizer_master_params(self):
assert v.shape[-1] == v._local_shape[-1] * 2

def shard_opt_fn(self, accumulator_name, param, accumulator, process_mesh):
sharding_specs = [None for _ in range(len(accumulator.shape))]
placements = [dist.Replicate() for _ in range(len(process_mesh.shape))]
if 'beta' not in accumulator_name and param.is_dist():
dims_mappings = param.dist_attr.dims_mapping
dim_names = param.dist_attr.process_mesh.dim_names
for i in range(len(param.shape)):
if dims_mappings[i] != -1:
sharding_specs[i] = dim_names[dims_mappings[i]]
dist_attr = dist.DistAttr(
mesh=process_mesh,
sharding_specs=sharding_specs,
)
return dist.shard_tensor(accumulator, dist_attr=dist_attr)
placements = param.placements
return dist.shard_tensor(accumulator, process_mesh, placements)

def test_shard_optimizer_shard_fn(self):
paddle.seed(self._seed)
Expand Down

0 comments on commit d63d4b8

Please sign in to comment.