Skip to content

Commit

Permalink
[hybrid] optimizer sharding support optimize cast (PaddlePaddle#35878)
Browse files Browse the repository at this point in the history
  • Loading branch information
wangxicoding authored and AnnaTrainingG committed Sep 29, 2021
1 parent 8abb977 commit ad33a69
Show file tree
Hide file tree
Showing 5 changed files with 440 additions and 54 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op
from paddle.fluid import core, unique_name
from .shard import Shard

__all__ = []

Expand All @@ -23,11 +25,8 @@ class OffloadHelper(object):
cuda_place_type = 1
cuda_pinned_place_type = 2

def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def __init__(self, ring_id=None):
self.ring_id = ring_id

def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
Expand All @@ -50,6 +49,21 @@ def _insert_cast_op(self, block, idx, src_name, dst_name):
OP_ROLE_KEY: OpRole.Optimize
})

def _insert_broadcast_op(self, block, idx, param):
if self.ring_id is None:
return
block._insert_op_without_sync(
idx,
type="c_broadcast",
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
})

def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
Expand Down Expand Up @@ -206,20 +220,25 @@ def remove_param(input_name):

# step5: startup_block add offload
visited_vars = set()
# FIXME(wangxi): should insert in idx, need move comm init to the head.
insert_idx = len(startup_block.ops)
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue

if out_name in param_name_to_offload_name:
var_name = out_name
# FIXME(wangxi): offload should insert after broadcast param
if offload:
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1,
self._insert_offload_op(startup_block, insert_idx,
var_name, offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: broadcast, cast, offload
self._insert_broadcast_op(startup_block, insert_idx,
var_name)

visited_vars.add(out_name)

Expand Down Expand Up @@ -303,3 +322,181 @@ def offload(self, block, startup_block):

block._sync_with_cpp()
startup_block._sync_with_cpp()

def opt_sharding_cast_fp32param(self,
block,
startup_block,
params,
offload=False):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(pout,) = adam(p)
(p_fp16) = cast(p)
broadcast(p_fp16)
"""
global_params = set()
local_params = set()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()

def remove_param(input_name):
global_params.remove(input_name)
if input_name in local_params:
local_params.remove(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)

# step1: record param
global_params = set(params)
for idx, op in reversed(list(enumerate(block.ops))):
if is_update_op(op):
param = op.desc.input("Param")[0]
local_params.add(param)

# step2: remove param which can't offload and
# record param->fp16param, fp16param->recompute_var
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
# TODO (Yuang Liu): tmp solution for fuse_grad_merge + optimize_cast
if op.type == 'coalesce_tensor':
continue
for input_name in op.desc.input_arg_names():
if input_name not in global_params:
continue

# param which will be used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue

# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue

if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param

param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if is_update_op(op):
param = op.desc.input("Param")[0]
if param not in global_params:
continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
if offload:
self._create_offload_var(param, offload_var_name,
[block, startup_block])

# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param,
offload_var_name)

assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])

if offload:
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)

continue

# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in global_params:
block._remove_op(idx, sync=False)
continue

# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])

# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)

# step5: remove fp32 param which not need
for idx, op in enumerate(block.ops):
if op.type not in ['coalesce_tensor', 'c_broadcast']:
continue
for input_name in op.desc.input_arg_names():
if input_name in param_to_fp16:
op._rename_input(input_name, param_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in param_to_fp16:
op._rename_output(output_name, param_to_fp16[output_name])

for param in global_params:
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True

if param not in local_params:
block._remove_var(param, sync=False)

# step6: startup_block add offload
visited_vars = set()
insert_idx = len(startup_block.ops)
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars: continue

if out_name in param_to_fp16:
var_name = out_name
if offload:
self._insert_offload_op(
startup_block, idx + 1, var_name,
param_name_to_offload_name[var_name])

self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])

self._insert_broadcast_op(startup_block, insert_idx,
var_name)

if var_name not in local_params:
param = startup_block.var(out_name)
param.persistable = False

visited_vars.add(out_name)

block._sync_with_cpp()
startup_block._sync_with_cpp()
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import paddle
from paddle.fluid import core, unique_name
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY

import re
Expand Down Expand Up @@ -366,6 +366,24 @@ def insert_allreduce_ops(block,


class FuseHelper(object):
@staticmethod
def sort_vars_by_dtype(block, vars_name):
fp32_vars = []
fp16_vars = []
other_vars = []
for var in vars_name:
dtype = block.var(var).dtype
if dtype == paddle.float32:
fp32_vars.append(var)
elif dtype == paddle.float16:
fp16_vars.append(var)
else:
other_vars.append(var)
assert len(other_vars) == 0, "only support fp32/fp16 vars for fuse"

fp32_vars.extend(fp16_vars)
return fp32_vars

@staticmethod
def get_fused_groups(block, vars_name, fuse_size=32.):
""" coalesce tensor, get fused group """
Expand Down Expand Up @@ -639,6 +657,54 @@ def insert_broadcast_param_ops(block,
return param_in_this_device


def fuse_opt_broadcast_param_ops(block,
ring_id,
shard,
op_role=OpRole.Optimize,
strategy=None):
"""
fuse optimizer sharding broadcast param ops
"""
if strategy is None or not strategy.fuse_all_reduce_ops:
return

fuse_size = strategy.fuse_grad_size_in_MB

nranks = shard.worker_num
device_to_vars = [[] for _ in range(nranks)]

for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op) or op.type != 'c_broadcast':
break
var = op.input_arg_names[0]
root_id = op.attr('root')
device_to_vars[root_id].insert(0, var)
block._remove_op(idx, sync=False)

insert_idx = idx + 1
for root_id, vars_name in enumerate(device_to_vars):
vars_name = FuseHelper.sort_vars_by_dtype(block, vars_name)
groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)

fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
block, insert_idx, groups, op_role, prefix="Param")

for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_broadcast',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'root': root_id,
'use_calc_stream': True,
OP_ROLE_KEY: op_role
})

block._sync_with_cpp()


def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
Expand Down
Loading

0 comments on commit ad33a69

Please sign in to comment.