Skip to content

Commit

Permalink
Add fill_constant_batch_size_like tests (#53736)
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc authored May 16, 2023
1 parent 51ecd93 commit 98100fd
Showing 1 changed file with 11 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,12 @@ def fill_constant_batch_size_like(
)


class TestFillConstatnBatchSizeLike1(OpTest):
class TestFillConstantBatchSizeLike1(OpTest):
# test basic
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
self.python_api = fill_constant_batch_size_like
self.init_dtype()
self.init_data()

input = np.zeros(self.shape)
Expand All @@ -59,9 +60,11 @@ def setUp(self):
'force_cpu': self.force_cpu,
}

def init_dtype(self):
self.dtype = np.float32

def init_data(self):
self.shape = [10, 10]
self.dtype = np.float32
self.value = 100
self.input_dim_idx = 0
self.output_dim_idx = 0
Expand All @@ -71,11 +74,16 @@ def test_check_output(self):
self.check_output()


class TestFillConstantBatchSizeLikeFP16Op(TestFillConstantBatchSizeLike1):
def init_dtype(self):
self.dtype = np.float16


@unittest.skipIf(
not core.is_compiled_with_cuda() or not core.supports_bfloat16(),
"core is not compiled with CUDA or place do not support bfloat16",
)
class TestFillConstatnBatchSizeLikeBf16(OpTest):
class TestFillConstantBatchSizeLikeBF16Op(OpTest):
# test bf16
def setUp(self):
self.op_type = "fill_constant_batch_size_like"
Expand Down

0 comments on commit 98100fd

Please sign in to comment.