Skip to content

Commit

Permalink
[Auto Parallel] Improve the dist op interface and the compatible comp…
Browse files Browse the repository at this point in the history
…utation (#39014)

* Add the backward support for QR

* Remove unnecessary comments

* [Auto Parallel] Improve the dist op interface and compatible computation

* Remove unnecessary modification

* Recover some modifications

* Add lost files

* Fix a minor bug

* Fix the bug of the planner

* Fix the format problem
  • Loading branch information
aoyulong authored Jan 20, 2022
1 parent 2a9c993 commit 9acc26c
Show file tree
Hide file tree
Showing 19 changed files with 808 additions and 761 deletions.
72 changes: 24 additions & 48 deletions python/paddle/distributed/auto_parallel/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -353,30 +353,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=True)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=True)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
else:
for tensor_node in op_node.outputs:
if tensor_node.var() is not None:
Expand All @@ -399,30 +387,18 @@ def update_op_node_dims_mapping(dist_context, op_node, fwd=True):
tensor_desc.name(), compatible_dims_mapping)
changed = True
# Find the most compatible implemenetations from the distributed operator
op_dist_impl, op_dist_impl_idx = find_best_compatible_distributed_operator_impl(
op_desc.type(), dist_op, fwd=False)
if op_dist_impl is not None:
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
# This statement will be replaced by a good way
if op_dist_impl.is_compatible(dist_op):
op_dist_attr.impl_type = op_desc.type()
op_dist_attr.impl_idx = op_dist_impl_idx
elif is_elementwise_like_op(op_desc.type()):
dim_changed = update_op_dims_mapping_by_elementwise_like_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "element-wise"
op_dist_attr.impl_idx = -1
else:
dim_changed = update_op_dims_mapping_by_default_dist_impl(
dist_context, op_node)
if dim_changed:
changed = True
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = -2
op_dist_impl = find_best_compatible_distributed_operator_impl(
dist_op, fwd=False)
assert op_dist_impl is not None, "Cannot find the dist op implementation."
dim_changed = op_dist_impl.update_dims_mapping(dist_op)
if dim_changed:
changed = True
if op_dist_impl.is_auto_compatible(dist_op):
if op_dist_impl.type == "elementwise":
op_dist_attr.impl_type = "default"
else:
op_dist_attr.impl_type = op_dist_impl.type
op_dist_attr.impl_idx = op_dist_impl.idx
return changed


Expand Down
92 changes: 90 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,8 @@ def __init__(self, program=None):
# Other data members
self._dist_op_context = DistributedOperatorContext()
self._process_meshes = []
self._serial_ordered_nodes = []
self._tensor_id_to_tensor_node_ids = {}

# Distributed programs
self._dist_main_programs = {}
Expand All @@ -80,6 +82,10 @@ def serial_program(self, program):
"This distributed context has already been realted to a serial program"
self._serial_program = program

@property
def serial_ordered_nodes(self):
return self._serial_ordered_nodes

@property
def process_meshes(self):
return self._process_meshes
Expand Down Expand Up @@ -186,6 +192,18 @@ def get_tensor_dist_attr_for_graph(self, serial_tensor_node):
else:
return None

# def set_tensor_dist_attr_for_graph(self, serial_tensor_node, dist_attr):
# assert serial_tensor_node.is_var() and \
# serial_tensor_node.var() is not None
# serial_tensor_id = serial_tensor_node.node.original_desc_id()
# dist_tensor = self._dist_tensors_for_program.get(serial_tensor_id, None)
# assert dist_tensor is not None, \
# "The distributed tensor of the program has not been added to this context."
# serial_tensor_node_id = serial_tensor_node.id()
# new_dist_tensor = DistributedTensor(dist_tensor.serial_tensor,
# dist_attr)
# self._dist_tensors_for_graph[serial_tensor_node_id] = new_dist_tensor

def get_op_dist_attr_for_program(self, serial_op):
serial_op_id = serial_op.desc.id()
dist_op = self._dist_ops_for_program.get(serial_op_id, None)
Expand Down Expand Up @@ -218,6 +236,35 @@ def get_op_dist_attr_for_graph(self, serial_op_node):
else:
return None

# def set_op_dist_attr_for_graph(self, serial_op_node, dist_attr):
# assert serial_op_node.is_op() and \
# serial_op_node.op() is not None
# serial_op_id = serial_op_node.node.original_desc_id()
# dist_op = self._dist_ops_for_program.get(serial_op_id, None)
# assert dist_op is not None, \
# "The distributed operator of the program has not been added to this context."
# serial_op_node_id = serial_op_node.id()
# new_dist_op = DistributedOperator(dist_op.serial_op, dist_attr)
# self._dist_ops_for_graph[serial_op_node_id] = new_dist_op

