diff --git a/python/paddle/base/backward.py b/python/paddle/base/backward.py index 6d30823d4bf4a..4c3c158924de2 100755 --- a/python/paddle/base/backward.py +++ b/python/paddle/base/backward.py @@ -2060,11 +2060,11 @@ def append_backward( block, [loss], [], block_no_grad_set, op_path_dict ) - no_grad_vars = _find_no_grad_vars( + no_grad_set = _find_no_grad_vars( block, op_path, [loss], block_no_grad_set ) - block_no_grad_set.update(no_grad_vars) + block_no_grad_set.update(no_grad_set) no_grad_dict[block_idx].update( list(map(_append_grad_suffix_, block_no_grad_set)) ) @@ -2510,10 +2510,10 @@ def calc_gradient_helper( block.program._sync_with_cpp() # find no grad var by op_path - no_grad_vars = _find_no_grad_vars( + no_grad_set = _find_no_grad_vars( block, op_path, tmp_targets, block_no_grad_set ) - block_no_grad_set.update(no_grad_vars) + block_no_grad_set.update(no_grad_set) no_grad_dict[0].update(list(map(_append_grad_suffix_, block_no_grad_set))) grad_to_var = dict() @@ -2636,6 +2636,56 @@ def gradients(targets, inputs, target_gradients=None, no_grad_set=None): >>> print(z) [var x@GRAD : LOD_TENSOR.shape(-1, 2, 8, 8).dtype(float32).stop_gradient(False)] """ + if framework.in_pir_mode(): + check_type( + targets, + 'targets', + ((paddle.pir.Value, paddle.pir.OpResult), list, tuple), + 'paddle.autograd.ir_backward.grad', + ) + check_type( + inputs, + 'inputs', + ((paddle.pir.Value, paddle.pir.OpResult), list, tuple), + 'paddle.autograd.ir_backward.grad', + ) + check_type( + target_gradients, + 'target_gradients', + ((paddle.pir.Value, paddle.pir.OpResult), list, tuple, type(None)), + 'paddle.autograd.ir_backward.grad', + ) + + check_type( + no_grad_set, + 'no_grad_set', + ( + (paddle.pir.Value, paddle.pir.OpResult), + list, + tuple, + set, + type(None), + ), + 'paddle.autograd.ir_backward.grad', + ) + targets = _as_list(targets) + inputs = _as_list(inputs) + target_gradients = _as_list(target_gradients) + if no_grad_set is None: + no_grad_set = set() + elif no_grad_set is not set: + no_grad_set = set(no_grad_set) + else: + no_grad_set = no_grad_set + from paddle.autograd.ir_backward import ( + calc_gradient as pir_calc_gradient, + ) + + input_grad = pir_calc_gradient( + targets, inputs, target_gradients, no_grad_set + ) + return input_grad + check_type( targets, 'targets', diff --git a/python/paddle/pir_utils.py b/python/paddle/pir_utils.py index 28d261b0155fc..a2b5244cad7c5 100644 --- a/python/paddle/pir_utils.py +++ b/python/paddle/pir_utils.py @@ -13,6 +13,8 @@ # limitations under the License. +from functools import wraps + import paddle @@ -64,9 +66,16 @@ def _switch_to_pir(self): {"FLAGS_enable_new_ir_in_executor": True} ) paddle.pir.register_paddle_dialect() - paddle.static.Program = paddle.pir.Program + paddle.base.Program = paddle.pir.Program paddle.base.program_guard = paddle.pir.core.program_guard + # paddle.base.default_main_program = ( + # paddle.pir.core.default_main_program + # ) + # paddle.base.default_startup_program = ( + # paddle.pir.core.default_startup_program + # ) + paddle.static.Program = paddle.pir.Program paddle.static.program_guard = paddle.pir.core.program_guard paddle.static.default_main_program = ( paddle.pir.core.default_main_program @@ -82,9 +91,14 @@ def _switch_to_old_ir(self): paddle.framework.set_flags( {"FLAGS_enable_new_ir_in_executor": False} ) - paddle.static.Program = self.old_Program + paddle.base.Program = self.old_Program paddle.base.program_guard = self.old_program_guard + # paddle.base.default_main_program = self.old_default_main_program + # paddle.base.default_startup_program = ( + # self.old_default_startup_program + # ) + paddle.static.Program = self.old_Program paddle.static.program_guard = self.old_program_guard paddle.static.default_main_program = self.old_default_main_program paddle.static.default_startup_program = ( @@ -95,3 +109,13 @@ def _switch_to_old_ir(self): "IrGuard._switch_to_old_ir only work when paddle.framework.in_pir_mode() is false, \ please set FLAGS_enable_pir_api = false" ) + + +def test_with_pir_api(func): + @wraps(func) + def impl(*args, **kwargs): + func(*args, **kwargs) + with IrGuard(): + func(*args, **kwargs) + + return impl diff --git a/python/paddle/tensor/math.py b/python/paddle/tensor/math.py index 5a60e6884b890..467c7f7ab88f1 100644 --- a/python/paddle/tensor/math.py +++ b/python/paddle/tensor/math.py @@ -1226,7 +1226,7 @@ def maximum(x, y, name=None): Tensor(shape=[3], dtype=float32, place=Place(cpu), stop_gradient=True, [5. , 3. , inf.]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.maximum(x, y) else: return _elementwise_op(LayerHelper('elementwise_max', **locals())) diff --git a/test/legacy_test/test_calc_gradient.py b/test/legacy_test/test_calc_gradient.py index 945acf18bb932..41f3772260c77 100644 --- a/test/legacy_test/test_calc_gradient.py +++ b/test/legacy_test/test_calc_gradient.py @@ -85,7 +85,11 @@ def test2(self): self.assertEqual(12, out[0]) +from paddle.pir_utils import test_with_pir_api + + class TestGradientWithPrune(unittest.TestCase): + @test_with_pir_api def test_prune(self): with paddle.base.scope_guard(paddle.static.Scope()): x = paddle.static.data(name='x', shape=[3], dtype='float32') @@ -95,8 +99,8 @@ def test_prune(self): x1_grad = base.gradients(y, x) exe = base.Executor(base.CPUPlace()) - main = base.default_main_program() - exe.run(base.default_startup_program()) + main = paddle.static.default_main_program() + exe.run(paddle.static.default_startup_program()) out = exe.run( main, feed={'x': np.ones([3]).astype('float32')}, diff --git a/test/legacy_test/test_maximum_op.py b/test/legacy_test/test_maximum_op.py index 818bdb65fee68..a0e660112bd03 100644 --- a/test/legacy_test/test_maximum_op.py +++ b/test/legacy_test/test_maximum_op.py @@ -18,6 +18,7 @@ import paddle from paddle.base import core +from paddle.pir_utils import test_with_pir_api class ApiMaximumTest(unittest.TestCase): @@ -39,6 +40,7 @@ def setUp(self): self.np_expected3 = np.maximum(self.input_a, self.input_c) self.np_expected4 = np.maximum(self.input_b, self.input_c) + @test_with_pir_api def test_static_api(self): paddle.enable_static() with paddle.static.program_guard( @@ -119,3 +121,7 @@ def test_dynamic_api(self): res = paddle.maximum(b, c) res = res.numpy() np.testing.assert_allclose(res, self.np_expected4, rtol=1e-05) + + +if __name__ == '__main__': + unittest.main()