Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added unsorted_segment_min to ivy functional API #17833

Merged
merged 3 commits into from
Jul 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions ivy/data_classes/array/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,40 @@ def eye_like(
0., 1., 0.]])
"""
return ivy.eye_like(self._data, k=k, dtype=dtype, device=device, out=out)

def unsorted_segment_min(
self: ivy.Array,
segment_ids: ivy.Array,
num_segments: Union[int, ivy.Array],
) -> ivy.Array:
r"""
ivy.Array instance method variant of ivy.unsorted_segment_min. This method
simply wraps the function, and so the docstring for ivy.unsorted_segment_min
also applies to this method with minimal changes.

Note
----
If the given segment ID `i` is negative, then the corresponding
value is dropped, and will not be included in the result.

Parameters
----------
self
The array from which to gather values.

segment_ids
Must be in the same size with the first dimension of `self`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `self`.

num_segments
An integer or array representing the total number of distinct segment IDs.

Returns
-------
ret
The output array, representing the result of a segmented min operation.
For each segment, it computes the min value in `self` where `segment_ids`
equals to segment ID.
"""
return ivy.unsorted_segment_min(self._data, segment_ids, num_segments)
100 changes: 100 additions & 0 deletions ivy/data_classes/container/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,3 +762,103 @@ def eye_like(
dtype=dtype,
device=device,
)

@staticmethod
def static_unsorted_segment_min(
data: ivy.Container,
segment_ids: ivy.Container,
num_segments: Union[int, ivy.Container],
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
) -> ivy.Container:
r"""
ivy.Container instance method variant of ivy.unsorted_segment_min. This method
simply wraps the function, and so the docstring for ivy.unsorted_segment_min
also applies to this method with minimal changes.

Note
----
If the given segment ID `i` is negative, then the corresponding
value is dropped, and will not be included in the result.

Parameters
----------
data
input array or container from which to gather the input.
segment_ids
Must be in the same size with the first dimension of `data`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `data`.
num_segments
An integer or array representing the total number of distinct segment IDs.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.

Returns
-------
ret
A container, representing the result of a segmented min operation.
For each segment, it computes the min value in `data` where `segment_ids`
equals to segment ID.
"""
return ContainerBase.cont_multi_map_in_function(
"unsorted_segment_min",
data,
segment_ids,
num_segments,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

def unsorted_segment_min(
self: ivy.Container,
segment_ids: ivy.Container,
num_segments: Union[int, ivy.Container],
):
r"""
ivy.Container instance method variant of ivy.unsorted_segment_min. This method
simply wraps the function, and so the docstring for ivy.unsorted_segment_min
also applies to this method with minimal changes.

Note
----
If the given segment ID `i` is negative, then the corresponding
value is dropped, and will not be included in the result.

Parameters
----------
self
input array or container from which to gather the input.
segment_ids
Must be in the same size with the first dimension of `self`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `self`.
num_segments
An integer or array representing the total number of distinct segment IDs.

