Skip to content

Commit

Permalink
[AutoParallel] adapt gradient merge pass (PaddlePaddle#45915)
Browse files Browse the repository at this point in the history
* adapt gradient merge

* fix op_role

* fix strategy
  • Loading branch information
zhaoyinglia authored and aoyulong committed Sep 17, 2022
1 parent f6dd201 commit d694628
Show file tree
Hide file tree
Showing 8 changed files with 75 additions and 34 deletions.
25 changes: 18 additions & 7 deletions python/paddle/distributed/auto_parallel/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,10 @@ def __init__(self,
strategy=None,
user_tuning_config=None):
self.model = model
self.strategy = strategy or fleet.DistributedStrategy()
self.inputs_spec = self._validate_spec(inputs_spec)
self.labels_spec = self._validate_spec(labels_spec)
self.cluster = cluster
if self.cluster is None:
self.cluster = get_default_cluster()
self.strategy = strategy
if self.strategy is None:
self.strategy = fleet.DistributedStrategy()
self.cluster = cluster or get_default_cluster()
self._user_tuning_config = user_tuning_config

self._executor = None
Expand Down Expand Up @@ -433,7 +429,7 @@ def fit(self,
break

train_logs["step: {:d} "] = step
if lr_scheduler is not None:
if lr_scheduler is not None and step % self.k_steps == 0:
lr_scheduler.step()
try:
train_logs["lr: {:5e} "] = self._lr_optimizer.get_lr()
Expand Down Expand Up @@ -551,6 +547,12 @@ def _create_dataloader(self,
epochs=1,
steps_per_epoch=None,
collate_fn=None):

if self.strategy.gradient_merge and batch_size is not None:
assert batch_size % self.k_steps == 0, \
"Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(batch_size, self.k_steps)
batch_size //= self.k_steps

dist_main_prog = self._dist_main_progs[self.mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self.mode][self._cur_rank]
dist_context = self._dist_contexts[self.mode]
Expand Down Expand Up @@ -612,13 +614,22 @@ def _create_dataloader(self,

def _validate_spec(self, specs):
specs = to_list(specs)
self.k_steps = 1
if self.strategy.gradient_merge:
self.k_steps = self.strategy.gradient_merge_configs['k_steps']
if specs is not None:
for i, spec in enumerate(specs):
assert isinstance(spec, InputSpec)
if spec.name is None:
raise ValueError(
"Requires Input[{}].name != None, but receive `None` with {}."
.format(i, spec))
if self.k_steps > 1:
shape = list(spec.shape)
assert shape[0] % self.k_steps == 0, \
"Requires batch_size[{}] to be divisible by k_steps[{}].".format(spec.shape[0], self.k_steps)
shape[0] //= self.k_steps
spec.shape = shape
return specs

def _is_local_var(self, var):
Expand Down
16 changes: 13 additions & 3 deletions python/paddle/distributed/auto_parallel/parallelizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def __init__(self, fleet):
self._need_rank_mapping = os.getenv("PADDLE_NEED_RANK_MAPPING")
self._need_rank_mapping = True if self._need_rank_mapping and \
self._need_rank_mapping.lower() == 'true' else False
self._pass_context = None
# self._pass_context = None

def _remove_distributed_attrs(self, main_program):
suffix = core.kAutoParallelSuffix()
Expand Down Expand Up @@ -143,10 +143,11 @@ def _generate_backward(self, main_program, startup_program, loss,

def _apply_optimize(self, main_program, startup_program, params_grads):

optimizer = copy.deepcopy(self._optimizer)
with program_guard(main_program, startup_program):
optimize_ops = copy.deepcopy(
self._optimizer).apply_gradients(params_grads)
optimize_ops = optimizer.apply_gradients(params_grads)

self._dist_context._lr_optimizer = optimizer
# update completion
self._completer = Completer(self._dist_context)
self._completer.complete_update_annotation(main_program)
Expand All @@ -165,6 +166,15 @@ def _apply_post_optimization_passes(self, main_program, startup_program,
config)
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")

config = copy.deepcopy(self._dist_strategy.sharding_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
config["rank_id"] = rank
auto_parallel_clip_pass = new_pass("auto_parallel_grad_clip", config)
auto_parallel_clip_pass.apply([main_program], [startup_program],
self._pass_context)

if self._dist_strategy.gradient_merge:
config = copy.deepcopy(self._dist_strategy.gradient_merge_configs)
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/distributed/auto_parallel/parallelizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,9 @@ def _apply_post_optimization(self, main_program, startup_program, rank,
config)
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context)
params_grads = self._pass_context.get_attr("params_grads")

# GradClip is train-only optimization

if self._mode == "train":
config = copy.deepcopy(self._strategy.sharding_configs)
config["dist_context"] = self._dist_context
Expand Down
10 changes: 5 additions & 5 deletions python/paddle/distributed/passes/auto_parallel_fp16.py
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,7 @@ def _check_and_update_gradient(grads, loss_scaling, name, dist_context):

inputs = {'X': grads, 'Scale': loss_scaling}
outputs = {'Out': grads, 'FoundInfinite': found_inf}
attrs = {'op_role': OpRole.Backward}
attrs = {'op_role': OpRole.Optimize}
new_op = main_block.append_op(type='check_finite_and_unscale',
inputs=inputs,
outputs=outputs,
Expand Down Expand Up @@ -575,18 +575,18 @@ def _apply_single_impl(self, main_program, startup_program, context):
) or self.get_attr("init_loss_scaling") != 1.0:
found_infs = []
if fp32_grads:
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
_, found_inf_fp32 = _check_and_update_gradient(
fp32_grads, self._loss_scaling, "@fp32",
self.dist_context)
found_infs.append(found_inf_fp32)
if fp16_grads:
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
_, found_inf_fp16 = _check_and_update_gradient(
fp16_grads, self._loss_scaling, "@fp16",
self.dist_context)
found_infs.append(found_inf_fp16)
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
block = main_program.global_block()

all_infs = paddle.fluid.layers.concat(found_infs)
Expand All @@ -608,7 +608,7 @@ def _apply_single_impl(self, main_program, startup_program, context):
block, self.dist_context)

if self.get_attr("use_dynamic_loss_scaling"):
with main_program._backward_role_guard():
with main_program._optimized_guard([]):
if fp32_grads:
self._update_loss_scaling(fp32_grads, found_inf)
if fp16_grads:
Expand Down
6 changes: 5 additions & 1 deletion python/paddle/distributed/passes/auto_parallel_grad_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,13 +207,16 @@ def __init__(self):
super(ClipGradByGloblNormPass, self).__init__()
self.set_attr("rank_id", None)
self.set_attr("dist_context", None)
self.set_attr("params_grads", None)

def _check_self(self):
if self.get_attr("dist_context") is None:
return False
dist_context = self.get_attr("dist_context")
if dist_context._lr_optimizer._grad_clip is None:
return False
if self.get_attr("params_grads") is None:
return False
return True

def _check_conflict(self, other_pass):
Expand All @@ -223,7 +226,8 @@ def _apply_single_impl(self, main_program, startup_program, context):
dist_context = self.get_attr("dist_context", None)
rank_id = self.get_attr("rank_id", None)
block = main_program.global_block()
dist_params_grads = _get_params_grads(block)
dist_params_grads = self.get_attr("params_grads", None)
# dist_params_grads = _get_params_grads(block)

self.clip_helper = ClipHelper(dist_params_grads, rank_id, block,
dist_context)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,6 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
return optimize_ops_desc


def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker
op = grad.op
if op and op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName())


def _get_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block()
# Add const var
Expand Down Expand Up @@ -147,8 +140,6 @@ def _append_gradient_merge_backward_op(
param.type != core.VarDesc.VarType.SELECTED_ROWS
), "SELECTED_ROWS is not supported in GradientMergeOptimizer for now"

_remove_op_role_var(param, grad)

# {grad.name: gradient_merge_var.name} to rename opt inputs
grad_to_gradient_merge = {}
# {param: gradient_merge_var} to insert scale op and fill_constant op
Expand Down
38 changes: 31 additions & 7 deletions python/paddle/distributed/passes/auto_parallel_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def __init__(self):
self.varname_to_sharding_info = {}
self.partial_sharding = False
self.outer_dp_group = None
self.shared_params_grads = []

def _check_self(self):
if self.get_attr("dist_context") is None:
Expand Down Expand Up @@ -94,6 +95,8 @@ def _apply_single_impl(self, main_program, startup_program, context):
self._shard_gradient_synchronization(main_block)
self._shard_parameter(main_block, startup_block)

context.set_attr("params_grads", self.shared_params_grads)

def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads)
Expand Down Expand Up @@ -148,13 +151,10 @@ def _build_sharding_infos(self, params_grads):

self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
params_in_group = [p for p, g in params_grads]
assert len(params_in_group) == len(
set(params_in_group)), "found duplicated param in params_grads"
sharding_info = ShardingInfo(sharding_group, self.global_rank,
params_in_group)
params_grads)
self.sharding_infos.append(sharding_info)
for param in params_in_group:
for param in sharding_info.params:
self.varname_to_sharding_info[param.name] = sharding_info

