Skip to content

Commit

Permalink
[PIR] No.46 Migrate paddle.tile into pir (PaddlePaddle#57700)
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 authored and jiahy0825 committed Oct 16, 2023
1 parent d01b52a commit 5a3987b
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@
'unsqueeze',
'tril',
'triu',
'tile',
]
vjp_interface_implementation_gen_op_list = [
'where',
Expand Down Expand Up @@ -112,4 +113,5 @@
'unsqueeze',
'tril',
'triu',
'tile',
]
2 changes: 2 additions & 0 deletions paddle/fluid/primitive/codegen/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
'where_grad',
'tril_grad',
'triu_grad',
'tile_grad',
'tanh_grad',
'mean_grad',
'add_grad',
Expand Down Expand Up @@ -160,6 +161,7 @@
'where_grad',
'tril_grad',
'triu_grad',
'tile_grad',
'add_n',
'mean',
'sum',
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3168,7 +3168,7 @@ def tile(x, repeat_times, name=None):
# Tensor(shape=[1, 6], dtype=int32, place=Place(gpu:0), stop_gradient=True,
# [[1, 2, 3, 1, 2, 3]])
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
if isinstance(repeat_times, core.eager.Tensor):
assert (
repeat_times.ndim == 1
Expand Down
26 changes: 15 additions & 11 deletions test/legacy_test/test_tile_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,10 @@ def init_data(self):
self.repeat_times = [2]

def test_check_output(self):
self.check_output(check_cinn=self.check_cinn)
self.check_output(check_cinn=self.check_cinn, check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


class TestTileOpRank_ZeroDim1(TestTileOpRank1):
Expand Down Expand Up @@ -165,7 +165,7 @@ def init_data(self):
self.infer_repeat_times = [-1]

def test_check_output(self):
self.check_output()
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
Expand Down Expand Up @@ -206,7 +206,7 @@ def init_data(self):
self.repeat_times = [2]

def test_check_output(self):
self.check_output()
self.check_output(check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
Expand Down Expand Up @@ -235,7 +235,7 @@ def if_enable_cinn(self):
self.check_cinn = True

def test_check_output(self):
self.check_output(check_cinn=self.check_cinn)
self.check_output(check_cinn=self.check_cinn, check_new_ir=True)


class TestTileFP16OP(OpTest):
Expand All @@ -262,10 +262,10 @@ def init_data(self):
self.repeat_times = [2, 1, 4]

def test_check_output(self):
self.check_output(check_cinn=self.check_cinn)
self.check_output(check_cinn=self.check_cinn, check_new_ir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, check_new_ir=True)


@unittest.skipIf(
Expand Down Expand Up @@ -293,7 +293,9 @@ def if_enable_cinn(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place, check_cinn=self.check_cinn)
self.check_output_with_place(
place, check_cinn=self.check_cinn, check_new_ir=True
)

def init_data(self):
self.dtype = np.uint16
Expand All @@ -302,7 +304,9 @@ def init_data(self):

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
self.check_grad_with_place(
place, ['X'], 'Out', check_prim=True, check_new_ir=True
)


# Situation 5: input x is Bool
Expand All @@ -320,7 +324,7 @@ def if_enable_cinn(self):
self.check_cinn = True

def test_check_output(self):
self.check_output(check_cinn=self.check_cinn)
self.check_output(check_cinn=self.check_cinn, check_new_ir=True)


# Situation 56: input x is Integer
Expand All @@ -340,7 +344,7 @@ def if_enable_cinn(self):
self.check_cinn = True

def test_check_output(self):
self.check_output(check_cinn=self.check_cinn)
self.check_output(check_cinn=self.check_cinn, check_new_ir=True)


class TestTileError(unittest.TestCase):
Expand Down

0 comments on commit 5a3987b

Please sign in to comment.