Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

moveaxis extension #5609

Merged
merged 8 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion ivy/array/extensions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# global
import abc
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence

# local
import ivy
Expand Down Expand Up @@ -230,3 +230,46 @@ def max_pool2d(
data_format=data_format,
out=out,
)

def moveaxis(
self: ivy.Array,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""ivy.Array instance method variant of ivy.moveaxis. This method simply
wraps the function, and so the docstring for ivy.unstack also applies to
this method with minimal changes.

Parameters
----------
a
The array whose axes should be reordered.
source
Original positions of the axes to move. These must be unique.
destination
Destination positions for each of the original axes.
These must also be unique.
out
optional output array, for writing the result to.

Returns
-------
ret
Array with moved axes. This array is a view of the input array.

Examples
--------
>>> x = ivy.zeros((3, 4, 5))
>>> x.moveaxis(0, -1).shape
(4, 5, 3)
>>> x.moveaxis(-1, 0).shape
(5, 3, 4)
"""
return ivy.moveaxis(
self._data,
source,
destination,
out=out)
108 changes: 107 additions & 1 deletion ivy/container/extensions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Optional, Union, List, Dict, Tuple
from typing import Optional, Union, List, Dict, Tuple, Sequence

# local
import ivy
Expand Down Expand Up @@ -493,3 +493,109 @@ def max_pool2d(
map_sequences=map_sequences,
out=out,
)

@staticmethod
def static_moveaxis(
a: Union[ivy.Array, ivy.NativeArray, ivy.Container],
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.moveaxis. This method simply wraps
the function, and so the docstring for ivy.moveaxis also applies to this method
with minimal changes.

Parameters
----------
a
The container with the arrays whose axes should be reordered.
source
Original positions of the axes to move. These must be unique.
destination
Destination positions for each of the original axes.
These must also be unique.
out
optional output container, for writing the result to.

Returns
-------
ret
Container including arrays with moved axes.

Examples
--------
With one :class:`ivy.Container` input:

>>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6)))
>>> ivy.Container.static_moveaxis(x, 0, -1).shape

{
a: (4, 5, 3)
b: (7, 6, 2)
}
"""
return ContainerBase.multi_map_in_static_method(
"moveaxis",
a,
source,
destination,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def moveaxis(
self: ivy.Container,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""ivy.Container instance method variant of ivy.moveaxis. This method simply
wraps the function, and so the docstring for ivy.flatten also applies to this
method with minimal changes.

Parameters
----------
self
The container with the arrays whose axes should be reordered.
source
Original positions of the axes to move. These must be unique.
destination
Destination positions for each of the original axes.
These must also be unique.
out
optional output container, for writing the result to.

Returns
-------
ret
Container including arrays with moved axes.

Examples
--------
With one :class:`ivy.Container` input:

>>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6)))
>>> x.moveaxis(0, -1).shape

{
a: (4, 5, 3)
b: (7, 6, 2)
}
"""
return self.static_moveaxis(
self,
source,
destination,
out=out)
13 changes: 12 additions & 1 deletion ivy/functional/backends/jax/extensions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence
import ivy
from ivy.functional.ivy.extensions import (
_verify_coo_components,
Expand Down Expand Up @@ -156,3 +156,14 @@ def max_pool2d(
return jnp.transpose(res, (0, 3, 1, 2))

return res


def moveaxis(
a: JaxArray,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.moveaxis(a, source, destination)
16 changes: 15 additions & 1 deletion ivy/functional/backends/numpy/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence
import logging
import ivy
import numpy as np
Expand Down Expand Up @@ -176,3 +176,17 @@ def max_pool2d(
if data_format == "NCHW":
return np.transpose(res, (0, 3, 1, 2))
return res


def moveaxis(
a: np.ndarray,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return np.moveaxis(a, source, destination)


moveaxis.support_native_out = False
13 changes: 12 additions & 1 deletion ivy/functional/backends/tensorflow/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional, Tuple
from typing import Union, Optional, Tuple, Sequence
import ivy
from ivy.functional.ivy.extensions import (
_verify_coo_components,
Expand Down Expand Up @@ -138,3 +138,14 @@ def max_pool2d(
if data_format == "NCHW":
return tf.transpose(res, (0, 3, 1, 2))
return res


def moveaxis(
a: Union[tf.Tensor, tf.Variable],
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
return tf.experimental.numpy.moveaxis(a, source, destination)
16 changes: 15 additions & 1 deletion ivy/functional/backends/torch/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence
import ivy
from ivy.functional.ivy.extensions import (
_verify_coo_components,
Expand Down Expand Up @@ -183,3 +183,17 @@ def max_pool2d(


max_pool2d.unsupported_dtypes = ("bfloat16", "float16")


def moveaxis(
a: torch.Tensor,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.moveaxis(a, source, destination)


moveaxis.support_native_out = False
47 changes: 46 additions & 1 deletion ivy/functional/ivy/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence
import ivy
from ivy.func_wrapper import (
handle_out_argument,
Expand Down Expand Up @@ -783,3 +783,48 @@ def max_pool2d(
"""

return ivy.current_backend(x).max_pool2d(x, kernel, strides, padding, out=out)


@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
def moveaxis(
a: Union[ivy.Array, ivy.NativeArray],
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
) -> Union[ivy.Array, ivy.NativeArray]:
"""Move axes of an array to new positions..

Parameters
----------
a
The array whose axes should be reordered.
source
Original positions of the axes to move. These must be unique.
destination
Destination positions for each of the original axes.
These must also be unique.
out
optional output array, for writing the result to.

Returns
-------
ret
Array with moved axes. This array is a view of the input array.

Examples
--------
With :class:`ivy.Array` input:

>>> x = ivy.zeros((3, 4, 5))
>>> ivy.moveaxis(x, 0, -1).shape
(4, 5, 3)
>>> ivy.moveaxis(x, -1, 0).shape
(5, 3, 4)
"""
return ivy.current_backend().moveaxis(
a, source, destination, out=out
)
78 changes: 78 additions & 0 deletions ivy_tests/test_ivy/test_functional/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,3 +350,81 @@ def test_max_pool2d(
strides=stride,
padding=pad
)


# moveaxis
@handle_cmd_line_args
@given(
dtype_and_a=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=-100,
max_value=100,
shape=st.shared(
helpers.get_shape(
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
key="a_s_d"
),
),
source=helpers.get_axis(
allow_none=False,
unique=True,
shape=st.shared(
helpers.get_shape(
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
key="a_s_d"
),
min_size=1,
force_int=True,
),
destination=helpers.get_axis(
allow_none=False,
unique=True,
shape=st.shared(
helpers.get_shape(
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
key="a_s_d"
),
min_size=1,
force_int=True,
),
num_positional_args=helpers.num_positional_args(fn_name="moveaxis"),
)
def test_moveaxis(
dtype_and_a,
source,
destination,
as_variable,
with_out,
num_positional_args,
native_array,
container,
instance_method,
fw,
):
input_dtype, a = dtype_and_a
helpers.test_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=with_out,
num_positional_args=num_positional_args,
native_array_flags=native_array,
container_flags=container,
instance_method=instance_method,
fw=fw,
fn_name="moveaxis",
a=np.asarray(a[0], dtype=input_dtype[0]),
source=source,
destination=destination,
)