Skip to content

Commit

Permalink
[HybridParallel]Add segment methods for pipelineparallel (#53344)
Browse files Browse the repository at this point in the history
  • Loading branch information
ForFishes authored Apr 26, 2023
1 parent 35f5c24 commit c59debe
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,37 @@ def __init__(
), "layer number should be greater than number of segments"

def do_segment(self):
if self.method == "uniform":

if isinstance(self.method, list):
seg_method = self.method[:]
source_num_parts = len(seg_method) - 1

def check_sanity():
assert seg_method[0] == 0, "seg_method[0] should be 0"
for part in seg_method:
assert isinstance(part, int), "part should be int"
assert part >= 0, f"part[{part}] should be greater than 0"
assert (
part <= self.num_items
), "part[{}] should be less than num_items[{}]".format(
part, self.num_items
)

check_sanity()

if self.num_parts == source_num_parts + 1:
seg_method.append(self.num_items)
return seg_method
elif self.num_parts == source_num_parts:
return seg_method
else:
raise ValueError(
"We set seg_method as {}, this length is {}, but the number of stages is {}".format(
seg_method, len(seg_method), self.num_parts
)
)

elif self.method == "uniform":
return self.uniform(self.num_items, self.num_parts)

elif self.method.startswith('layer:'):
Expand Down Expand Up @@ -144,6 +174,8 @@ def do_segment(self):
memory_counter = 0
result[actual_num_parts] = len(weights)
return result
else:
raise ValueError(f"method {self.method} is not supported")

def _gen_layer_weight(self, layername):
weight_idxs = []
Expand Down
14 changes: 14 additions & 0 deletions python/paddle/fluid/tests/unittests/hybrid_parallel_pp_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,20 @@ def test_pipelayer_sequential(self):
np.testing.assert_array_equal(param_a.name, param_b.name)
np.testing.assert_allclose(param_a.numpy(), param_b.numpy())

def test_pipelayer_segment_method(self):
init_net = AlexNetPipe()
pipe_model = PipelineLayer(
layers=init_net.to_layers(),
num_stages=self.pipeline_parallel_size,
seg_method=[0, 4],
loss_fn=nn.CrossEntropyLoss(),
)
stage_id = self.hcg.get_stage_id()
if stage_id == 0:
np.testing.assert_array_equal(len(pipe_model.parameters()), 4)
elif stage_id == 1:
np.testing.assert_array_equal(len(pipe_model.parameters()), 8)


if __name__ == '__main__':
unittest.main()

0 comments on commit c59debe

Please sign in to comment.