Skip to content

Commit

Permalink
add index_select(5/5) and index_sample(5/5) (#59040)
Browse files Browse the repository at this point in the history
  • Loading branch information
zrr1999 authored Nov 17, 2023
1 parent 71ffae7 commit 6febbd0
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 37 deletions.
4 changes: 2 additions & 2 deletions python/paddle/tensor/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def index_select(x, index, axis=0, name=None):
[ 9. 10. 10.]]
"""

if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.index_select(x, index, axis)
else:
helper = LayerHelper("index_select", **locals())
Expand Down Expand Up @@ -849,7 +849,7 @@ def index_sample(x, index):
[1200 1100]]
"""
if in_dynamic_mode():
if in_dynamic_or_pir_mode():
return _C_ops.index_sample(x, index)
else:
helper = LayerHelper("index_sample", **locals())
Expand Down
56 changes: 30 additions & 26 deletions test/legacy_test/test_index_sample_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle
from paddle import base
from paddle.base import core
from paddle.pir_utils import test_with_pir_api


class TestIndexSampleOp(OpTest):
Expand All @@ -46,10 +47,10 @@ def setUp(self):
self.outputs = {'Out': out}

def test_check_output(self):
self.check_output()
self.check_output(check_pir=True)

def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', check_pir=True)

def config(self):
"""
Expand Down Expand Up @@ -176,10 +177,10 @@ def setUp(self):
self.place = core.CUDAPlace(0)

def test_check_output(self):
self.check_output_with_place(self.place)
self.check_output_with_place(self.place, check_pir=True)

def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
self.check_grad_with_place(self.place, ['X'], 'Out', check_pir=True)

def config(self):
"""
Expand All @@ -193,30 +194,33 @@ def config(self):


class TestIndexSampleShape(unittest.TestCase):
@test_with_pir_api
def test_shape(self):
paddle.enable_static()
# create x value
x_shape = (2, 5)
x_type = "float64"
x_np = np.random.random(x_shape).astype(x_type)

# create index value
index_shape = (2, 3)
index_type = "int32"
index_np = np.random.randint(
low=0, high=x_shape[1], size=index_shape
).astype(index_type)

x = paddle.static.data(name='x', shape=[-1, 5], dtype='float64')
index = paddle.static.data(name='index', shape=[-1, 3], dtype='int32')
output = paddle.index_sample(x=x, index=index)

place = base.CPUPlace()
exe = base.Executor(place=place)
exe.run(base.default_startup_program())

feed = {'x': x_np, 'index': index_np}
res = exe.run(feed=feed, fetch_list=[output])
with paddle.static.program_guard(paddle.static.Program()):
# create x value
x_shape = (2, 5)
x_type = "float64"
x_np = np.random.random(x_shape).astype(x_type)

# create index value
index_shape = (2, 3)
index_type = "int32"
index_np = np.random.randint(
low=0, high=x_shape[1], size=index_shape
).astype(index_type)

x = paddle.static.data(name='x', shape=[-1, 5], dtype='float64')
index = paddle.static.data(
name='index', shape=[-1, 3], dtype='int32'
)
output = paddle.index_sample(x=x, index=index)

place = base.CPUPlace()
exe = base.Executor(place=place)

feed = {'x': x_np, 'index': index_np}
res = exe.run(feed=feed, fetch_list=[output])


class TestIndexSampleDynamic(unittest.TestCase):
Expand Down
24 changes: 15 additions & 9 deletions test/legacy_test/test_index_select_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import paddle
from paddle import base
from paddle.base import Program, core, program_guard
from paddle.pir_utils import test_with_pir_api

np.random.seed(1024)

Expand Down Expand Up @@ -66,15 +67,15 @@ def init_dtype_type(self):

def test_check_output(self):
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_output(check_prim=False)
self.check_output(check_prim=False, check_pir=True)
else:
self.check_output(check_prim=True)
self.check_output(check_prim=True, check_pir=True)

def test_check_grad_normal(self):
if self.x_type == np.complex64 or self.x_type == np.complex128:
self.check_grad(['X'], 'Out', check_prim=False)
self.check_grad(['X'], 'Out', check_prim=False, check_pir=True)
else:
self.check_grad(['X'], 'Out', check_prim=True)
self.check_grad(['X'], 'Out', check_prim=True, check_pir=True)


class TestIndexSelectOpCase2(TestIndexSelectOp):
Expand Down Expand Up @@ -150,11 +151,13 @@ def init_dtype_type(self):

def test_check_output(self):
place = core.CUDAPlace(0)
self.check_output_with_place(place)
self.check_output_with_place(place, check_pir=True)

def test_check_grad_normal(self):
place = core.CUDAPlace(0)
self.check_grad_with_place(place, ['X'], 'Out', check_prim=True)
self.check_grad_with_place(
place, ['X'], 'Out', check_prim=True, check_pir=True
)


class TestIndexSelectComplex64(TestIndexSelectOp):
Expand Down Expand Up @@ -186,6 +189,7 @@ def input_data(self):
).astype("float32")
self.data_index = np.array([0, 1, 1]).astype('int32')

@test_with_pir_api
def test_index_select_api(self):
paddle.enable_static()
self.input_data()
Expand All @@ -198,7 +202,7 @@ def test_index_select_api(self):
exe = base.Executor(base.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_x, 'index': self.data_index},
fetch_list=[z.name],
fetch_list=[z],
return_numpy=False,
)
expect_out = np.array(
Expand All @@ -207,14 +211,16 @@ def test_index_select_api(self):
np.testing.assert_allclose(expect_out, np.array(res), rtol=1e-05)

# case 2:
with program_guard(Program(), Program()):
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
x = paddle.static.data(name='x', shape=[-1, 4])
index = paddle.static.data(name='index', shape=[3], dtype='int32')
z = paddle.index_select(x, index)
exe = base.Executor(base.CPUPlace())
(res,) = exe.run(
feed={'x': self.data_x, 'index': self.data_index},
fetch_list=[z.name],
fetch_list=[z],
return_numpy=False,
)
expect_out = np.array(
Expand Down

0 comments on commit 6febbd0

Please sign in to comment.