Skip to content

Commit

Permalink
add 0D support for trace (PaddlePaddle#53208)
Browse files Browse the repository at this point in the history
* add 0D support for trace, test=allcase

* fix trace gpu kernel 0d error, test=allcase

* fix windows error, test=allcase
  • Loading branch information
GGBond8488 authored and lijialin03 committed Apr 25, 2023
1 parent f69866f commit 8298146
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 3 deletions.
1 change: 0 additions & 1 deletion paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4405,7 +4405,6 @@ void TraceInferMeta(
auto sizes = vectorize(x_dims);
if (x_dims.size() == 2) {
sizes.clear();
sizes.push_back(1);
} else {
sizes.erase(sizes.begin() + std::max(dim1_, dim2_));
sizes.erase(sizes.begin() + std::min(dim1_, dim2_));
Expand Down
5 changes: 4 additions & 1 deletion paddle/phi/kernels/gpu/trace_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,10 @@ void TraceKernel(const Context& ctx,
auto diag = funcs::Diagonal<T, Context>(ctx, &x, offset, axis1, axis2);
if (diag.numel() > 0) {
std::vector<int> reduce_dims;
reduce_dims.push_back(out->dims().size());
// Adapt to 0D output
auto out_dim_size = out->dims().size();
if (out_dim_size == 0) out_dim_size = 1;
reduce_dims.push_back(out_dim_size);
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
ctx, diag, out, kps::IdentityFunctor<T>(), reduce_dims);
} else {
Expand Down
3 changes: 2 additions & 1 deletion paddle/phi/kernels/impl/trace_grad_kernel_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ void TraceGradKernel(const Context& ctx,
auto input_dims = in_grad->dims();
auto input_stride = phi::stride(input_dims);
auto output_dims = out_grad.dims();
auto output_stride = phi::stride(output_dims);
auto output_stride = output_dims.size() == 0 ? phi::DDim(output_dims)
: phi::stride(output_dims);

auto* out_data = out_grad.data<T>();
T* x_data = ctx.template Alloc<T>(in_grad);
Expand Down
24 changes: 24 additions & 0 deletions python/paddle/fluid/tests/unittests/test_zero_dim_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2303,6 +2303,16 @@ def test_multi_dot(self):
self.assertEqual(b.grad.shape, [4, 5])
self.assertEqual(c.grad.shape, [5])

def test_trace(self):
x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32")
x.stop_gradient = False
out = paddle.trace(x)
out.backward()

self.assertEqual(out.shape, [])
np.testing.assert_allclose(out, np.array(12))
self.assertEqual(x.grad.shape, [2, 2])


class TestSundryAPIStatic(unittest.TestCase):
def setUp(self):
Expand Down Expand Up @@ -4180,6 +4190,20 @@ def test_multi_dot(self):
self.assertEqual(res[2].shape, (4, 5))
self.assertEqual(res[3].shape, (5,))

@prog_scope()
def test_trace(self):
x = paddle.to_tensor([[3, 2], [1, 9]], dtype="float32")
x.stop_gradient = False
out = paddle.trace(x)
paddle.static.append_backward(out)

prog = paddle.static.default_main_program()
res = self.exe.run(prog, fetch_list=[out, x.grad_name])

self.assertEqual(res[0].shape, ())
self.assertEqual(res[1].shape, (2, 2))
np.testing.assert_allclose(res[0], np.array(12))


# Use to test API whose zero-dim input tensors don't have grad and not need to test backward in OpTest.
class TestNoBackwardAPI(unittest.TestCase):
Expand Down

0 comments on commit 8298146

Please sign in to comment.