Skip to content

Commit

Permalink
update xpu zero dim tensor ut (#50289)
Browse files Browse the repository at this point in the history
* xpu scatter ut no backward

* update gather xpu ut
  • Loading branch information
FeixLiu authored Feb 7, 2023
1 parent 84fe2de commit 7e4b432
Showing 1 changed file with 3 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -583,42 +583,35 @@ def test_gather_xD_axis_0(self):
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [3])

def _test_gather_xD_axis_1(self):
def test_gather_xD_axis_1(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
out = paddle.gather(x, index, axis=1)
out.backward()

self.assertEqual(out.shape, [2])
np.testing.assert_array_equal(out.numpy(), [2.0, 5.0])
self.assertEqual(x.grad.shape, [2, 3])
self.assertEqual(out.grad.shape, [2])

def _test_scatter_1D(self):
def test_scatter_1D(self):
x = paddle.to_tensor([1.0, 3.0, 5.0, 7.0, 9.0], stop_gradient=False)
index = paddle.full([], 2, 'int64')
updates = paddle.full([], 4.0)
out = paddle.scatter(x, index, updates)
out.backward()

self.assertEqual(out.shape, [5])
self.assertEqual(out.numpy()[2], 4)
self.assertEqual(out.grad.shape, [5])

def _test_scatter_XD(self):
def test_scatter_XD(self):
x = paddle.to_tensor(
[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]], stop_gradient=False
)
index = paddle.full([], 1, 'int64')
updates = paddle.to_tensor([1.0, 2.0, 3.0])
out = paddle.scatter(x, index, updates)
out.backward()

self.assertEqual(out.shape, [2, 3])
np.testing.assert_array_equal(out.numpy()[1], [1.0, 2.0, 3.0])
self.assertEqual(out.grad.shape, [2, 3])

def test_diagflat(self):
x1 = paddle.rand([])
Expand Down

0 comments on commit 7e4b432

Please sign in to comment.