Skip to content

Commit

Permalink
[AMP OP&Test] Tile OP (#51380)
Browse files Browse the repository at this point in the history
* tile_op

* fix bfloat16 x

* update review

* del out
  • Loading branch information
yangjianfengo1 authored Mar 9, 2023
1 parent d0d739c commit d7660a7
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 13 deletions.
33 changes: 21 additions & 12 deletions python/paddle/fluid/tests/unittests/test_tile_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,19 +196,23 @@ def test_check_output(self):
self.check_output()


class TestTileOpFloat16(OpTest):
class TestTileFP16OP(OpTest):
def setUp(self):
self.op_type = "tile"
self.dtype = np.float16
self.__class__.op_type = self.op_type
self.python_api = paddle.tile
self.inputs = {
'X': np.random.uniform(10, size=(100, 4, 5)).astype(self.dtype)
}
self.attrs = {'repeat_times': [2, 1, 4]}
output = np.tile(self.inputs['X'], (2, 1, 4))
self.init_data()
x = np.random.uniform(10, size=self.ori_shape).astype(self.dtype)
output = np.tile(x, self.repeat_times)
self.inputs = {'X': x}
self.attrs = {'repeat_times': self.repeat_times}
self.outputs = {'Out': output}

def init_data(self):
self.dtype = np.float16
self.ori_shape = [100, 4, 5]
self.repeat_times = [2, 1, 4]

def test_check_output(self):
self.check_output()

Expand All @@ -221,22 +225,27 @@ def test_check_grad(self):
or not core.is_bfloat16_supported(core.CUDAPlace(0)),
"core is not complied with CUDA and not support the bfloat16",
)
class TestWhereOpBFloat16(OpTest):
class TestTileBF16OP(OpTest):
def setUp(self):
self.op_type = 'tile'
self.dtype = np.uint16
self.__class__.op_type = self.op_type
self.python_api = paddle.tile
x = np.random.uniform(10, size=(100, 4, 5)).astype(np.float32)
output = np.tile(x, (2, 1, 4))
self.init_data()
x = np.random.uniform(10, size=self.ori_shape).astype(np.float32)
output = np.tile(x, self.repeat_times)
self.inputs = {'X': convert_float_to_uint16(x)}
self.attrs = {'repeat_times': [2, 1, 4]}
self.attrs = {'repeat_times': self.repeat_times}
self.outputs = {'Out': convert_float_to_uint16(output)}

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)

def init_data(self):
self.dtype = np.uint16
self.ori_shape = [100, 4, 5]
self.repeat_times = [2, 1, 4]

def test_check_grad(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/tensor/manipulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3159,7 +3159,7 @@ def tile(x, repeat_times, name=None):
[
'bool',
'float16',
'bfloat16',
'uint16',
'float32',
'float64',
'int32',
Expand Down

0 comments on commit d7660a7

Please sign in to comment.