From cdd53f32b08f90baa6575095116360e10d0becb5 Mon Sep 17 00:00:00 2001 From: nassimberrada Date: Tue, 11 Oct 2022 19:32:49 +0100 Subject: [PATCH 1/7] moveaxis_extension --- ivy/functional/backends/jax/extensions.py | 13 ++- ivy/functional/backends/numpy/extensions.py | 16 +++- .../backends/tensorflow/extensions.py | 13 ++- ivy/functional/backends/torch/extensions.py | 16 +++- ivy/functional/ivy/extensions.py | 47 +++++++++- .../test_functional/test_extensions.py | 86 +++++++++++++++++++ 6 files changed, 186 insertions(+), 5 deletions(-) diff --git a/ivy/functional/backends/jax/extensions.py b/ivy/functional/backends/jax/extensions.py index 703621ac09ec2..cbf9e1ae72ebf 100644 --- a/ivy/functional/backends/jax/extensions.py +++ b/ivy/functional/backends/jax/extensions.py @@ -1,5 +1,5 @@ import logging -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Sequence import ivy from ivy.functional.ivy.extensions import ( _verify_coo_components, @@ -156,3 +156,14 @@ def max_pool2d( return jnp.transpose(res, (0, 3, 1, 2)) return res + + +def moveaxis( + a: JaxArray, + source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], + /, + *, + out: Optional[JaxArray] = None, +) -> JaxArray: + return jnp.moveaxis(a, source, destination) diff --git a/ivy/functional/backends/numpy/extensions.py b/ivy/functional/backends/numpy/extensions.py index 8930ce00205ca..23b1ee0e0b908 100644 --- a/ivy/functional/backends/numpy/extensions.py +++ b/ivy/functional/backends/numpy/extensions.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Sequence import logging import ivy import numpy as np @@ -176,3 +176,17 @@ def max_pool2d( if data_format == "NCHW": return np.transpose(res, (0, 3, 1, 2)) return res + + +def moveaxis( + a: np.ndarray, + source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], + /, + *, + out: Optional[np.ndarray] = None, +) -> np.ndarray: + return np.moveaxis(a, source, destination) + + +moveaxis.support_native_out = False diff --git a/ivy/functional/backends/tensorflow/extensions.py b/ivy/functional/backends/tensorflow/extensions.py index 5287b7f46a081..08adf68c6a95d 100644 --- a/ivy/functional/backends/tensorflow/extensions.py +++ b/ivy/functional/backends/tensorflow/extensions.py @@ -1,4 +1,4 @@ -from typing import Union, Optional, Tuple +from typing import Union, Optional, Tuple, Sequence import ivy from ivy.functional.ivy.extensions import ( _verify_coo_components, @@ -138,3 +138,14 @@ def max_pool2d( if data_format == "NCHW": return tf.transpose(res, (0, 3, 1, 2)) return res + + +def moveaxis( + a: Union[tf.Tensor, tf.Variable], + source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], + /, + *, + out: Optional[Union[tf.Tensor, tf.Variable]] = None, +) -> Union[tf.Tensor, tf.Variable]: + return tf.experimental.numpy.moveaxis(a, source, destination) diff --git a/ivy/functional/backends/torch/extensions.py b/ivy/functional/backends/torch/extensions.py index c00bd49c2fb14..d944cd89017f2 100644 --- a/ivy/functional/backends/torch/extensions.py +++ b/ivy/functional/backends/torch/extensions.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Sequence import ivy from ivy.functional.ivy.extensions import ( _verify_coo_components, @@ -183,3 +183,17 @@ def max_pool2d( max_pool2d.unsupported_dtypes = ("bfloat16", "float16") + + +def moveaxis( + a: torch.Tensor, + source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], + /, + *, + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.moveaxis(a, source, destination) + + +moveaxis.support_native_out = False diff --git a/ivy/functional/ivy/extensions.py b/ivy/functional/ivy/extensions.py index 2213333182229..f029903f6346a 100644 --- a/ivy/functional/ivy/extensions.py +++ b/ivy/functional/ivy/extensions.py @@ -1,4 +1,4 @@ -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Sequence import ivy from ivy.func_wrapper import ( handle_out_argument, @@ -783,3 +783,48 @@ def max_pool2d( """ return ivy.current_backend(x).max_pool2d(x, kernel, strides, padding, out=out) + + +@to_native_arrays_and_back +@handle_out_argument +# @handle_nestable +def moveaxis( + a: Union[ivy.Array, ivy.NativeArray], + source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], + /, + *, + out: Optional[Union[ivy.Array, ivy.NativeArray]] = None, +) -> Union[ivy.Array, ivy.NativeArray]: + """Move axes of an array to new positions.. + + Parameters + ---------- + a + The array whose axes should be reordered. + source + Original positions of the axes to move. These must be unique. + destination + Destination positions for each of the original axes. + These must also be unique. + out + optional output array, for writing the result to. + + Returns + ------- + ret + Array with moved axes. This array is a view of the input array. + + Examples + -------- + With :class:`ivy.Array` input: + + >>> x = ivy.zeros((3, 4, 5)) + >>> ivy.moveaxis(x, 0, -1).shape + (4, 5, 3) + >>> ivy.moveaxis(x, -1, 0).shape + (5, 3, 4) + """ + return ivy.current_backend().moveaxis( + a, source, destination, out=out + ) diff --git a/ivy_tests/test_ivy/test_functional/test_extensions.py b/ivy_tests/test_ivy/test_functional/test_extensions.py index a313629a314f1..e66f19b3e6e35 100644 --- a/ivy_tests/test_ivy/test_functional/test_extensions.py +++ b/ivy_tests/test_ivy/test_functional/test_extensions.py @@ -350,3 +350,89 @@ def test_max_pool2d( strides=stride, padding=pad ) + + +# moveaxis +@st.composite +def _array_dual_axes(draw): + shape = draw( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + ), + dtype_and_a = draw( + helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-10, + max_value=10, + shape=st.shared( + shape, + key="shared_axes" + ) + ), + ) + source = draw( + st.lists( + helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + shape, + key="shared_axes" + ), + min_size=1, + force_tuple=True, + ), + ) + ) + destination = draw( + st.lists( + helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + shape, + key="shared_axes" + ), + min_size=1, + force_tuple=True, + ), + ) + ) + return dtype_and_a, source, destination + + +@handle_cmd_line_args +@given( + dtype_a_s_d=_array_dual_axes(), + num_positional_args=helpers.num_positional_args(fn_name="moveaxis"), +) +def test_moveaxis( + dtype_a_s_d, + as_variable, + with_out, + num_positional_args, + native_array, + container, + instance_method, + fw, +): + dtype_and_a, source, destination = dtype_a_s_d + input_dtype, a = dtype_and_a + helpers.test_function( + input_dtypes=input_dtype, + as_variable_flags=as_variable, + with_out=with_out, + num_positional_args=num_positional_args, + native_array_flags=native_array, + container_flags=container, + instance_method=instance_method, + fw=fw, + fn_name="moveaxis", + a=np.asarray(a[0], dtype=input_dtype[0]), + source=source[0], + destination=destination[0], + ) From ac812b15bdb41edb70944bf6dfe74b7ac71f6d21 Mon Sep 17 00:00:00 2001 From: nassimberrada Date: Tue, 11 Oct 2022 20:13:52 +0100 Subject: [PATCH 2/7] changes --- ivy_tests/test_ivy/test_functional/test_extensions.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/ivy_tests/test_ivy/test_functional/test_extensions.py b/ivy_tests/test_ivy/test_functional/test_extensions.py index e66f19b3e6e35..0f8e9a8649d4a 100644 --- a/ivy_tests/test_ivy/test_functional/test_extensions.py +++ b/ivy_tests/test_ivy/test_functional/test_extensions.py @@ -384,7 +384,7 @@ def _array_dual_axes(draw): key="shared_axes" ), min_size=1, - force_tuple=True, + force_int=True, ), ) ) @@ -398,7 +398,7 @@ def _array_dual_axes(draw): key="shared_axes" ), min_size=1, - force_tuple=True, + force_int=True, ), ) ) @@ -433,6 +433,6 @@ def test_moveaxis( fw=fw, fn_name="moveaxis", a=np.asarray(a[0], dtype=input_dtype[0]), - source=source[0], - destination=destination[0], + source=source, + destination=destination, ) From bc2b190138e805f1da6fb8fa7673c503bfb38f1a Mon Sep 17 00:00:00 2001 From: nassimberrada Date: Wed, 12 Oct 2022 12:29:04 +0100 Subject: [PATCH 3/7] added container and array methods --- ivy/array/extensions.py | 45 +++++++- ivy/container/extensions.py | 108 +++++++++++++++++- ivy/functional/ivy/extensions.py | 2 +- .../test_functional/test_extensions.py | 100 ++++++++-------- 4 files changed, 198 insertions(+), 57 deletions(-) diff --git a/ivy/array/extensions.py b/ivy/array/extensions.py index cbaeb142dd2d3..1727c659315b4 100644 --- a/ivy/array/extensions.py +++ b/ivy/array/extensions.py @@ -1,6 +1,6 @@ # global import abc -from typing import Optional, Union, Tuple +from typing import Optional, Union, Tuple, Sequence # local import ivy @@ -230,3 +230,46 @@ def max_pool2d( data_format=data_format, out=out, ) + + def moveaxis( + self: ivy.Array, + source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], + /, + *, + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ivy.Array instance method variant of ivy.moveaxis. This method simply + wraps the function, and so the docstring for ivy.unstack also applies to + this method with minimal changes. + + Parameters + ---------- + a + The array whose axes should be reordered. + source + Original positions of the axes to move. These must be unique. + destination + Destination positions for each of the original axes. + These must also be unique. + out + optional output array, for writing the result to. + + Returns + ------- + ret + Array with moved axes. This array is a view of the input array. + + Examples + -------- + >>> x = ivy.zeros((3, 4, 5)) + >>> print(x.moveaxis(0, -1).shape) + (4, 5, 3) + >>> print(x.moveaxis(-1, 0).shape) + (5, 3, 4) + """ + return ivy.flatten( + self._data, + source, + destination, + out=out) diff --git a/ivy/container/extensions.py b/ivy/container/extensions.py index 2fcfaac4042f7..8d8aeb405d69b 100644 --- a/ivy/container/extensions.py +++ b/ivy/container/extensions.py @@ -1,5 +1,5 @@ # global -from typing import Optional, Union, List, Dict, Tuple +from typing import Optional, Union, List, Dict, Tuple, Sequence # local import ivy @@ -493,3 +493,109 @@ def max_pool2d( map_sequences=map_sequences, out=out, ) + + @staticmethod + def static_moveaxis( + a: Union[ivy.Array, ivy.NativeArray, ivy.Container], + source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], + /, + *, + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ + ivy.Container static method variant of ivy.moveaxis. This method simply wraps + the function, and so the docstring for ivy.moveaxis also applies to this method + with minimal changes. + + Parameters + ---------- + a + The container with the arrays whose axes should be reordered. + source + Original positions of the axes to move. These must be unique. + destination + Destination positions for each of the original axes. + These must also be unique. + out + optional output container, for writing the result to. + + Returns + ------- + ret + Container including arrays with moved axes. + + Examples + -------- + With one :class:`ivy.Container` input: + + >>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6))) + >>> ivy.static_moveaxis(x, 0, -1).shape + + { + a: (4, 5, 3) + b: (7, 6, 2) + } + """ + return ContainerBase.multi_map_in_static_method( + "moveaxis", + a, + source, + destination, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def moveaxis( + self: ivy.Container, + source: Union[int, Sequence[int]], + destination: Union[int, Sequence[int]], + /, + *, + out: Optional[ivy.Container] = None, + ) -> ivy.Container: + """ivy.Container instance method variant of ivy.moveaxis. This method simply + wraps the function, and so the docstring for ivy.flatten also applies to this + method with minimal changes. + + Parameters + ---------- + self + The container with the arrays whose axes should be reordered. + source + Original positions of the axes to move. These must be unique. + destination + Destination positions for each of the original axes. + These must also be unique. + out + optional output container, for writing the result to. + + Returns + ------- + ret + Container including arrays with moved axes. + + Examples + -------- + With one :class:`ivy.Container` input: + + >>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6))) + >>> ivy.moveaxis(x, 0, -1).shape + + { + a: (4, 5, 3) + b: (7, 6, 2) + } + """ + return self.static_moveaxis( + self, + source, + destination, + out=out) diff --git a/ivy/functional/ivy/extensions.py b/ivy/functional/ivy/extensions.py index f029903f6346a..3d953e07a4dcc 100644 --- a/ivy/functional/ivy/extensions.py +++ b/ivy/functional/ivy/extensions.py @@ -787,7 +787,7 @@ def max_pool2d( @to_native_arrays_and_back @handle_out_argument -# @handle_nestable +@handle_nestable def moveaxis( a: Union[ivy.Array, ivy.NativeArray], source: Union[int, Sequence[int]], diff --git a/ivy_tests/test_ivy/test_functional/test_extensions.py b/ivy_tests/test_ivy/test_functional/test_extensions.py index 0f8e9a8649d4a..9501772bc0173 100644 --- a/ivy_tests/test_ivy/test_functional/test_extensions.py +++ b/ivy_tests/test_ivy/test_functional/test_extensions.py @@ -352,66 +352,59 @@ def test_max_pool2d( ) -# moveaxis -@st.composite -def _array_dual_axes(draw): - shape = draw( - helpers.get_shape( - min_num_dims=1, - max_num_dims=3, - min_dim_size=1, - max_dim_size=3, +# moveaxis +@handle_cmd_line_args +@given( + dtype_and_a=helpers.dtype_and_values( + available_dtypes=helpers.get_dtypes("float"), + min_value=-100, + max_value=100, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, + ), + key="a_s_d" ), ), - dtype_and_a = draw( - helpers.dtype_and_values( - available_dtypes=helpers.get_dtypes("float"), - min_value=-10, - max_value=10, - shape=st.shared( - shape, - key="shared_axes" - ) - ), - ) - source = draw( - st.lists( - helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - shape, - key="shared_axes" - ), - min_size=1, - force_int=True, + source=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, ), - ) - ) - destination = draw( - st.lists( - helpers.get_axis( - allow_none=False, - unique=True, - shape=st.shared( - shape, - key="shared_axes" - ), - min_size=1, - force_int=True, + key="a_s_d" + ), + min_size=1, + force_int=True, + ), + destination=helpers.get_axis( + allow_none=False, + unique=True, + shape=st.shared( + helpers.get_shape( + min_num_dims=1, + max_num_dims=3, + min_dim_size=1, + max_dim_size=3, ), - ) - ) - return dtype_and_a, source, destination - - -@handle_cmd_line_args -@given( - dtype_a_s_d=_array_dual_axes(), + key="a_s_d" + ), + min_size=1, + force_int=True, + ), num_positional_args=helpers.num_positional_args(fn_name="moveaxis"), ) def test_moveaxis( - dtype_a_s_d, + dtype_and_a, + source, + destination, as_variable, with_out, num_positional_args, @@ -420,7 +413,6 @@ def test_moveaxis( instance_method, fw, ): - dtype_and_a, source, destination = dtype_a_s_d input_dtype, a = dtype_and_a helpers.test_function( input_dtypes=input_dtype, From 091bd125ede3e743838c417229dfff1116fef198 Mon Sep 17 00:00:00 2001 From: nassimberrada Date: Wed, 12 Oct 2022 13:10:15 +0100 Subject: [PATCH 4/7] changes --- ivy/array/extensions.py | 4 ++-- ivy/container/extensions.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/ivy/array/extensions.py b/ivy/array/extensions.py index 1727c659315b4..2432fc06f9af9 100644 --- a/ivy/array/extensions.py +++ b/ivy/array/extensions.py @@ -263,9 +263,9 @@ def moveaxis( Examples -------- >>> x = ivy.zeros((3, 4, 5)) - >>> print(x.moveaxis(0, -1).shape) + >>> ivy.moveaxis(0, -1).shape (4, 5, 3) - >>> print(x.moveaxis(-1, 0).shape) + >>> ivy.moveaxis(-1, 0).shape (5, 3, 4) """ return ivy.flatten( diff --git a/ivy/container/extensions.py b/ivy/container/extensions.py index 8d8aeb405d69b..6456b6870c596 100644 --- a/ivy/container/extensions.py +++ b/ivy/container/extensions.py @@ -534,7 +534,7 @@ def static_moveaxis( With one :class:`ivy.Container` input: >>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6))) - >>> ivy.static_moveaxis(x, 0, -1).shape + >>> ivy.Container.static_moveaxis(x, 0, -1).shape { a: (4, 5, 3) @@ -587,7 +587,7 @@ def moveaxis( With one :class:`ivy.Container` input: >>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6))) - >>> ivy.moveaxis(x, 0, -1).shape + >>> x.moveaxis(, 0, -1).shape { a: (4, 5, 3) From 19f3fae1f971416c45c93566b6ff03b023e5dc0f Mon Sep 17 00:00:00 2001 From: nassimberrada Date: Wed, 12 Oct 2022 13:49:02 +0100 Subject: [PATCH 5/7] changed flatten to moveaxis --- ivy/array/extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/array/extensions.py b/ivy/array/extensions.py index 2432fc06f9af9..005ca376e77de 100644 --- a/ivy/array/extensions.py +++ b/ivy/array/extensions.py @@ -268,7 +268,7 @@ def moveaxis( >>> ivy.moveaxis(-1, 0).shape (5, 3, 4) """ - return ivy.flatten( + return ivy.moveaxis( self._data, source, destination, From 4f9789fdc71902c3020b1ad7a0b4ffc4ed0d6c83 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Wed, 12 Oct 2022 14:31:21 +0100 Subject: [PATCH 6/7] Update extensions.py --- ivy/container/extensions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ivy/container/extensions.py b/ivy/container/extensions.py index 6456b6870c596..1a6d9aee4bbd7 100644 --- a/ivy/container/extensions.py +++ b/ivy/container/extensions.py @@ -587,7 +587,7 @@ def moveaxis( With one :class:`ivy.Container` input: >>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6))) - >>> x.moveaxis(, 0, -1).shape + >>> x.moveaxis(0, -1).shape { a: (4, 5, 3) From b5a885ad7173ae06ae7b98cc9ec3ce03a5cfc002 Mon Sep 17 00:00:00 2001 From: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com> Date: Wed, 12 Oct 2022 14:32:01 +0100 Subject: [PATCH 7/7] Update extensions.py --- ivy/array/extensions.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/ivy/array/extensions.py b/ivy/array/extensions.py index 005ca376e77de..76c379680574d 100644 --- a/ivy/array/extensions.py +++ b/ivy/array/extensions.py @@ -263,9 +263,9 @@ def moveaxis( Examples -------- >>> x = ivy.zeros((3, 4, 5)) - >>> ivy.moveaxis(0, -1).shape + >>> x.moveaxis(0, -1).shape (4, 5, 3) - >>> ivy.moveaxis(-1, 0).shape + >>> x.moveaxis(-1, 0).shape (5, 3, 4) """ return ivy.moveaxis(