Skip to content

Commit

Permalink
moveaxis extension (#5609)
Browse files Browse the repository at this point in the history
* moveaxis_extension

* changes

* added container and array methods

* changes

* changed flatten to moveaxis

* Update extensions.py

* Update extensions.py

Co-authored-by: nassimberrada <Nassim>
Co-authored-by: Ishtiaq Hussain <53497039+Ishticode@users.noreply.github.com>
  • Loading branch information
nassimberrada and Ishticode authored Oct 12, 2022
1 parent f9dbd12 commit dadff1d
Show file tree
Hide file tree
Showing 8 changed files with 329 additions and 7 deletions.
45 changes: 44 additions & 1 deletion ivy/array/extensions.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# global
import abc
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence

# local
import ivy
Expand Down Expand Up @@ -216,3 +216,46 @@ def max_pool2d(
data_format=data_format,
out=out,
)

def moveaxis(
self: ivy.Array,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[ivy.Array] = None,
) -> ivy.Array:
"""ivy.Array instance method variant of ivy.moveaxis. This method simply
wraps the function, and so the docstring for ivy.unstack also applies to
this method with minimal changes.
Parameters
----------
a
The array whose axes should be reordered.
source
Original positions of the axes to move. These must be unique.
destination
Destination positions for each of the original axes.
These must also be unique.
out
optional output array, for writing the result to.
Returns
-------
ret
Array with moved axes. This array is a view of the input array.
Examples
--------
>>> x = ivy.zeros((3, 4, 5))
>>> x.moveaxis(0, -1).shape
(4, 5, 3)
>>> x.moveaxis(-1, 0).shape
(5, 3, 4)
"""
return ivy.moveaxis(
self._data,
source,
destination,
out=out)
108 changes: 107 additions & 1 deletion ivy/container/extensions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# global
from typing import Optional, Union, List, Dict, Tuple
from typing import Optional, Union, List, Dict, Tuple, Sequence

# local
import ivy
Expand Down Expand Up @@ -679,3 +679,109 @@ def kaiser_window(
dtype=dtype,
out=out,
)

@staticmethod
def static_moveaxis(
a: Union[ivy.Array, ivy.NativeArray, ivy.Container],
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
key_chains: Optional[Union[List[str], Dict[str, str]]] = None,
to_apply: bool = True,
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.moveaxis. This method simply wraps
the function, and so the docstring for ivy.moveaxis also applies to this method
with minimal changes.
Parameters
----------
a
The container with the arrays whose axes should be reordered.
source
Original positions of the axes to move. These must be unique.
destination
Destination positions for each of the original axes.
These must also be unique.
out
optional output container, for writing the result to.
Returns
-------
ret
Container including arrays with moved axes.
Examples
--------
With one :class:`ivy.Container` input:
>>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6)))
>>> ivy.Container.static_moveaxis(x, 0, -1).shape
{
a: (4, 5, 3)
b: (7, 6, 2)
}
"""
return ContainerBase.multi_map_in_static_method(
"moveaxis",
a,
source,
destination,
key_chains=key_chains,
to_apply=to_apply,
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
)

def moveaxis(
self: ivy.Container,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[ivy.Container] = None,
) -> ivy.Container:
"""ivy.Container instance method variant of ivy.moveaxis. This method simply
wraps the function, and so the docstring for ivy.flatten also applies to this
method with minimal changes.
Parameters
----------
self
The container with the arrays whose axes should be reordered.
source
Original positions of the axes to move. These must be unique.
destination
Destination positions for each of the original axes.
These must also be unique.
out
optional output container, for writing the result to.
Returns
-------
ret
Container including arrays with moved axes.
Examples
--------
With one :class:`ivy.Container` input:
>>> x = ivy.Container(a=ivy.zeros((3, 4, 5)), b=ivy.zeros((2,7,6)))
>>> x.moveaxis(, 0, -1).shape
{
a: (4, 5, 3)
b: (7, 6, 2)
}
"""
return self.static_moveaxis(
self,
source,
destination,
out=out)
13 changes: 12 additions & 1 deletion ivy/functional/backends/jax/extensions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import logging
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence
import ivy
from ivy.functional.ivy.extensions import (
_verify_coo_components,
Expand Down Expand Up @@ -171,3 +171,14 @@ def kaiser_window(
return jnp.array(
jnp.kaiser(M=window_length + 1, beta=beta)[:-1],
dtype=dtype)


def moveaxis(
a: JaxArray,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[JaxArray] = None,
) -> JaxArray:
return jnp.moveaxis(a, source, destination)
16 changes: 15 additions & 1 deletion ivy/functional/backends/numpy/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence
import logging
import ivy
import numpy as np
Expand Down Expand Up @@ -205,3 +205,17 @@ def kaiser_window(


kaiser_window.support_native_out = False


def moveaxis(
a: np.ndarray,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
return np.moveaxis(a, source, destination)


moveaxis.support_native_out = False
13 changes: 12 additions & 1 deletion ivy/functional/backends/tensorflow/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union, Optional, Tuple
from typing import Union, Optional, Tuple, Sequence
import ivy
from ivy.functional.ivy.extensions import (
_verify_coo_components,
Expand Down Expand Up @@ -145,3 +145,14 @@ def kaiser_window(
else:
return tf.signal.kaiser_window(
window_length + 1, beta, dtype=dtype, name=None)[:-1]


def moveaxis(
a: Union[tf.Tensor, tf.Variable],
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> Union[tf.Tensor, tf.Variable]:
return tf.experimental.numpy.moveaxis(a, source, destination)
16 changes: 15 additions & 1 deletion ivy/functional/backends/torch/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence
import ivy
from ivy.functional.ivy.extensions import (
_verify_coo_components,
Expand Down Expand Up @@ -197,3 +197,17 @@ def kaiser_window(
layout=torch.strided,
device=None,
requires_grad=False)


def moveaxis(
a: torch.Tensor,
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return torch.moveaxis(a, source, destination)


moveaxis.support_native_out = False
47 changes: 46 additions & 1 deletion ivy/functional/ivy/extensions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Tuple
from typing import Optional, Union, Tuple, Sequence
import ivy
from ivy.func_wrapper import (
handle_out_argument,
Expand Down Expand Up @@ -816,3 +816,48 @@ def kaiser_window(
"""
return ivy.current_backend().kaiser_window(
window_length, periodic, beta, dtype=dtype, out=out)


