diff --git a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py index f6885d7371513..324c51e29c9e6 100644 --- a/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py +++ b/paddle/fluid/pir/dialect/op_generator/vjp_interface_gen_op_list.py @@ -67,6 +67,7 @@ 'unsqueeze', 'tril', 'triu', + 'tile', ] vjp_interface_implementation_gen_op_list = [ 'where', @@ -112,4 +113,5 @@ 'unsqueeze', 'tril', 'triu', + 'tile', ] diff --git a/paddle/fluid/primitive/codegen/gen.py b/paddle/fluid/primitive/codegen/gen.py index 0eef8470521ea..8f96cee858d06 100644 --- a/paddle/fluid/primitive/codegen/gen.py +++ b/paddle/fluid/primitive/codegen/gen.py @@ -41,6 +41,7 @@ 'where_grad', 'tril_grad', 'triu_grad', + 'tile_grad', 'tanh_grad', 'mean_grad', 'add_grad', @@ -156,6 +157,7 @@ 'where_grad', 'tril_grad', 'triu_grad', + 'tile_grad', 'add_n', 'mean', 'sum', diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 7cf0bb4084514..1e9cb7a5c58f1 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -3168,7 +3168,7 @@ def tile(x, repeat_times, name=None): # Tensor(shape=[1, 6], dtype=int32, place=Place(gpu:0), stop_gradient=True, # [[1, 2, 3, 1, 2, 3]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): if isinstance(repeat_times, core.eager.Tensor): assert ( repeat_times.ndim == 1 diff --git a/test/legacy_test/test_tile_op.py b/test/legacy_test/test_tile_op.py index 6e0cea75b0c3c..40dc04b053770 100644 --- a/test/legacy_test/test_tile_op.py +++ b/test/legacy_test/test_tile_op.py @@ -47,10 +47,10 @@ def init_data(self): self.repeat_times = [2] def test_check_output(self): - self.check_output(check_cinn=self.check_cinn) + self.check_output(check_cinn=self.check_cinn, check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) class TestTileOpRank_ZeroDim1(TestTileOpRank1): @@ -165,7 +165,7 @@ def init_data(self): self.infer_repeat_times = [-1] def test_check_output(self): - self.check_output() + self.check_output(check_new_ir=True) def test_check_grad(self): self.check_grad(['X'], 'Out') @@ -206,7 +206,7 @@ def init_data(self): self.repeat_times = [2] def test_check_output(self): - self.check_output() + self.check_output(check_new_ir=True) def test_check_grad(self): self.check_grad(['X'], 'Out') @@ -235,7 +235,7 @@ def if_enable_cinn(self): self.check_cinn = True def test_check_output(self): - self.check_output(check_cinn=self.check_cinn) + self.check_output(check_cinn=self.check_cinn, check_new_ir=True) class TestTileFP16OP(OpTest): @@ -262,10 +262,10 @@ def init_data(self): self.repeat_times = [2, 1, 4] def test_check_output(self): - self.check_output(check_cinn=self.check_cinn) + self.check_output(check_cinn=self.check_cinn, check_new_ir=True) def test_check_grad(self): - self.check_grad(['X'], 'Out', check_prim=True) + self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True) @unittest.skipIf( @@ -293,7 +293,9 @@ def if_enable_cinn(self): def test_check_output(self): place = core.CUDAPlace(0) - self.check_output_with_place(place, check_cinn=self.check_cinn) + self.check_output_with_place( + place, check_cinn=self.check_cinn, check_new_ir=True + ) def init_data(self): self.dtype = np.uint16 @@ -302,7 +304,9 @@ def init_data(self): def test_check_grad(self): place = core.CUDAPlace(0) - self.check_grad_with_place(place, ['X'], 'Out', check_prim=True) + self.check_grad_with_place( + place, ['X'], 'Out', check_prim=True, check_new_ir=True + ) # Situation 5: input x is Bool @@ -320,7 +324,7 @@ def if_enable_cinn(self): self.check_cinn = True def test_check_output(self): - self.check_output(check_cinn=self.check_cinn) + self.check_output(check_cinn=self.check_cinn, check_new_ir=True) # Situation 56: input x is Integer @@ -340,7 +344,7 @@ def if_enable_cinn(self): self.check_cinn = True def test_check_output(self): - self.check_output(check_cinn=self.check_cinn) + self.check_output(check_cinn=self.check_cinn, check_new_ir=True) class TestTileError(unittest.TestCase):