Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

split optimization ops on pserver to independenty blocks #10123

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 23 additions & 22 deletions python/paddle/fluid/distribute_transpiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,21 +368,19 @@ def get_pserver_program(self, endpoint):
else:
recv_inputs.append(single_trainer_var)

# step3
optimize_block = pserver_program.create_block(0)
# step 4
# step 3
# Create a union-find data structure from optimize ops,
# If two ops are connected, we could add these two ops
# into one set.
ufind = self._create_ufind(self.optimize_ops)
# step 4.2
# step 3.2
# Iterate through the ops and append optimize op which
# located on current pserver
opt_op_on_pserver = []
for _, op in enumerate(self.optimize_ops):
if self._is_opt_op(op) and self._is_opt_op_on_pserver(endpoint, op):
opt_op_on_pserver.append(op)
# step 4.3
# step 3.3
# Iterate through the ops, and if an op and the optimize ops
# which located on current pserver are in one set, then
# append it into the sub program.
Expand Down Expand Up @@ -415,29 +413,30 @@ def __append_optimize_op__(op, block):
else:
self._append_pserver_non_opt_ops(block, op)

append_block = optimize_block
# append lr decay ops to the child block if exists
lr_ops = self._get_lr_ops()
if len(lr_ops) > 0:
lr_decay_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
for _, op in enumerate(lr_ops):
self._append_pserver_non_opt_ops(append_block, op)

append_block = pserver_program.create_block(append_block.idx)
self._append_pserver_non_opt_ops(lr_decay_block, op)

# append op to the current block
per_opt_block = append_block
pre_block_idx = pserver_program.num_blocks - 1
for idx, opt_op in enumerate(opt_op_on_pserver):
per_opt_block = pserver_program.create_block(pre_block_idx)
for _, op in enumerate(self.optimize_ops):
# optimizer is connected to itself
if ufind.is_connected(op, opt_op) and \
op not in global_ops:
if ufind.is_connected(op, opt_op) and op not in global_ops:
__append_optimize_op__(op, per_opt_block)
if idx == len(opt_op_on_pserver) - 1 and global_ops:
per_opt_block = pserver_program.create_block(append_block.idx)

# append global ops
for glb_op in global_ops:
__append_optimize_op__(glb_op, per_opt_block)
opt_state_block = None
if global_ops:
opt_state_block = pserver_program.create_block(
pserver_program.num_blocks - 1)
for glb_op in global_ops:
__append_optimize_op__(glb_op, opt_state_block)

# NOT USED: single block version:
#
Expand All @@ -451,10 +450,10 @@ def __append_optimize_op__(op, block):
prefetch_block = None
if self.has_distributed_lookup_table:
pserver_index = self.pserver_endpoints.index(endpoint)
self._create_table_optimize_block(pserver_index, pserver_program,
append_block)
table_opt_block = self._create_table_optimize_block(
pserver_index, pserver_program, pre_block_idx)
prefetch_block = self._create_prefetch_block(
pserver_index, pserver_program, optimize_block)
pserver_index, pserver_program, table_opt_block)

# NOTE: if has_distributed_lookup_table is False, then prefetch_block will
# not be executed, so it's safe to use optimize_block to hold the place
Expand All @@ -470,7 +469,7 @@ def __append_optimize_op__(op, block):
inputs={'X': recv_inputs},
outputs={},
attrs={
"OptimizeBlock": optimize_block,
"OptimizeBlock": pserver_program.block(1),
"endpoint": endpoint,
"Fanin": self.trainer_num,
"PrefetchBlock": prefetch_block
Expand Down Expand Up @@ -663,7 +662,7 @@ def _create_prefetch_block(self, pserver_index, pserver_program,
return prefetch_block

def _create_table_optimize_block(self, pserver_index, pserver_program,
append_block):
pre_block_idx):
def _clone_var(block, var, persistable=True):
assert isinstance(var, Variable)
return block.create_var(
Expand Down Expand Up @@ -700,7 +699,7 @@ def _clone_var(block, var, persistable=True):
op for op in self.optimize_ops
if op.input("Param")[0] == self.table_name
][0]
table_opt_block = pserver_program.create_block(append_block.idx)
table_opt_block = pserver_program.create_block(pre_block_idx)
# only support sgd now
assert table_opt_op.type == "sgd"

Expand All @@ -724,6 +723,8 @@ def _clone_var(block, var, persistable=True):
outputs=outputs,
attrs=table_opt_op.attrs)

return table_opt_block

# ====================== private transpiler functions =====================
def _create_vars_from_blocklist(self,
program,
Expand Down
4 changes: 4 additions & 0 deletions python/paddle/fluid/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,6 +1107,10 @@ def parse_from_string(binary_str):
def random_seed(self):
return self._seed

@property
def num_blocks(self):
return self.desc.num_blocks()

@random_seed.setter
def random_seed(self, seed):
if not isinstance(seed, int):
Expand Down