Skip to content

Commit

Permalink
[AutoParallel] fix amp o1 (#46391) (#46481)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhaoyinglia authored Sep 27, 2022
1 parent 5711bbe commit 5dab0b0
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
40 changes: 24 additions & 16 deletions python/paddle/distributed/passes/auto_parallel_amp.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,18 @@ def __init__(self, block):
self._op_fp16_dict = {
} # op_id --> True/False. 'True' means that the current op is in fp16 mode.
self._var_name_dict = {} # fwd_op_id --> {old_name: cast_name}
self.is_train = False

def _is_fp16_op(self, op_id):
return self._op_fp16_dict.get(op_id, None)

def _build_stats(self, amp_lists, dist_context):
def _build_state(self, amp_lists, dist_context):
ops = self._block.ops
dist_op_context = dist_context.dist_op_context
for op in ops:
if int(op.attr('op_role')) == 257:
self.is_train = True

if int(op.attr('op_role')) == int(OpRole.Forward):
self._mark_black_white_ops(amp_lists)
elif int(op.attr('op_role')) == int(OpRole.Backward):
Expand All @@ -59,6 +63,8 @@ def _build_stats(self, amp_lists, dist_context):
elif int(op.attr('op_role')) == int(OpRole.Optimize):
break

return self.is_train

def _mark_black_white_ops(self, amp_lists):
"""
this function is modified from paddle.fluid.contrib.mixed_precision
Expand Down Expand Up @@ -546,23 +552,25 @@ def _apply_single_impl(self, main_program, startup_program, context):
set(self.get_attr("custom_black_list")),
set(self.get_attr("custom_black_varnames")))

amp_state = AMPState(main_program.global_block())
amp_state._build_stats(amp_lists, self.dist_context)

with paddle.static.program_guard(main_program, startup_program):
amp_state = AMPState(main_program.global_block())
is_train = amp_state._build_state(amp_lists, self.dist_context)

amp_state.cast_forward_program(self.dist_context)
amp_state.cast_backward_program(params_grads, self.dist_context)
# TODO (JZ-LIANG)support cast forward program only when inference
self._init_amp_var()
self._scale_loss()

if self.get_attr("use_dynamic_loss_scaling"
) or self.get_attr("init_loss_scaling") != 1.0:
grads, found_inf = _check_and_update_gradient(
params_grads, self._loss_scaling, self.dist_context)

if self.get_attr("use_dynamic_loss_scaling"):
self._update_loss_scaling(grads, found_inf)

if is_train:
with paddle.static.program_guard(main_program, startup_program):
amp_state.cast_backward_program(params_grads, self.dist_context)
self._init_amp_var()
self._scale_loss()

if self.get_attr("use_dynamic_loss_scaling"
) or self.get_attr("init_loss_scaling") != 1.0:
grads, found_inf = _check_and_update_gradient(
params_grads, self._loss_scaling, self.dist_context)

if self.get_attr("use_dynamic_loss_scaling"):
self._update_loss_scaling(grads, found_inf)

def _init_amp_var(self):
self._loss_scaling = paddle.static.create_global_var(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ def test_amp_pass(self):
3,
batch_size=self.batch_size)
amp_o1_losses = np.array(amp_o1_losses["loss"])
amp_o1_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o1_losses)

# mp2 amp-o2 training
Expand All @@ -105,6 +106,7 @@ def test_amp_pass(self):
3,
batch_size=self.batch_size)
amp_o2_losses = np.array(amp_o2_losses["loss"])
amp_o2_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o2_losses)

# mp2 amp-o3 training
Expand All @@ -113,6 +115,7 @@ def test_amp_pass(self):
3,
batch_size=self.batch_size)
amp_o3_losses = np.array(amp_o3_losses["loss"])
amp_o3_engine.evaluate(self.dataset, 3, batch_size=self.batch_size)
# self.check_results(mp_losses, amp_o3_losses)


Expand Down

0 comments on commit 5dab0b0

Please sign in to comment.