@to_native_arrays_and_back
@handle_out_argument
@handle_nestable
def moveaxis(
a: Union[ivy.Array, ivy.NativeArray],
source: Union[int, Sequence[int]],
destination: Union[int, Sequence[int]],
/,
*,
out: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
) -> Union[ivy.Array, ivy.NativeArray]:
"""Move axes of an array to new positions..
Parameters
----------
a
The array whose axes should be reordered.
source
Original positions of the axes to move. These must be unique.
destination
Destination positions for each of the original axes.
These must also be unique.
out
optional output array, for writing the result to.
Returns
-------
ret
Array with moved axes. This array is a view of the input array.
Examples
--------
With :class:`ivy.Array` input:
>>> x = ivy.zeros((3, 4, 5))
>>> ivy.moveaxis(x, 0, -1).shape
(4, 5, 3)
>>> ivy.moveaxis(x, -1, 0).shape
(5, 3, 4)
"""
return ivy.current_backend().moveaxis(
a, source, destination, out=out
)
78 changes: 78 additions & 0 deletions ivy_tests/test_ivy/test_functional/test_extensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,81 @@ def test_kaiser_window(
beta=beta,
dtype=dtype
)


# moveaxis
@handle_cmd_line_args
@given(
dtype_and_a=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("float"),
min_value=-100,
max_value=100,
shape=st.shared(
helpers.get_shape(
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
key="a_s_d"
),
),
source=helpers.get_axis(
allow_none=False,
unique=True,
shape=st.shared(
helpers.get_shape(
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
key="a_s_d"
),
min_size=1,
force_int=True,
),
destination=helpers.get_axis(
allow_none=False,
unique=True,
shape=st.shared(
helpers.get_shape(
min_num_dims=1,
max_num_dims=3,
min_dim_size=1,
max_dim_size=3,
),
key="a_s_d"
),
min_size=1,
force_int=True,
),
num_positional_args=helpers.num_positional_args(fn_name="moveaxis"),
)
def test_moveaxis(
dtype_and_a,
source,
destination,
as_variable,
with_out,
num_positional_args,
native_array,
container,
instance_method,
fw,
):
input_dtype, a = dtype_and_a
helpers.test_function(
input_dtypes=input_dtype,
as_variable_flags=as_variable,
with_out=with_out,
num_positional_args=num_positional_args,
native_array_flags=native_array,
container_flags=container,
instance_method=instance_method,
fw=fw,
fn_name="moveaxis",
a=np.asarray(a[0], dtype=input_dtype[0]),
source=source,
destination=destination,
)

0 comments on commit dadff1d

Please sign in to comment.