Skip to content

Commit

Permalink
add group operators (#48208)
Browse files Browse the repository at this point in the history
  • Loading branch information
Caozhou1995 authored Nov 22, 2022
1 parent df4dfda commit 48d5c36
Show file tree
Hide file tree
Showing 3 changed files with 396 additions and 0 deletions.
262 changes: 262 additions & 0 deletions python/paddle/distributed/auto_parallel/tuner/rule_based_tuner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,262 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


class OperatorGroupUtil:
common_starts = ["layer_norm", "matmul_v2", "matmul"]

@staticmethod
def get_ranks(seq):
"""Get rank array of the given seq by doubled algorithm."""
ordered_seq = sorted(list(set(seq)))
item_to_rank = {item: idx for idx, item in enumerate(ordered_seq)}
inter_ranks = [item_to_rank[item] for item in seq]

length = len(inter_ranks)
power = 0
interval = 2**power
while interval < length:
for idx, item in enumerate(inter_ranks):
if idx + interval >= length:
inter_ranks[idx] = [item, -1]
else:
inter_ranks[idx] = [item, inter_ranks[idx + interval]]

tmp = []
for item in inter_ranks:
if item not in tmp:
tmp.append(item)
tmp.sort(key=lambda x: (x[0], x[1]))
item_to_rank = {}
for idx, val in enumerate(tmp):
key = ",".join(str(item) for item in val)
item_to_rank[key] = idx

inter_ranks = [
item_to_rank[",".join(str(val) for val in item)]
for item in inter_ranks
]
power += 1
interval = 2**power

return inter_ranks

@staticmethod
def get_suffixes(ranks):
"""Get suffix array by the given rank array."""
suffixes = [0 for idx in range(len(ranks))]
for idx, item in enumerate(ranks):
suffixes[item] = idx
return suffixes

@staticmethod
def get_heights(suffixes, seq):
"""Get height array by the suffix array and seq"""
heights = [-1 for i in range(len(suffixes))]
for i in range(1, len(seq)):
x = seq[suffixes[i - 1] :]
y = seq[suffixes[i] :]
max_len = len(x) if len(x) > len(y) else len(y)
same_count = 0
for j in range(max_len):
if j >= len(x) or j >= len(y):
break
else:
if x[j] == y[j]:
same_count += 1
else:
break
heights[i] = same_count

return heights

@staticmethod
def get_longest_repeated_sub_seq(suffixes, heights, seq):
"""Get longest repeated sub sequence by suffix array algorithm."""
length = len(seq)
if length <= 1:
return None
k = length // 2
height_groups = []
longest_sub_seq = None
longest_sub_seqs = []

while k >= 2:
height_group = []
for i in range(1, len(heights)):
if heights[i] >= k:
if i == 1:
height_group.append(0)
height_group.append(i)
else:
if i == 1:
height_groups.append([0])
height_group = [i]
else:
height_groups.append(height_group)
height_group = [i]

if height_group:
height_groups.append(height_group)

for height_group in height_groups:
suffix_group = []
index_group = []
for idx in height_group:
suffix_group.append(idx)
index_group.append(suffixes[idx])

max_index = max(index_group)
min_index = min(index_group)
if max_index - min_index >= k:
longest_sub_seq = seq[min_index : min_index + k]
if longest_sub_seq[0] in OperatorGroupUtil.common_starts:
return longest_sub_seq
if longest_sub_seq is not None:
return longest_sub_seq

k -= 1
height_groups = []

return longest_sub_seq

@staticmethod
def get_decomposed_sub_seq(seq):
"""Get decomposed sub seq s by seq S such as s * R = S."""
if not seq:
return seq

decomposed_sub_seq = seq
seq_len = len(seq)
if seq_len == 1:
return decomposed_sub_seq
else:
for interval in range(2, seq_len + 1):
if seq_len % interval == 0:
repeated_times = seq_len // interval
decomposed_sub_seq = seq[0:interval]
decomposed = True
for j in range(1, repeated_times + 1):
sub_seq = seq[interval * (j - 1) : interval * j]
if sub_seq != decomposed_sub_seq:
decomposed = False
break
if decomposed:
return decomposed_sub_seq

return decomposed_sub_seq

@staticmethod
def replace_by_decomposed_seq(sub_seq, seq):
"""Replace seq by sub seq."""
if not sub_seq:
return seq

result = []
sub_seq_len = len(sub_seq)
i = 0
while i < len(seq):
if seq[i : i + sub_seq_len] == sub_seq:
result.append(seq[i : i + sub_seq_len])
i += sub_seq_len
else:
result.append(seq[i])
i += 1

return result

@staticmethod
def stop_replace(seq):
for item in seq:
if not isinstance(item, list):
return False
return True


class RuleBasedTuner:
def __init__(self, dist_context, mode="train"):
self._dist_context = dist_context
self._mode = mode

def group_operators(self, ops):
"""
Group operators to layers.
Args:
ops (list): A operator list.
Returns:
List: The list contains the list of operators which belong to the same layer.
"""
seq = [op.type for op in ops]

while not OperatorGroupUtil.stop_replace(seq):
to_replace_seq = []
to_replace_idxes = []
has_append = False
for idx, item in enumerate(seq):
if not isinstance(item, list):
has_append = True
to_replace_seq.append(item)
to_replace_idxes.append(idx)
elif isinstance(seq, list) and not has_append:
continue
elif isinstance(seq, list) and has_append:
break

ranks = OperatorGroupUtil.get_ranks(to_replace_seq)
suffixes = OperatorGroupUtil.get_suffixes(ranks)
heights = OperatorGroupUtil.get_heights(suffixes, to_replace_seq)
longest_sub_seq = OperatorGroupUtil.get_longest_repeated_sub_seq(
suffixes, heights, to_replace_seq
)
has_merged = False
if longest_sub_seq is None:
for i in range(to_replace_idxes[-1] + 1, len(seq)):
if isinstance(seq[i], list):
seq[i] = to_replace_seq + seq[i]
has_merged = True
break
if not has_merged:
for i in range(to_replace_idxes[0] - 1, -1, -1):
if isinstance(seq[i], list):
seq[i].extend(to_replace_seq)
has_merged = True
break
if not has_merged:
seq = [to_replace_seq]
break

decomposed_sub_seq = OperatorGroupUtil.get_decomposed_sub_seq(
longest_sub_seq
)
to_replace_seq = OperatorGroupUtil.replace_by_decomposed_seq(
decomposed_sub_seq, to_replace_seq
)
result = seq[: to_replace_idxes[0]]
if not has_merged:
result.extend(to_replace_seq)
result.extend(seq[to_replace_idxes[-1] + 1 :])
seq = result

layers = []
idx = 0
for groups in seq:
layer = []
for op in groups:
layer.append(ops[idx])
idx += 1
layers.append(layer)

return layers
Original file line number Diff line number Diff line change
Expand Up @@ -118,5 +118,6 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
test_conditional_block_reshard)
py_test_modules(test_engine_api_error MODULES test_engine_api_error)
py_test_modules(test_fp16_assign MODULES test_fp16_assign)
py_test_modules(test_group_operators MODULES test_group_operators)

endif()
Loading

0 comments on commit 48d5c36

Please sign in to comment.