Skip to content

Commit

Permalink
add OpResult.clone() (PaddlePaddle#59115)
Browse files Browse the repository at this point in the history
  • Loading branch information
gouzil authored and SecretXV committed Nov 28, 2023
1 parent c68c792 commit f32be6f
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 0 deletions.
26 changes: 26 additions & 0 deletions python/paddle/pir/math_op_patch.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,31 @@ def _size_(self):
"""
return paddle.numel(self)

def clone(self):
"""
Returns a new static OpResult, which is the clone of the original static
OpResult. It remains in the current graph, that is, the cloned OpResult
provides gradient propagation. Calling ``out = tensor.clone()`` is same
as ``out = assign(tensor)`` .
Returns:
OpResult, The cloned OpResult.
Examples:
.. code-block:: python
>>> import paddle
>>> paddle.enable_static()
>>> # create a static OpResult
>>> x = paddle.static.data(name='x', shape=[3, 2, 1])
>>> # create a cloned OpResult
>>> y = x.clone()
"""
return paddle.assign(self)

import paddle

opresult_methods = [
Expand All @@ -341,6 +366,7 @@ def _size_(self):
('ndim', _ndim),
('astype', astype),
('size', _size_),
('clone', clone),
(
'__add__',
_binary_creator_('__add__', paddle.tensor.add, False, _scalar_add_),
Expand Down
17 changes: 17 additions & 0 deletions test/legacy_test/test_math_op_patch_pir.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,23 @@ def test_size(self):
(output_x,) = exe.run(main_program, fetch_list=[x.size])
self.assertEqual(output_x, 24)

def test_clone(self):
x_np = np.random.random(size=[100, 10]).astype('float64')
with paddle.pir_utils.IrGuard():
main_program, exe, program_guard = new_program()
with program_guard:
x = paddle.static.data(
name='x', shape=[100, 10], dtype="float64"
)
a = x.clone()
(a_np,) = exe.run(
main_program,
feed={"x": x_np},
fetch_list=[a],
)
np.testing.assert_array_equal(x_np, a_np)
self.assertNotEqual(id(x), id(a))

def test_math_exists(self):
with paddle.pir_utils.IrGuard():
a = paddle.static.data(name='a', shape=[1], dtype='float32')
Expand Down

0 comments on commit f32be6f

Please sign in to comment.