diff --git a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py index 8cd770698f9bcd..737fb4c51c8d88 100755 --- a/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py +++ b/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py @@ -163,8 +163,10 @@ def _gen_layer_weight(self, layername): def uniform(self, num_items, num_parts): result = [0 for _ in range(num_parts + 1)] part_size = math.floor(num_items / num_parts) - for i in range(num_parts): - result[i] = int(min(part_size * i, num_items)) + extra_layers = num_items % num_parts + for i in range(1, num_parts): + offset = 1 if i > (num_parts - extra_layers) else 0 + result[i] = int(min(result[i - 1] + part_size + offset, num_items)) result[num_parts] = num_items return result