Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

function setitem support with stride #57023

Merged
merged 6 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 48 additions & 40 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import core
import paddle
import warnings
import itertools


MAX_INTEGER = 2**31 - 1
Expand Down Expand Up @@ -869,16 +870,24 @@ def _setitem_static(x, indices, values):
'decrease_axes': decrease_axes,
'none_axes': none_axes,
}

value_tensor = None
StartsTensorList = None
EndsTensorList = None
StepsTensorList = None

if paddle.utils._contain_var(starts):
inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list(
starts
)
StartsTensorList = paddle.utils._convert_to_tensor_list(starts)
inputs['StartsTensorList'] = StartsTensorList
del attrs['starts']

if paddle.utils._contain_var(ends):
inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends)
EndsTensorList = paddle.utils._convert_to_tensor_list(ends)
inputs['EndsTensorList'] = EndsTensorList
del attrs['ends']
if paddle.utils._contain_var(steps):
inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps)
StepsTensorList = paddle.utils._convert_to_tensor_list(steps)
inputs['StepsTensorList'] = StepsTensorList
del attrs['steps']

if not has_advanced_index:
Expand All @@ -899,6 +908,8 @@ def _setitem_static(x, indices, values):

elif isinstance(values, Variable):
inputs["ValueTensor"] = values
value_tensor = values

else:
raise TypeError(
"Only support to assign an integer, float, numpy.ndarray or "
Expand All @@ -909,8 +920,14 @@ def _setitem_static(x, indices, values):

# step3.1: Only basic indexing, use OP set_value to set value.
if paddle.in_dynamic_mode():
x._bump_inplace_version()
output = x
return paddle._legacy_C_ops.set_value_(
x,
value_tensor,
StartsTensorList,
EndsTensorList,
StepsTensorList,
*itertools.chain.from_iterable(attrs.items())
)
else:
helper = paddle.base.layer_helper.LayerHelper(
'set_value', **locals()
Expand All @@ -924,21 +941,20 @@ def _setitem_static(x, indices, values):
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)

if not paddle.in_dynamic_mode():
# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output
)
return output
return output
else:
# step3.2: Case for there are advanced indexing.
# 1. get __getitem__ result of basic indexing;
Expand All @@ -965,27 +981,19 @@ def _setitem_static(x, indices, values):
) = deal_advanced_index(sub_tensor, advanced_index, True)
if not isinstance(values, Variable):
values = paddle.assign(values).astype(transed_sub_tensor.dtype)
transed_sub_tensor = transed_sub_tensor.index_put(
adjusted_advanced_index, values
)

# NOTE(zoooo0820): now basic indexing of __getitem__ will return a new Tensor both in dynamic and static mode
# After strided is ready and basic indexing returns view of Tensor in dynamic mode. The code shoule be changed
# for dynamic mode.
if paddle.in_dynamic_mode():
transed_sub_tensor.index_put_(adjusted_advanced_index, values)
return transed_sub_tensor.index_put_(
adjusted_advanced_index, values
)
else:
transed_sub_tensor = transed_sub_tensor.index_put(
adjusted_advanced_index, values
)

transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)
transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)
inputs["ValueTensor"] = transback_sub_tensor

inputs["ValueTensor"] = transback_sub_tensor
if paddle.in_dynamic_mode():
x._bump_inplace_version()
output = x
else:
helper = paddle.base.layer_helper.LayerHelper(
'set_value', **locals()
)
Expand All @@ -998,20 +1006,20 @@ def _setitem_static(x, indices, values):
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)
if not paddle.in_dynamic_mode():
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)

# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output
)
return output
return output


def get_tensor_with_basic_indexing(
Expand Down
61 changes: 61 additions & 0 deletions test/indexing/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,67 @@ def test_index_has_range(self):
np.testing.assert_allclose(res[0], np_res)


class TestGetitemBasicIndexOutputView(unittest.TestCase):
def setUp(self):
# Stride now only supports in dygraph mode
paddle.disable_static()

def test_index_is_int(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[3, 2]
np_tmp[2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[3, 2]
x_tmp[2] = 20

np.testing.assert_allclose(x.numpy(), np_data)

def test_index_is_0dTensor(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[3, 2]
np_tmp[2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[paddle.to_tensor(3), paddle.to_tensor(2)]
x_tmp[2] = 20

np.testing.assert_allclose(x.numpy(), np_data)

def test_index_is_slice(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[::2, :, 0:4]
np_tmp[2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[::2, :, 0:4]
x_tmp[2] = 20

np.testing.assert_allclose(x.numpy(), np_data)

def test_index_is_None(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[None]
np_tmp[:, 2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[None]
x_tmp[:, 2] = 20

np.testing.assert_allclose(x.numpy(), np_data)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we add test case of ellipsis index

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done, add a case to test ellipsis index.

def test_index_is_ellipsis(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[...]
np_tmp[2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[...]
x_tmp[2] = 20

np.testing.assert_allclose(x.numpy(), np_data)


class TestGetItemErrorCase(unittest.TestCase):
def setUp(self):
paddle.disable_static()
Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_set_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,13 +1878,13 @@ def test_inplace(self):
paddle.seed(100)
a = paddle.rand(shape=[1, 4])
a.stop_gradient = False
b = a[:]
b = a[:] * 1
c = b
b[paddle.zeros([], dtype='int32')] = 1.0

self.assertTrue(id(b) == id(c))
np.testing.assert_array_equal(b.numpy(), c.numpy())
self.assertEqual(b.inplace_version, 0)
self.assertEqual(b.inplace_version, 1)

paddle.enable_static()

Expand Down