Skip to content

Commit

Permalink
added unsorted_segment_min to ivy functional API (#17833)
Browse files Browse the repository at this point in the history
  • Loading branch information
akshatvishu authored Jul 3, 2023
1 parent 570531b commit c7b1c47
Show file tree
Hide file tree
Showing 10 changed files with 398 additions and 4 deletions.
37 changes: 37 additions & 0 deletions ivy/data_classes/array/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,3 +54,40 @@ def eye_like(
0., 1., 0.]])
"""
return ivy.eye_like(self._data, k=k, dtype=dtype, device=device, out=out)

def unsorted_segment_min(
self: ivy.Array,
segment_ids: ivy.Array,
num_segments: Union[int, ivy.Array],
) -> ivy.Array:
r"""
ivy.Array instance method variant of ivy.unsorted_segment_min. This method
simply wraps the function, and so the docstring for ivy.unsorted_segment_min
also applies to this method with minimal changes.
Note
----
If the given segment ID `i` is negative, then the corresponding
value is dropped, and will not be included in the result.
Parameters
----------
self
The array from which to gather values.
segment_ids
Must be in the same size with the first dimension of `self`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `self`.
num_segments
An integer or array representing the total number of distinct segment IDs.
Returns
-------
ret
The output array, representing the result of a segmented min operation.
For each segment, it computes the min value in `self` where `segment_ids`
equals to segment ID.
"""
return ivy.unsorted_segment_min(self._data, segment_ids, num_segments)
100 changes: 100 additions & 0 deletions ivy/data_classes/container/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -762,3 +762,103 @@ def eye_like(
dtype=dtype,
device=device,
)

@staticmethod
def static_unsorted_segment_min(
data: ivy.Container,
segment_ids: ivy.Container,
num_segments: Union[int, ivy.Container],
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
) -> ivy.Container:
r"""
ivy.Container instance method variant of ivy.unsorted_segment_min. This method
simply wraps the function, and so the docstring for ivy.unsorted_segment_min
also applies to this method with minimal changes.
Note
----
If the given segment ID `i` is negative, then the corresponding
value is dropped, and will not be included in the result.
Parameters
----------
data
input array or container from which to gather the input.
segment_ids
Must be in the same size with the first dimension of `data`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `data`.
num_segments
An integer or array representing the total number of distinct segment IDs.
key_chains
The key-chains to apply or not apply the method to. Default is ``None``.
to_apply
If True, the method will be applied to key_chains, otherwise key_chains
will be skipped. Default is ``True``.
prune_unapplied
Whether to prune key_chains for which the function was not applied.
Default is ``False``.
map_sequences
Whether to also map method to sequences (lists, tuples).
Default is ``False``.
Returns
-------
ret
A container, representing the result of a segmented min operation.
For each segment, it computes the min value in `data` where `segment_ids`
equals to segment ID.
"""
return ContainerBase.cont_multi_map_in_function(
"unsorted_segment_min",
data,
segment_ids,
num_segments,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
)

def unsorted_segment_min(
self: ivy.Container,
segment_ids: ivy.Container,
num_segments: Union[int, ivy.Container],
):
r"""
ivy.Container instance method variant of ivy.unsorted_segment_min. This method
simply wraps the function, and so the docstring for ivy.unsorted_segment_min
also applies to this method with minimal changes.
Note
----
If the given segment ID `i` is negative, then the corresponding
value is dropped, and will not be included in the result.
Parameters
----------
self
input array or container from which to gather the input.
segment_ids
Must be in the same size with the first dimension of `self`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `self`.
num_segments
An integer or array representing the total number of distinct segment IDs.
Returns
-------
ret
A container, representing the result of a segmented min operation.
For each segment, it computes the min value in `self` where `segment_ids`
equals to segment ID.
"""
return self.static_unsorted_segment_min(
self,
segment_ids,
num_segments,
)
13 changes: 13 additions & 0 deletions ivy/functional/backends/jax/experimental/creation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# global
from typing import Optional, Tuple
import math
import jax
import jax.numpy as jnp
import jaxlib.xla_extension

Expand Down Expand Up @@ -78,3 +79,15 @@ def tril_indices(
jnp.tril_indices(n=n_rows, k=k, m=n_cols),
device=device,
)


def unsorted_segment_min(
data: JaxArray,
segment_ids: JaxArray,
num_segments: int,
) -> JaxArray:
# added this check to keep the same behaviour as tensorflow
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)
return jax.ops.segment_min(data, segment_ids, num_segments)
26 changes: 26 additions & 0 deletions ivy/functional/backends/numpy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,29 @@ def indices(
sparse: bool = False,
) -> Union[np.ndarray, Tuple[np.ndarray, ...]]:
return np.indices(dimensions, dtype=dtype, sparse=sparse)


def unsorted_segment_min(
data: np.ndarray,
segment_ids: np.ndarray,
num_segments: int,
) -> np.ndarray:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)

if data.dtype in [np.float32, np.float64]:
init_val = np.finfo(data.dtype).max
elif data.dtype in [np.int32, np.int64, np.int8, np.int16, np.uint8]:
init_val = np.iinfo(data.dtype).max
else:
raise ValueError("Unsupported data type")

res = np.full((num_segments,) + data.shape[1:], init_val, dtype=data.dtype)

for i in range(num_segments):
mask_index = segment_ids == i
if np.any(mask_index):
res[i] = np.min(data[mask_index], axis=0)

return res
41 changes: 39 additions & 2 deletions ivy/functional/backends/paddle/experimental/creation.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
# global
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import math
import paddle
import ivy.functional.backends.paddle as paddle_backend
from paddle.fluid.libpaddle import Place
from ivy.functional.backends.paddle.device import to_device
from ivy.func_wrapper import (
with_supported_dtypes,
)


# local
import ivy

from .. import backend_version

# noinspection PyProtectedMember
# Helpers for calculating Window Functions
Expand Down Expand Up @@ -92,3 +96,36 @@ def tril_indices(
paddle.tril_indices(n_rows, col=n_cols, offset=k, dtype="int64"), device
)
)


@with_supported_dtypes(
{"2.4.2 and below": ("float64", "float32", "int32", "int64")},
backend_version,
)
def unsorted_segment_min(
data: paddle.Tensor,
segment_ids: paddle.Tensor,
num_segments: Union[int, paddle.Tensor],
) -> paddle.Tensor:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)
if data.dtype == paddle.float32:
init_val = 3.4028234663852886e38 # float32 max
elif data.dtype == paddle.float64:
init_val = 1.7976931348623157e308 # float64 max
elif data.dtype == paddle.int32:
init_val = 2147483647
elif data.dtype == paddle.int64:
init_val = 9223372036854775807
else:
raise ValueError("Unsupported data type")
# Using paddle.full is causing interger overflow for int64
res = paddle.empty((num_segments,) + tuple(data.shape[1:]), dtype=data.dtype)
res[:] = init_val
for i in range(num_segments):
mask_index = segment_ids == i
if paddle.any(mask_index):
res[i] = paddle.min(data[mask_index], 0)

return res
8 changes: 8 additions & 0 deletions ivy/functional/backends/tensorflow/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,11 @@ def tril_indices(
return tuple(tf.convert_to_tensor(ret, dtype=tf.int64))

return tuple(tf.convert_to_tensor(ret, dtype=tf.int64))


def unsorted_segment_min(
data: tf.Tensor,
segment_ids: tf.Tensor,
num_segments: Union[int, tf.Tensor],
) -> tf.Tensor:
return tf.math.unsorted_segment_min(data, segment_ids, num_segments)
28 changes: 27 additions & 1 deletion ivy/functional/backends/torch/experimental/creation.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Optional, Tuple
from typing import Optional, Tuple, Union
import math
import torch

Expand Down Expand Up @@ -124,3 +124,29 @@ def tril_indices(
row=n_rows, col=n_cols, offset=k, dtype=torch.int64, device=device
)
)


def unsorted_segment_min(
data: torch.Tensor,
segment_ids: torch.Tensor,
num_segments: Union[int, torch.Tensor],
) -> torch.Tensor:
ivy.utils.assertions.check_unsorted_segment_min_valid_params(
data, segment_ids, num_segments
)
if data.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]:
init_val = torch.finfo(data.dtype).max
elif data.dtype in [torch.int32, torch.int64, torch.int8, torch.int16, torch.uint8]:
init_val = torch.iinfo(data.dtype).max
else:
raise ValueError("Unsupported data type")

res = torch.full(
(num_segments,) + data.shape[1:], init_val, dtype=data.dtype, device=data.device
)
for i in range(num_segments):
mask_index = segment_ids == i
if torch.any(mask_index):
res[i] = torch.min(data[mask_index], 0)[0]

return res
40 changes: 40 additions & 0 deletions ivy/functional/ivy/experimental/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,3 +583,43 @@ def indices(
else:
grid = ivy.meshgrid(*[ivy.arange(dim) for dim in dimensions], indexing="ij")
return ivy.stack(grid, axis=0).astype(dtype)


@handle_exceptions
@handle_nestable
@to_native_arrays_and_back
def unsorted_segment_min(
data: Union[ivy.Array, ivy.NativeArray],
segment_ids: Union[ivy.Array, ivy.NativeArray],
num_segments: Union[int, ivy.Array, ivy.NativeArray],
) -> ivy.Array:
r"""
Compute the minimum along segments of an array. Segments are defined by an integer
array of segment IDs.
Note
----
If the given segment ID `i` is negative, then the corresponding
value is dropped, and will not be included in the result.
Parameters
----------
data
The array from which to gather values.
segment_ids
Must be in the same size with the first dimension of `data`. Has to be
of integer data type. The index-th element of `segment_ids` array is
the segment identifier for the index-th element of `data`.
num_segments
An integer or array representing the total number of distinct segment IDs.
Returns
-------
ret
The output array, representing the result of a segmented min operation.
For each segment, it computes the min value in `data` where `segment_ids`
equals to segment ID.
"""
return ivy.current_backend().unsorted_segment_min(data, segment_ids, num_segments)
Loading

0 comments on commit c7b1c47

Please sign in to comment.