Returns
-------
ret
A container, representing the result of a segmented min operation.
For each segment, it computes the min value in `self` where `segment_ids`
equals to segment ID.
"""
return self.static_unsorted_segment_min(
self,
segment_ids,
num_segments,
)
13 changes: 13 additions & 0 deletions ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# global
from typing import Optional, Tuple
import math
import jax
import jax.numpy as jnp
import jaxlib.xla_extension

Expand Down Expand Up @@ -78,3 +79,15 @@ def tril_indices(
jnp.tril_indices(n=n_rows, k=k, m=n_cols),
device=device,
)


def unsorted_segment_min(
data: JaxArray,
segment_ids: JaxArray,
num_segments: int,
) -> JaxArray:
# added this check to keep the same behaviour as tensorflow
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)
return jax.ops.segment_min(data, segment_ids, num_segments)
26 changes: 26 additions & 0 deletions ivy/functional/backends/numpy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,29 @@ def indices(
sparse: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, ...]]:
return np.indices(dimensions, dtype=dtype, sparse=sparse)


def unsorted_segment_min(
data: np.ndarray,
segment_ids: np.ndarray,
num_segments: int,
) -> np.ndarray:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)

if data.dtype in [np.float32, np.float64]:
init_val = np.finfo(data.dtype).max
elif data.dtype in [np.int32, np.int64, np.int8, np.int16, np.uint8]:
init_val = np.iinfo(data.dtype).max
else:
raise ValueError("Unsupported data type")

res = np.full((num_segments,) + data.shape[1:], init_val, dtype=data.dtype)

for i in range(num_segments):
mask_index = segment_ids == i
if np.any(mask_index):
res[i] = np.min(data[mask_index], axis=0)

return res
41 changes: 39 additions & 2 deletions ivy/functional/backends/paddle/experimental/creation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# global
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import math
import paddle
import ivy.functional.backends.paddle as paddle_backend
from paddle.fluid.libpaddle import Place
from ivy.functional.backends.paddle.device import to_device
from ivy.func_wrapper import (
with_supported_dtypes,
)


# local
import ivy

from .. import backend_version

# noinspection PyProtectedMember
# Helpers for calculating Window Functions
Expand Down Expand Up @@ -92,3 +96,36 @@ def tril_indices(
paddle.tril_indices(n_rows, col=n_cols, offset=k, dtype="int64"), device
)
)


@with_supported_dtypes(
{"2.4.2 and below": ("float64", "float32", "int32", "int64")},
backend_version,
)
def unsorted_segment_min(
data: paddle.Tensor,
segment_ids: paddle.Tensor,
num_segments: Union[int, paddle.Tensor],
) -> paddle.Tensor:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)
if data.dtype == paddle.float32:
init_val = 3.4028234663852886e38 # float32 max
elif data.dtype == paddle.float64:
init_val = 1.7976931348623157e308 # float64 max
elif data.dtype == paddle.int32:
init_val = 2147483647
elif data.dtype == paddle.int64:
init_val = 9223372036854775807
else:
raise ValueError("Unsupported data type")
# Using paddle.full is causing interger overflow for int64
res = paddle.empty((num_segments,) + tuple(data.shape[1:]), dtype=data.dtype)
res[:] = init_val
for i in range(num_segments):
mask_index = segment_ids == i
if paddle.any(mask_index):
res[i] = paddle.min(data[mask_index], 0)

return res
8 changes: 8 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,11 @@ def tril_indices(
return tuple(tf.convert_to_tensor(ret, dtype=tf.int64))

return tuple(tf.convert_to_tensor(ret, dtype=tf.int64))


def unsorted_segment_min(
data: tf.Tensor,
segment_ids: tf.Tensor,
num_segments: Union[int, tf.Tensor],
) -> tf.Tensor:
return tf.math.unsorted_segment_min(data, segment_ids, num_segments)
28 changes: 27 additions & 1 deletion ivy/functional/backends/torch/experimental/creation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import math
import torch

Expand Down Expand Up @@ -124,3 +124,29 @@ def tril_indices(
row=n_rows, col=n_cols, offset=k, dtype=torch.int64, device=device
)
)


def unsorted_segment_min(
data: torch.Tensor,
segment_ids: torch.Tensor,
num_segments: Union[int, torch.Tensor],
) -> torch.Tensor:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)
if data.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]:
init_val = torch.finfo(data.dtype).max
elif data.dtype in [torch.int32, torch.int64, torch.int8, torch.int16, torch.uint8]:
init_val = torch.iinfo(data.dtype).max
else:
raise ValueError("Unsupported data type")

res = torch.full(
(num_segments,) + data.shape[1:], init_val, dtype=data.dtype, device=data.device
)
for i in range(num_segments):
mask_index = segment_ids == i
if torch.any(mask_index):
res[i] = torch.min(data[mask_index], 0)[0]

return res
40 changes: 40 additions & 0 deletions ivy/functional/ivy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,43 @@ def indices(
else:
grid = ivy.meshgrid(*[ivy.arange(dim) for dim in dimensions], indexing="ij")
return ivy.stack(grid, axis=0).astype(dtype)


@handle_exceptions
@handle_nestable
@to_native_arrays_and_back
def unsorted_segment_min(
data: Union[ivy.Array, ivy.NativeArray],
segment_ids: Union[ivy.Array, ivy.NativeArray],
num_segments: Union[int, ivy.Array, ivy.NativeArray],
) -> ivy.Array:
r"""
Compute the minimum along segments of an array. Segments are defined by an integer
array of segment IDs.

Note
----
If the given segment ID `i` is negative, then the corresponding
value is dropped, and will not be included in the result.

Parameters
----------
data
The array from which to gather values.

segment_ids
Must be in the same size with the first dimension of `data`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `data`.

num_segments
An integer or array representing the total number of distinct segment IDs.

Returns
-------
ret
The output array, representing the result of a segmented min operation.
For each segment, it computes the min value in `data` where `segment_ids`
equals to segment ID.
"""
return ivy.current_backend().unsorted_segment_min(data, segment_ids, num_segments)
Loading