From 47cc9bc8d995805669212f333f6cd110e7bdc481 Mon Sep 17 00:00:00 2001 From: 0x45f Date: Fri, 25 Mar 2022 06:29:10 +0000 Subject: [PATCH] Fix param@grad type error for amp in run_program --- .../fluid/dygraph/dygraph_to_static/partial_program.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py index b8f8de67cc4a2..90f960798ef2c 100644 --- a/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py +++ b/python/paddle/fluid/dygraph/dygraph_to_static/partial_program.py @@ -204,7 +204,9 @@ def _train_amp_program(self): """ Lazy initialized property of train_amp_program. """ - return self._append_backward_desc(self._infer_amp_program) + train_amp_program = self._append_backward_desc(self._infer_amp_program) + self._set_grad_type(self._params, train_amp_program) + return train_amp_program @LazyInitialized @switch_to_static_graph @@ -224,7 +226,10 @@ def _train_pure_fp16_program(self): """ Lazy initialized property of _train_pure_fp16_program. """ - return self._append_backward_desc(self._infer_pure_fp16_program) + train_pure_fp16_program = self._append_backward_desc( + self._infer_pure_fp16_program) + self._set_grad_type(self._params, train_pure_fp16_program) + return train_pure_fp16_program @LazyInitialized def _infer_program_id(self):