def _shard_optimizer(self, main_block, startup_block, params_grads,
Expand Down Expand Up @@ -201,6 +201,7 @@ def _shard_amp_related_op_and_vars(self, main_block, pass_context):
op.desc.set_output('Out', reversed_x)
else:
if op.type == "check_finite_and_unscale":
op_role = op.attr('op_role')
out_name = op.output_arg_names[0]
out_var = main_block.vars[out_name]
main_block._remove_op(idx, sync=False)
Expand All @@ -212,6 +213,7 @@ def _shard_amp_related_op_and_vars(self, main_block, pass_context):
"shape": out_var.shape,
"dtype": out_var.dtype,
"value": 0,
OP_ROLE_KEY: op_role,
})
else:
main_block._remove_op(idx, sync=False)
Expand Down Expand Up @@ -313,6 +315,9 @@ def _shard_optimizer_ops_and_states(self, main_block, startup_block):
if varname != param_name
])
main_block._remove_op(idx, sync=False)
else:
self.shared_params_grads.append(
self._get_param_grad(param_name))

for idx, op in reversed(list(enumerate(startup_block.ops))):
if len(op.output_arg_names) == 1 and op.output_arg_names[
Expand Down Expand Up @@ -365,6 +370,13 @@ def _is_parameter_in_local_shard(self, param_name):
sharding_info = self.varname_to_sharding_info[param_name]
return sharding_info.is_in_local_shard(param_name)

def _get_param_grad(self, param_name):
assert param_name in self.varname_to_sharding_info
sharding_info = self.varname_to_sharding_info[param_name]
p_g = sharding_info.get_param_grad(param_name)
assert p_g is not None
return p_g

def _shard_gradient_synchronization(self, main_block):

if self.stage < 2:
Expand Down Expand Up @@ -705,9 +717,13 @@ def shard_parameters(params, group_size):

class ShardingInfo(object):

def __init__(self, group, rank, params):
def __init__(self, group, rank, params_grads):
self.group = group
self.params = params
self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
assert len(self.params_grads) == len(set(
self.params_grads)), "found duplicated param in params_grads"

self.params = [p for p, _ in params_grads]
self.param_names = [p.name for p in self.params]
self.group_size = group.nranks
self.global_rank = rank
Expand Down Expand Up @@ -762,3 +778,11 @@ def get_broadcast_vars_and_param_usage(self, block):
if usage > 0:
broadcast_vars.add(param)
return broadcast_vars, param_usage

def get_param_grad(self, param_name):
if not self.is_in_local_shard(param_name):
raise ValueError(
"param[{}] not in current rank.".format(param_name))
if param_name not in self.params_grads:
raise ValueError('param[{}] not in params_grads'.format(param_name))
return self.params_grads.get(param_name, None)
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,7 @@ def get_gpt_model(self, strategy, place, batch_size, sequence_len,
preds = model(tokens, position_ids, attention_mask)
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)

if kwargs.get('optimizer', None) == "LarsMomentum":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
Expand All @@ -188,7 +189,7 @@ def get_gpt_model(self, strategy, place, batch_size, sequence_len,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=None)
grad_clip=clip)
optimizer = fleet.distributed_optimizer(optimizer)
startup_program = paddle.static.default_startup_program()
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
Expand Down

0 comments on commit d694628

Please sign in to comment.