From c7b1c47250abbbea088abd576fab2efb84420eb2 Mon Sep 17 00:00:00 2001 From: akshatvishu <33392262+akshatvishu@users.noreply.github.com> Date: Mon, 3 Jul 2023 15:49:48 +0530 Subject: [PATCH] added `unsorted_segment_min` to ivy functional API (#17833) --- .../array/experimental/creation.py | 37 +++++++ .../container/experimental/creation.py | 100 ++++++++++++++++++ .../backends/jax/experimental/creation.py | 13 +++ .../backends/numpy/experimental/creation.py | 26 +++++ .../backends/paddle/experimental/creation.py | 41 ++++++- .../tensorflow/experimental/creation.py | 8 ++ .../backends/torch/experimental/creation.py | 28 ++++- ivy/functional/ivy/experimental/creation.py | 40 +++++++ ivy/utils/assertions.py | 39 ++++++- .../test_core/test_creation.py | 70 ++++++++++++ 10 files changed, 398 insertions(+), 4 deletions(-) diff --git a/ivy/data_classes/array/experimental/creation.py b/ivy/data_classes/array/experimental/creation.py index c3acbf4089351..a9a158421c647 100644 --- a/ivy/data_classes/array/experimental/creation.py +++ b/ivy/data_classes/array/experimental/creation.py @@ -54,3 +54,40 @@ def eye_like( 0., 1., 0.]]) """ return ivy.eye_like(self._data, k=k, dtype=dtype, device=device, out=out) + + def unsorted_segment_min( + self: ivy.Array, + segment_ids: ivy.Array, + num_segments: Union[int, ivy.Array], + ) -> ivy.Array: + r""" + ivy.Array instance method variant of ivy.unsorted_segment_min. This method + simply wraps the function, and so the docstring for ivy.unsorted_segment_min + also applies to this method with minimal changes. + + Note + ---- + If the given segment ID `i` is negative, then the corresponding + value is dropped, and will not be included in the result. + + Parameters + ---------- + self + The array from which to gather values. + + segment_ids + Must be in the same size with the first dimension of `self`. Has to be + of integer data type. The index-th element of `segment_ids` array is + the segment identifier for the index-th element of `self`. + + num_segments + An integer or array representing the total number of distinct segment IDs. + + Returns + ------- + ret + The output array, representing the result of a segmented min operation. + For each segment, it computes the min value in `self` where `segment_ids` + equals to segment ID. + """ + return ivy.unsorted_segment_min(self._data, segment_ids, num_segments) diff --git a/ivy/data_classes/container/experimental/creation.py b/ivy/data_classes/container/experimental/creation.py index 887f42b6beb0d..2ceadfcaa63ad 100644 --- a/ivy/data_classes/container/experimental/creation.py +++ b/ivy/data_classes/container/experimental/creation.py @@ -762,3 +762,103 @@ def eye_like( dtype=dtype, device=device, ) + + @staticmethod + def static_unsorted_segment_min( + data: ivy.Container, + segment_ids: ivy.Container, + num_segments: Union[int, ivy.Container], + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + ) -> ivy.Container: + r""" + ivy.Container instance method variant of ivy.unsorted_segment_min. This method + simply wraps the function, and so the docstring for ivy.unsorted_segment_min + also applies to this method with minimal changes. + + Note + ---- + If the given segment ID `i` is negative, then the corresponding + value is dropped, and will not be included in the result. + + Parameters + ---------- + data + input array or container from which to gather the input. + segment_ids + Must be in the same size with the first dimension of `data`. Has to be + of integer data type. The index-th element of `segment_ids` array is + the segment identifier for the index-th element of `data`. + num_segments + An integer or array representing the total number of distinct segment IDs. + key_chains + The key-chains to apply or not apply the method to. Default is ``None``. + to_apply + If True, the method will be applied to key_chains, otherwise key_chains + will be skipped. Default is ``True``. + prune_unapplied + Whether to prune key_chains for which the function was not applied. + Default is ``False``. + map_sequences + Whether to also map method to sequences (lists, tuples). + Default is ``False``. + + Returns + ------- + ret + A container, representing the result of a segmented min operation. + For each segment, it computes the min value in `data` where `segment_ids` + equals to segment ID. + """ + return ContainerBase.cont_multi_map_in_function( + "unsorted_segment_min", + data, + segment_ids, + num_segments, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + ) + + def unsorted_segment_min( + self: ivy.Container, + segment_ids: ivy.Container, + num_segments: Union[int, ivy.Container], + ): + r""" + ivy.Container instance method variant of ivy.unsorted_segment_min. This method + simply wraps the function, and so the docstring for ivy.unsorted_segment_min + also applies to this method with minimal changes. + + Note + ---- + If the given segment ID `i` is negative, then the corresponding + value is dropped, and will not be included in the result. + + Parameters + ---------- + self + input array or container from which to gather the input. + segment_ids + Must be in the same size with the first dimension of `self`. Has to be + of integer data type. The index-th element of `segment_ids` array is + the segment identifier for the index-th element of `self`. + num_segments + An integer or array representing the total number of distinct segment IDs. + + Returns + ------- + ret + A container, representing the result of a segmented min operation. + For each segment, it computes the min value in `self` where `segment_ids` + equals to segment ID. + """ + return self.static_unsorted_segment_min( + self, + segment_ids, + num_segments, + ) diff --git a/ivy/functional/backends/jax/experimental/creation.py b/ivy/functional/backends/jax/experimental/creation.py index a4f94da4fbd5b..df22779e6d27c 100644 --- a/ivy/functional/backends/jax/experimental/creation.py +++ b/ivy/functional/backends/jax/experimental/creation.py @@ -1,6 +1,7 @@ # global from typing import Optional, Tuple import math +import jax import jax.numpy as jnp import jaxlib.xla_extension @@ -78,3 +79,15 @@ def tril_indices( jnp.tril_indices(n=n_rows, k=k, m=n_cols), device=device, ) + + +def unsorted_segment_min( + data: JaxArray, + segment_ids: JaxArray, + num_segments: int, +) -> JaxArray: + # added this check to keep the same behaviour as tensorflow + ivy.utils.assertions.check_unsorted_segment_min_valid_params( + data, segment_ids, num_segments + ) + return jax.ops.segment_min(data, segment_ids, num_segments) diff --git a/ivy/functional/backends/numpy/experimental/creation.py b/ivy/functional/backends/numpy/experimental/creation.py index 28fe4684f3fe4..fd67888a80b4c 100644 --- a/ivy/functional/backends/numpy/experimental/creation.py +++ b/ivy/functional/backends/numpy/experimental/creation.py @@ -85,3 +85,29 @@ def indices( sparse: bool = False, ) -> Union[np.ndarray, Tuple[np.ndarray, ...]]: return np.indices(dimensions, dtype=dtype, sparse=sparse) + + +def unsorted_segment_min( + data: np.ndarray, + segment_ids: np.ndarray, + num_segments: int, +) -> np.ndarray: + ivy.utils.assertions.check_unsorted_segment_min_valid_params( + data, segment_ids, num_segments + ) + + if data.dtype in [np.float32, np.float64]: + init_val = np.finfo(data.dtype).max + elif data.dtype in [np.int32, np.int64, np.int8, np.int16, np.uint8]: + init_val = np.iinfo(data.dtype).max + else: + raise ValueError("Unsupported data type") + + res = np.full((num_segments,) + data.shape[1:], init_val, dtype=data.dtype) + + for i in range(num_segments): + mask_index = segment_ids == i + if np.any(mask_index): + res[i] = np.min(data[mask_index], axis=0) + + return res diff --git a/ivy/functional/backends/paddle/experimental/creation.py b/ivy/functional/backends/paddle/experimental/creation.py index f7d3f84d85e2f..ba394ae181916 100644 --- a/ivy/functional/backends/paddle/experimental/creation.py +++ b/ivy/functional/backends/paddle/experimental/creation.py @@ -1,14 +1,18 @@ # global -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import math import paddle import ivy.functional.backends.paddle as paddle_backend from paddle.fluid.libpaddle import Place from ivy.functional.backends.paddle.device import to_device +from ivy.func_wrapper import ( + with_supported_dtypes, +) + # local import ivy - +from .. import backend_version # noinspection PyProtectedMember # Helpers for calculating Window Functions @@ -92,3 +96,36 @@ def tril_indices( paddle.tril_indices(n_rows, col=n_cols, offset=k, dtype="int64"), device ) ) + + +@with_supported_dtypes( + {"2.4.2 and below": ("float64", "float32", "int32", "int64")}, + backend_version, +) +def unsorted_segment_min( + data: paddle.Tensor, + segment_ids: paddle.Tensor, + num_segments: Union[int, paddle.Tensor], +) -> paddle.Tensor: + ivy.utils.assertions.check_unsorted_segment_min_valid_params( + data, segment_ids, num_segments + ) + if data.dtype == paddle.float32: + init_val = 3.4028234663852886e38 # float32 max + elif data.dtype == paddle.float64: + init_val = 1.7976931348623157e308 # float64 max + elif data.dtype == paddle.int32: + init_val = 2147483647 + elif data.dtype == paddle.int64: + init_val = 9223372036854775807 + else: + raise ValueError("Unsupported data type") + # Using paddle.full is causing interger overflow for int64 + res = paddle.empty((num_segments,) + tuple(data.shape[1:]), dtype=data.dtype) + res[:] = init_val + for i in range(num_segments): + mask_index = segment_ids == i + if paddle.any(mask_index): + res[i] = paddle.min(data[mask_index], 0) + + return res diff --git a/ivy/functional/backends/tensorflow/experimental/creation.py b/ivy/functional/backends/tensorflow/experimental/creation.py index b9592e7fdebca..a1659c243b3c1 100644 --- a/ivy/functional/backends/tensorflow/experimental/creation.py +++ b/ivy/functional/backends/tensorflow/experimental/creation.py @@ -93,3 +93,11 @@ def tril_indices( return tuple(tf.convert_to_tensor(ret, dtype=tf.int64)) return tuple(tf.convert_to_tensor(ret, dtype=tf.int64)) + + +def unsorted_segment_min( + data: tf.Tensor, + segment_ids: tf.Tensor, + num_segments: Union[int, tf.Tensor], +) -> tf.Tensor: + return tf.math.unsorted_segment_min(data, segment_ids, num_segments) diff --git a/ivy/functional/backends/torch/experimental/creation.py b/ivy/functional/backends/torch/experimental/creation.py index f6a6a2a0b858c..96961dd03cbbd 100644 --- a/ivy/functional/backends/torch/experimental/creation.py +++ b/ivy/functional/backends/torch/experimental/creation.py @@ -1,5 +1,5 @@ # global -from typing import Optional, Tuple +from typing import Optional, Tuple, Union import math import torch @@ -124,3 +124,29 @@ def tril_indices( row=n_rows, col=n_cols, offset=k, dtype=torch.int64, device=device ) ) + + +def unsorted_segment_min( + data: torch.Tensor, + segment_ids: torch.Tensor, + num_segments: Union[int, torch.Tensor], +) -> torch.Tensor: + ivy.utils.assertions.check_unsorted_segment_min_valid_params( + data, segment_ids, num_segments + ) + if data.dtype in [torch.float32, torch.float64, torch.float16, torch.bfloat16]: + init_val = torch.finfo(data.dtype).max + elif data.dtype in [torch.int32, torch.int64, torch.int8, torch.int16, torch.uint8]: + init_val = torch.iinfo(data.dtype).max + else: + raise ValueError("Unsupported data type") + + res = torch.full( + (num_segments,) + data.shape[1:], init_val, dtype=data.dtype, device=data.device + ) + for i in range(num_segments): + mask_index = segment_ids == i + if torch.any(mask_index): + res[i] = torch.min(data[mask_index], 0)[0] + + return res diff --git a/ivy/functional/ivy/experimental/creation.py b/ivy/functional/ivy/experimental/creation.py index 843c95477873f..0745f8d93b818 100644 --- a/ivy/functional/ivy/experimental/creation.py +++ b/ivy/functional/ivy/experimental/creation.py @@ -583,3 +583,43 @@ def indices( else: grid = ivy.meshgrid(*[ivy.arange(dim) for dim in dimensions], indexing="ij") return ivy.stack(grid, axis=0).astype(dtype) + + +@handle_exceptions +@handle_nestable +@to_native_arrays_and_back +def unsorted_segment_min( + data: Union[ivy.Array, ivy.NativeArray], + segment_ids: Union[ivy.Array, ivy.NativeArray], + num_segments: Union[int, ivy.Array, ivy.NativeArray], +) -> ivy.Array: + r""" + Compute the minimum along segments of an array. Segments are defined by an integer + array of segment IDs. + + Note + ---- + If the given segment ID `i` is negative, then the corresponding + value is dropped, and will not be included in the result. + + Parameters + ---------- + data + The array from which to gather values. + + segment_ids + Must be in the same size with the first dimension of `data`. Has to be + of integer data type. The index-th element of `segment_ids` array is + the segment identifier for the index-th element of `data`. + + num_segments + An integer or array representing the total number of distinct segment IDs. + + Returns + ------- + ret + The output array, representing the result of a segmented min operation. + For each segment, it computes the min value in `data` where `segment_ids` + equals to segment ID. + """ + return ivy.current_backend().unsorted_segment_min(data, segment_ids, num_segments) diff --git a/ivy/utils/assertions.py b/ivy/utils/assertions.py index b76f1ec03dac4..a747c71c1c961 100644 --- a/ivy/utils/assertions.py +++ b/ivy/utils/assertions.py @@ -1,5 +1,6 @@ import numpy as np - +import torch +import paddle import ivy @@ -213,6 +214,42 @@ def check_fill_value_and_dtype_are_compatible(fill_value, dtype): ) +def check_unsorted_segment_min_valid_params(data, segment_ids, num_segments): + if not (isinstance(num_segments, int)): + raise ValueError("num_segments must be of integer type") + + valid_dtypes = [ + ivy.int32, + ivy.int64, + torch.int32, + torch.int64, + paddle.int32, + paddle.int64, + ] + + if segment_ids.dtype not in valid_dtypes: + raise ValueError("segment_ids must have an integer dtype") + + if data.shape[0] != segment_ids.shape[0]: + raise ValueError("The length of segment_ids should be equal to data.shape[0].") + + if ivy.backend == "torch": + if isinstance(num_segments, torch.Tensor): + num_segments = num_segments.item() + elif ivy.backend == "paddle": + if isinstance(num_segments, paddle.Tensor): + num_segments = num_segments.item() + + if ivy.max(segment_ids) >= num_segments: + error_message = ( + f"segment_ids[{ivy.argmax(segment_ids)}] = " + f"{ivy.max(segment_ids)} is out of range [0, {num_segments})" + ) + raise ValueError(error_message) + if num_segments <= 0: + raise ValueError("num_segments must be positive") + + # General # # ------- # diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py index 146731435764b..730361103a01c 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_core/test_creation.py @@ -358,3 +358,73 @@ def test_indices( dtype=dtype[0], sparse=sparse, ) + + +@st.composite +def valid_unsorted_segment_min_inputs(draw): + while True: + dtype = draw(st.sampled_from([ivy.int32, ivy.int64, ivy.float32, ivy.float64])) + segment_ids_dim = draw(st.integers(min_value=3, max_value=10)) + num_segments = draw(st.integers(min_value=2, max_value=segment_ids_dim)) + + data_dim = draw( + helpers.get_shape( + min_dim_size=segment_ids_dim, + max_dim_size=segment_ids_dim, + min_num_dims=1, + max_num_dims=4, + ) + ) + data_dim = (segment_ids_dim,) + data_dim[1:] + + data = draw( + helpers.array_values( + dtype=dtype, + shape=data_dim, + min_value=1, + max_value=10, + ) + ) + + segment_ids = draw( + helpers.array_values( + dtype=ivy.int32, + shape=(segment_ids_dim,), + min_value=0, + max_value=num_segments + 1, + ) + ) + if data.shape[0] == segment_ids.shape[0]: + if np.max(segment_ids) < num_segments: + return (dtype, ivy.int32), data, num_segments, segment_ids + + +# unsorted_segment_min +@handle_test( + fn_tree="functional.ivy.experimental.unsorted_segment_min", + ground_truth_backend="tensorflow", + d_x_n_s=valid_unsorted_segment_min_inputs(), + test_with_out=st.just(False), + test_gradients=st.just(False), +) +def test_unsorted_segment_min( + *, + d_x_n_s, + test_flags, + backend_fw, + fn_name, + on_device, + ground_truth_backend, +): + dtypes, data, num_segments, segment_ids = d_x_n_s + helpers.test_function( + input_dtypes=dtypes, + test_flags=test_flags, + ground_truth_backend=ground_truth_backend, + on_device=on_device, + fw=backend_fw, + fn_name=fn_name, + data=data, + segment_ids=segment_ids, + num_segments=num_segments, + )