Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[hybrid] optimizer sharding support optimize cast #35878

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import copy
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole, is_update_op
from paddle.fluid import core, unique_name
from .shard import Shard

__all__ = []

Expand All @@ -23,11 +25,8 @@ class OffloadHelper(object):
cuda_place_type = 1
cuda_pinned_place_type = 2

def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def __init__(self, ring_id=None):
self.ring_id = ring_id

def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
Expand All @@ -50,6 +49,21 @@ def _insert_cast_op(self, block, idx, src_name, dst_name):
OP_ROLE_KEY: OpRole.Optimize
})

def _insert_broadcast_op(self, block, idx, param):
if self.ring_id is None:
return
block._insert_op_without_sync(
idx,
type="c_broadcast",
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': self.ring_id,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
})

def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
Expand Down Expand Up @@ -206,20 +220,25 @@ def remove_param(input_name):

# step5: startup_block add offload
visited_vars = set()
# FIXME(wangxi): should insert in idx, need move comm init to the head.
insert_idx = len(startup_block.ops)
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue

if out_name in param_name_to_offload_name:
var_name = out_name
# FIXME(wangxi): offload should insert after broadcast param
if offload:
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1,
self._insert_offload_op(startup_block, insert_idx,
var_name, offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])
# NOTE(wangxi): cast and offload should insert after broadcast param.
# the insert op order is: broadcast, cast, offload
self._insert_broadcast_op(startup_block, insert_idx,
var_name)

visited_vars.add(out_name)

Expand Down Expand Up @@ -303,3 +322,181 @@ def offload(self, block, startup_block):

block._sync_with_cpp()
startup_block._sync_with_cpp()

def opt_sharding_cast_fp32param(self,
block,
startup_block,
params,
offload=False):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)

(pout,) = adam(p)
(p_fp16) = cast(p)
broadcast(p_fp16)
"""
global_params = set()
local_params = set()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()

def remove_param(input_name):
global_params.remove(input_name)
if input_name in local_params:
local_params.remove(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)

# step1: record param
global_params = set(params)
for idx, op in reversed(list(enumerate(block.ops))):
if is_update_op(op):
param = op.desc.input("Param")[0]
local_params.add(param)

# step2: remove param which can't offload and
# record param->fp16param, fp16param->recompute_var
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
# TODO (Yuang Liu): tmp solution for fuse_grad_merge + optimize_cast
if op.type == 'coalesce_tensor':
continue
for input_name in op.desc.input_arg_names():
if input_name not in global_params:
continue

# param which will be used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue

# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

better to use a global variable to record the 'cast_fp16' rule, otherwise if this pattern is change in AMP, we should change everywhere in sharding

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

get, good idea

remove_param(input_name)
continue

if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param

param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if is_update_op(op):
param = op.desc.input("Param")[0]
if param not in global_params:
continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
if offload:
self._create_offload_var(param, offload_var_name,
[block, startup_block])

# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param,
offload_var_name)

assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])

if offload:
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)

continue

# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in global_params:
block._remove_op(idx, sync=False)
continue

# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])

# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)

# step5: remove fp32 param which not need
for idx, op in enumerate(block.ops):
if op.type not in ['coalesce_tensor', 'c_broadcast']:
continue
for input_name in op.desc.input_arg_names():
if input_name in param_to_fp16:
op._rename_input(input_name, param_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in param_to_fp16:
op._rename_output(output_name, param_to_fp16[output_name])

for param in global_params:
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True

if param not in local_params:
block._remove_var(param, sync=False)

# step6: startup_block add offload
visited_vars = set()
insert_idx = len(startup_block.ops)
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars: continue

if out_name in param_to_fp16:
var_name = out_name
if offload:
self._insert_offload_op(
startup_block, idx + 1, var_name,
param_name_to_offload_name[var_name])

self._insert_cast_op(startup_block, insert_idx, var_name,
param_to_fp16[var_name])

self._insert_broadcast_op(startup_block, insert_idx,
var_name)

if var_name not in local_params:
param = startup_block.var(out_name)
param.persistable = False

visited_vars.add(out_name)

block._sync_with_cpp()
startup_block._sync_with_cpp()
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import paddle
from paddle.fluid import core, unique_name
from functools import reduce
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op
from paddle.distributed.fleet.meta_optimizers.common import is_loss_grad_op, is_backward_op, is_optimizer_op
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY

import re
Expand Down Expand Up @@ -366,6 +366,24 @@ def insert_allreduce_ops(block,


class FuseHelper(object):
@staticmethod
def sort_vars_by_dtype(block, vars_name):
fp32_vars = []
fp16_vars = []
other_vars = []
for var in vars_name:
dtype = block.var(var).dtype
if dtype == paddle.float32:
fp32_vars.append(var)
elif dtype == paddle.float16:
fp16_vars.append(var)
else:
other_vars.append(var)
assert len(other_vars) == 0, "only support fp32/fp16 vars for fuse"

fp32_vars.extend(fp16_vars)
return fp32_vars

@staticmethod
def get_fused_groups(block, vars_name, fuse_size=32.):
""" coalesce tensor, get fused group """
Expand Down Expand Up @@ -639,6 +657,54 @@ def insert_broadcast_param_ops(block,
return param_in_this_device


def fuse_opt_broadcast_param_ops(block,
ring_id,
shard,
op_role=OpRole.Optimize,
strategy=None):
"""
fuse optimizer sharding broadcast param ops
"""
if strategy is None or not strategy.fuse_all_reduce_ops:
return

fuse_size = strategy.fuse_grad_size_in_MB

nranks = shard.worker_num
device_to_vars = [[] for _ in range(nranks)]

for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op) or op.type != 'c_broadcast':
break
var = op.input_arg_names[0]
root_id = op.attr('root')
device_to_vars[root_id].insert(0, var)
block._remove_op(idx, sync=False)

insert_idx = idx + 1
for root_id, vars_name in enumerate(device_to_vars):
vars_name = FuseHelper.sort_vars_by_dtype(block, vars_name)
groups = FuseHelper.get_fused_groups(block, vars_name, fuse_size)

fused_vars, insert_num = FuseHelper.insert_coalesce_tensor(
block, insert_idx, groups, op_role, prefix="Param")

for fused_var in fused_vars:
block._insert_op_without_sync(
insert_idx + insert_num,
type='c_broadcast',
inputs={'X': fused_var},
outputs={'Out': fused_var},
attrs={
'ring_id': ring_id,
'root': root_id,
'use_calc_stream': True,
OP_ROLE_KEY: op_role
})

block._sync_with_cpp()


def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
Expand Down
Loading