# def get_dist_attr_for_graph(self, serial_node):
# if serial_node.is_var() and serial_node.var() is not None:
# serial_tensor_node_id = serial_node.id()
# dist_tensor = self._dist_tensors_for_graph.get(
# serial_tensor_node_id, None)
# if dist_tensor:
# return dist_tensor.dist_attr
# else:
# return None
# if serial_node.is_op() and serial_node.op() is not None:
# serial_op_node_id = serial_node.id()
# dist_op = self._dist_ops_for_graph.get(serial_op_node_id, None)
# if dist_op:
# return dist_op.dist_attr
# else:
# return None
# return None

def init_dist_attr_for_program(self):
assert self._serial_program, \
"Please set the program of this context before initializing its distribute attributes."
Expand Down Expand Up @@ -248,6 +295,44 @@ def init_dist_attr_for_program(self):
self.add_dist_op_for_program(dist_op)
self._is_initialized_for_program = True

def order_nodes_by_program_order(self):
def _contains(nodes, target_node):
for node in nodes:
if node.id() == target_node.id():
return True
return False

ordered_tensor_nodes = []
ordered_op_nodes = []
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
if node.is_var() and node.var() is not None:
ordered_tensor_nodes.append(node)
if node.is_op() and node.op() is not None:
ordered_op_nodes.append(node)
ordered_tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
ordered_op_nodes.sort(key=lambda node: node.node.original_desc_id())
for op_node in ordered_op_nodes:
tensor_nodes = []
for tensor_node in op_node.inputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
tensor_nodes.sort(key=lambda node: node.node.original_desc_id())
self._serial_ordered_nodes.extend(tensor_nodes)
self._serial_ordered_nodes.append(op_node)
tensor_nodes = []
for tensor_node in op_node.outputs:
if tensor_node.is_var() \
and tensor_node.var() is not None \
and not _contains(self._serial_ordered_nodes, tensor_node):
tensor_nodes.append(tensor_node)
self._serial_ordered_nodes.extend(tensor_nodes)
num_nodes_before = len(ordered_tensor_nodes) + len(ordered_op_nodes)
assert len(self._serial_ordered_nodes) == num_nodes_before, \
"The number of nodes before ordering is not the same after ordering."

def init_dist_attr_for_graph(self):
assert self._is_initialized_for_program, \
"The program must be initialized before initializing the distributed attributes for its graph."
Expand All @@ -257,7 +342,8 @@ def init_dist_attr_for_graph(self):
self._serial_graph = framework.IrGraph(
core.Graph(self._serial_program.desc))
all_nodes = self._serial_graph.all_nodes()
for node in all_nodes:
self.order_nodes_by_program_order()
for node in self.serial_ordered_nodes:
if node.is_var() and node.var() is not None:
dist_tensor = None
tensor_id = node.node.original_desc_id()
Expand Down Expand Up @@ -397,7 +483,9 @@ def __deepcopy__(self, memo):
result = cls.__new__(cls)
memo[id(self)] = result
for k, v in self.__dict__.items():
if k == "_serial_program" or k == "_serial_graph" or k == "_dist_main_programs" or k == "_dist_startup_programs":
if k == "_serial_program" or k == "_serial_graph" \
or k == "_dist_main_programs" or k == "_dist_startup_programs" \
or k == "_serial_ordered_nodes":
setattr(result, k, v)
else:
setattr(result, k, copy.deepcopy(v, memo))
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/distributed/auto_parallel/dist_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _init_default_dist_attr(self):
if self._dist_attr.impl_type is None:
self._dist_attr.impl_type = "default"
if self._dist_attr.impl_idx is None:
self._dist_attr.impl_idx = -2
self._dist_attr.impl_idx = 0
if self._dist_attr.is_recompute is None:
self._dist_attr.is_recompute = False

Expand Down Expand Up @@ -217,7 +217,8 @@ def __str__(self):

str += ", pipeline stage: {}".format(None)

str += ", dist_impl idx: {} }}".format(self.dist_attr._impl_idx)
str += ", dist_impl idx: {} , dist_impl type {} }}".format(
self.dist_attr._impl_idx, self.dist_attr._impl_type)

return str

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@
from . import dist_softmax
from . import dist_transpose
from . import dist_default
from . import dist_eltwise
from . import dist_check_finite_and_unscale
from . import dist_update_loss_scaling
Loading

0 comments on commit 9acc26c

Please sign in to comment.