Skip to content

Commit

Permalink
EHN : Adding searchsorted n jax devicearray.py (ivy-llc#18980)
Browse files Browse the repository at this point in the history
  • Loading branch information
MuhammadNizamani authored Aug 9, 2023
1 parent 61bced0 commit 0870da3
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 0 deletions.
3 changes: 3 additions & 0 deletions ivy/functional/frontends/jax/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,9 @@ def __iter__(self):
def round(self, decimals=0):
return jax_frontend.numpy.round(self, decimals)

def searchsorted(self, v, side="left", sorter=None, *, method="scan"):
return jax_frontend.numpy.searchsorted(self, v, side=side, sorter=sorter)

def ptp(self, *, axis=None, out=None, keepdims=False):
return jax_frontend.numpy.ptp(self, axis=axis, keepdims=keepdims)

Expand Down
60 changes: 60 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_jax/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2279,3 +2279,63 @@ def test_jax_array_ptp(
method_flags=method_flags,
on_device=on_device,
)


# searchsorted
@st.composite
def _searchsorted(draw):
dtype_x, x = draw(
helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes(
"numeric", full=False, key="searchsorted"
),
shape=(draw(st.integers(min_value=1, max_value=10)),),
),
)
dtype_v, v = draw(
helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes(
"numeric", full=False, key="searchsorted"
),
min_num_dims=1,
)
)

input_dtypes = dtype_x + dtype_v
xs = x + v
side = draw(st.sampled_from(["left", "right"]))
sorter = None
xs[0] = np.sort(xs[0], axis=-1)
return input_dtypes, xs, side, sorter


@handle_frontend_method(
class_tree=CLASS_TREE,
init_tree="jax.numpy.array",
method_name="searchsorted",
dtype_x_v_side_sorter=_searchsorted(),
)
def test_jax_array_searchsorted(
dtype_x_v_side_sorter,
frontend,
frontend_method_data,
init_flags,
method_flags,
on_device,
backend_fw,
):
input_dtypes, xs, side, sorter = dtype_x_v_side_sorter
helpers.test_frontend_method(
init_input_dtypes=input_dtypes,
backend_to_test=backend_fw,
init_all_as_kwargs_np={
"object": xs[0],
},
method_input_dtypes=input_dtypes,
method_all_as_kwargs_np={"v": xs[0], "side": side, "sorter": sorter},
frontend=frontend,
frontend_method_data=frontend_method_data,
init_flags=init_flags,
method_flags=method_flags,
on_device=on_device,
)

0 comments on commit 0870da3

Please sign in to comment.