Skip to content

Commit

Permalink
function setitem support with stride (#57023)
Browse files Browse the repository at this point in the history
* function setitem support with stride

* fix ut

* remove redundant dygraph checks

* add unittest for basic slice output view

* add Ellipsis case
  • Loading branch information
zoooo0820 authored Sep 12, 2023
1 parent 48494fc commit b9673f6
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 42 deletions.
88 changes: 48 additions & 40 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from . import core
import paddle
import warnings
import itertools


MAX_INTEGER = 2**31 - 1
Expand Down Expand Up @@ -854,16 +855,24 @@ def _setitem_static(x, indices, values):
'decrease_axes': decrease_axes,
'none_axes': none_axes,
}

value_tensor = None
StartsTensorList = None
EndsTensorList = None
StepsTensorList = None

if paddle.utils._contain_var(starts):
inputs['StartsTensorList'] = paddle.utils._convert_to_tensor_list(
starts
)
StartsTensorList = paddle.utils._convert_to_tensor_list(starts)
inputs['StartsTensorList'] = StartsTensorList
del attrs['starts']

if paddle.utils._contain_var(ends):
inputs['EndsTensorList'] = paddle.utils._convert_to_tensor_list(ends)
EndsTensorList = paddle.utils._convert_to_tensor_list(ends)
inputs['EndsTensorList'] = EndsTensorList
del attrs['ends']
if paddle.utils._contain_var(steps):
inputs['StepsTensorList'] = paddle.utils._convert_to_tensor_list(steps)
StepsTensorList = paddle.utils._convert_to_tensor_list(steps)
inputs['StepsTensorList'] = StepsTensorList
del attrs['steps']

if not has_advanced_index:
Expand All @@ -884,6 +893,8 @@ def _setitem_static(x, indices, values):

elif isinstance(values, Variable):
inputs["ValueTensor"] = values
value_tensor = values

else:
raise TypeError(
"Only support to assign an integer, float, numpy.ndarray or "
Expand All @@ -894,8 +905,14 @@ def _setitem_static(x, indices, values):

# step3.1: Only basic indexing, use OP set_value to set value.
if paddle.in_dynamic_mode():
x._bump_inplace_version()
output = x
return paddle._legacy_C_ops.set_value_(
x,
value_tensor,
StartsTensorList,
EndsTensorList,
StepsTensorList,
*itertools.chain.from_iterable(attrs.items())
)
else:
helper = paddle.base.layer_helper.LayerHelper(
'set_value', **locals()
Expand All @@ -909,21 +926,20 @@ def _setitem_static(x, indices, values):
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)

if not paddle.in_dynamic_mode():
# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output
)
return output
return output
else:
# step3.2: Case for there are advanced indexing.
# 1. get __getitem__ result of basic indexing;
Expand All @@ -950,27 +966,19 @@ def _setitem_static(x, indices, values):
) = deal_advanced_index(sub_tensor, advanced_index, True)
if not isinstance(values, Variable):
values = paddle.assign(values).astype(transed_sub_tensor.dtype)
transed_sub_tensor = transed_sub_tensor.index_put(
adjusted_advanced_index, values
)

# NOTE(zoooo0820): now basic indexing of __getitem__ will return a new Tensor both in dynamic and static mode
# After strided is ready and basic indexing returns view of Tensor in dynamic mode. The code shoule be changed
# for dynamic mode.
if paddle.in_dynamic_mode():
transed_sub_tensor.index_put_(adjusted_advanced_index, values)
return transed_sub_tensor.index_put_(
adjusted_advanced_index, values
)
else:
transed_sub_tensor = transed_sub_tensor.index_put(
adjusted_advanced_index, values
)

transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)
transback_sub_tensor = transed_sub_tensor.transpose(transback_dim)
inputs["ValueTensor"] = transback_sub_tensor

inputs["ValueTensor"] = transback_sub_tensor
if paddle.in_dynamic_mode():
x._bump_inplace_version()
output = x
else:
helper = paddle.base.layer_helper.LayerHelper(
'set_value', **locals()
)
Expand All @@ -983,20 +991,20 @@ def _setitem_static(x, indices, values):
output = helper.create_variable_for_type_inference(
dtype=x.dtype
)
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)
if not paddle.in_dynamic_mode():
cur_block = default_main_program().current_block()
cur_block.append_op(
type="set_value",
inputs=inputs,
outputs={'Out': output},
attrs=attrs,
inplace_map={"Input": "Out"},
)

# map var to the new output
paddle.jit.api.ProgramTranslator.get_instance()._inplace_map.add(
cur_block.program, x.desc.id(), output
)
return output
return output


def get_tensor_with_basic_indexing(
Expand Down
61 changes: 61 additions & 0 deletions test/indexing/test_getitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,67 @@ def test_indexing_is_multi_dim_list(self):
np.testing.assert_allclose(res[1], np_res)


class TestGetitemBasicIndexOutputView(unittest.TestCase):
def setUp(self):
# Stride now only supports in dygraph mode
paddle.disable_static()

def test_index_is_int(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[3, 2]
np_tmp[2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[3, 2]
x_tmp[2] = 20

np.testing.assert_allclose(x.numpy(), np_data)

def test_index_is_0dTensor(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[3, 2]
np_tmp[2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[paddle.to_tensor(3), paddle.to_tensor(2)]
x_tmp[2] = 20

np.testing.assert_allclose(x.numpy(), np_data)

def test_index_is_slice(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[::2, :, 0:4]
np_tmp[2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[::2, :, 0:4]
x_tmp[2] = 20

np.testing.assert_allclose(x.numpy(), np_data)

def test_index_is_None(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[None]
np_tmp[:, 2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[None]
x_tmp[:, 2] = 20

np.testing.assert_allclose(x.numpy(), np_data)

def test_index_is_ellipsis(self):
np_data = np.ones((5, 5, 5), dtype='float32')
np_tmp = np_data[...]
np_tmp[2] = 20

x = paddle.ones((5, 5, 5), dtype='float32')
x_tmp = x[...]
x_tmp[2] = 20

np.testing.assert_allclose(x.numpy(), np_data)


class TestGetItemErrorCase(unittest.TestCase):
def setUp(self):
paddle.disable_static()
Expand Down
4 changes: 2 additions & 2 deletions test/legacy_test/test_set_value_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -1878,13 +1878,13 @@ def test_inplace(self):
paddle.seed(100)
a = paddle.rand(shape=[1, 4])
a.stop_gradient = False
b = a[:]
b = a[:] * 1
c = b
b[paddle.zeros([], dtype='int32')] = 1.0

self.assertTrue(id(b) == id(c))
np.testing.assert_array_equal(b.numpy(), c.numpy())
self.assertEqual(b.inplace_version, 0)
self.assertEqual(b.inplace_version, 1)

paddle.enable_static()

Expand Down

0 comments on commit b9673f6

Please sign in to comment.