Skip to content

Commit

Permalink
Zero-dim support of histogram kernel, test=develop (#49884)
Browse files Browse the repository at this point in the history
  • Loading branch information
qili93 authored Jan 18, 2023
1 parent 5fca45e commit 6cd7fca
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 0 deletions.
8 changes: 8 additions & 0 deletions python/paddle/fluid/tests/unittests/test_histogram_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ def test_check_output(self):
self.check_output(check_eager=True)


class TestHistogramOp_ZeroDim(TestHistogramOp):
def init_test_case(self):
self.in_shape = []
self.bins = 5
self.min = 1
self.max = 5


if __name__ == "__main__":
paddle.enable_static()
unittest.main()
15 changes: 15 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,11 @@ def test_flatten(self):
self.assertEqual(out.grad.shape, [1])
self.assertEqual(x.grad.shape, [])

def test_histogram(self):
x = paddle.rand([])
out = paddle.histogram(x, bins=5, min=1, max=5)
self.assertEqual(out.shape, [5])

def test_scale(self):
x = paddle.rand([])
x.stop_gradient = False
Expand Down Expand Up @@ -1658,6 +1663,16 @@ def test_flatten(self):
self.assertEqual(res[1].shape, ())
self.assertEqual(res[2].shape, (1,))

@prog_scope()
def test_histogram(self):
x = paddle.full([], 1, 'float32')
out = paddle.histogram(x, bins=5, min=1, max=5)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, feed={}, fetch_list=[out])

self.assertEqual(res[0].shape, (5,))

@prog_scope()
def test_scale(self):
x = paddle.rand([])
Expand Down

0 comments on commit 6cd7fca

Please sign in to comment.