Skip to content

Commit

Permalink
Fix param@grad type error for amp in run_program
Browse files Browse the repository at this point in the history
  • Loading branch information
0x45f committed Mar 25, 2022
1 parent b79c6a9 commit 47cc9bc
Showing 1 changed file with 7 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,9 @@ def _train_amp_program(self):
"""
Lazy initialized property of train_amp_program.
"""
return self._append_backward_desc(self._infer_amp_program)
train_amp_program = self._append_backward_desc(self._infer_amp_program)
self._set_grad_type(self._params, train_amp_program)
return train_amp_program

@LazyInitialized
@switch_to_static_graph
Expand All @@ -224,7 +226,10 @@ def _train_pure_fp16_program(self):
"""
Lazy initialized property of _train_pure_fp16_program.
"""
return self._append_backward_desc(self._infer_pure_fp16_program)
train_pure_fp16_program = self._append_backward_desc(
self._infer_pure_fp16_program)
self._set_grad_type(self._params, train_pure_fp16_program)
return train_pure_fp16_program

@LazyInitialized
def _infer_program_id(self):
Expand Down

1 comment on commit 47cc9bc

@paddle-bot-old
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Congratulation! Your pull request passed all required CI. You could ask reviewer(s) to approve and merge. 🎉

Please sign in to comment.