Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ivy pad backend consistency #5710

Merged
merged 15 commits into from
Oct 22, 2022
6 changes: 4 additions & 2 deletions ivy/array/extensions/layers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# global
import abc
from typing import Optional, Union, Tuple, Iterable, Callable, Literal
from typing import Optional, Union, Tuple, Iterable, Callable, Literal, Any
from numbers import Number

# local
Expand Down Expand Up @@ -157,8 +157,8 @@ def max_pool2d(

def pad(
self: ivy.Array,
/,
pad_width: Union[Iterable[Tuple[int]], int],
/,
*,
mode: Optional[
Union[
Expand All @@ -183,6 +183,7 @@ def pad(
end_values: Optional[Union[Iterable[Tuple[Number]], Number]] = 0,
reflect_type: Optional[Literal["even", "odd"]] = "even",
out: Optional[ivy.Array] = None,
**kwargs: Optional[Any],
) -> ivy.Array:
"""
ivy.Array instance method variant of ivy.pad. This method simply
Expand All @@ -198,4 +199,5 @@ def pad(
end_values=end_values,
reflect_type=reflect_type,
out=out,
**kwargs,
)
13 changes: 9 additions & 4 deletions ivy/container/extensions/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Callable,
Literal,
Iterable,
Any,
)
from numbers import Number

Expand Down Expand Up @@ -460,10 +461,10 @@ def kaiser_window(

@staticmethod
def static_pad(
x: ivy.Container,
input: ivy.Container,
pad_width: Union[Iterable[Tuple[int]], int],
/,
*,
pad_width: Union[Iterable[Tuple[int]], int],
mode: Optional[
Union[
Literal[
Expand Down Expand Up @@ -491,6 +492,7 @@ def static_pad(
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
**kwargs: Optional[Any],
) -> ivy.Container:
"""
ivy.Container static method variant of ivy.pad. This method simply
Expand All @@ -499,7 +501,7 @@ def static_pad(
"""
return ContainerBase.multi_map_in_static_method(
"pad",
x,
input,
pad_width,
mode=mode,
stat_length=stat_length,
Expand All @@ -511,13 +513,14 @@ def static_pad(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
**kwargs,
)

def pad(
self: ivy.Container,
pad_width: Union[Iterable[Tuple[int]], int],
/,
*,
pad_width: Union[Iterable[Tuple[int]], int],
mode: Optional[
Union[
Literal[
Expand Down Expand Up @@ -545,6 +548,7 @@ def pad(
prune_unapplied: bool = False,
map_sequences: bool = False,
out: Optional[ivy.Container] = None,
**kwargs: Optional[Any],
) -> ivy.Container:
"""
ivy.Container instance method variant of ivy.pad. This method simply
Expand All @@ -564,4 +568,5 @@ def pad(
prune_unapplied=prune_unapplied,
map_sequences=map_sequences,
out=out,
**kwargs,
)
52 changes: 30 additions & 22 deletions ivy/functional/backends/jax/extensions/layers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Optional, Union, Tuple, Callable, Literal, Sequence
from typing import Optional, Union, Tuple, Callable, Literal, Sequence, Any
from numbers import Number
import ivy
from ivy.functional.backends.jax import JaxArray
Expand Down Expand Up @@ -98,14 +98,28 @@ def max_pool2d(
return res


def kaiser_window(
window_length: int,
periodic: bool = True,
beta: float = 12.0,
*,
dtype: Optional[jnp.dtype] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if periodic is False:
return jnp.array(jnp.kaiser(M=window_length, beta=beta), dtype=dtype)
else:
return jnp.array(jnp.kaiser(M=window_length + 1, beta=beta)[:-1], dtype=dtype)


def _flat_array_to_1_dim_array(x):
return x.reshape((1,)) if x.shape == () else x


def pad(
x: JaxArray,
/,
input: JaxArray,
pad_width: Union[Sequence[Sequence[int]], JaxArray, int],
/,
*,
mode: Optional[
Union[
Expand All @@ -130,52 +144,46 @@ def pad(
end_values: Optional[Union[Sequence[Sequence[Number]], Number]] = 0,
reflect_type: Optional[Literal["even", "odd"]] = "even",
out: Optional[JaxArray] = None,
**kwargs: Optional[Any],
) -> JaxArray:
if callable(mode):
return jnp.pad(
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
**kwargs,
)
if mode in ["maximum", "mean", "median", "minimum"]:
return jnp.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
stat_length=stat_length,
)
elif mode == "constant":
return jnp.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
constant_values=constant_values,
)
elif mode == "linear_ramp":
return jnp.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
end_values=end_values,
)
elif mode in ["reflect", "symmetric"]:
return jnp.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
reflect_type=reflect_type,
)
else:
return jnp.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
)


def kaiser_window(
window_length: int,
periodic: bool = True,
beta: float = 12.0,
*,
dtype: Optional[jnp.dtype] = None,
out: Optional[JaxArray] = None,
) -> JaxArray:
if periodic is False:
return jnp.array(jnp.kaiser(M=window_length, beta=beta), dtype=dtype)
else:
return jnp.array(jnp.kaiser(M=window_length + 1, beta=beta)[:-1], dtype=dtype)
58 changes: 33 additions & 25 deletions ivy/functional/backends/numpy/extensions/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import math
from numbers import Number
import numpy as np
from typing import Optional, Union, Tuple, Sequence, Callable, Literal
from typing import Optional, Union, Tuple, Sequence, Callable, Literal, Any

# local
import ivy
Expand Down Expand Up @@ -107,14 +107,31 @@ def max_pool2d(
return res


def kaiser_window(
window_length: int,
periodic: bool = True,
beta: float = 12.0,
*,
dtype: Optional[np.dtype] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if periodic is False:
return np.array(np.kaiser(M=window_length, beta=beta), dtype=dtype)
else:
return np.array(np.kaiser(M=window_length + 1, beta=beta)[:-1], dtype=dtype)


kaiser_window.support_native_out = False


def _flat_array_to_1_dim_array(x):
return x.reshape((1,)) if x.shape == () else x


def pad(
x: np.ndarray,
/,
input: np.ndarray,
pad_width: Union[Sequence[Sequence[int]], np.ndarray, int],
/,
*,
mode: Optional[
Union[
Expand All @@ -139,55 +156,46 @@ def pad(
end_values: Optional[Union[Sequence[Sequence[Number]], Number]] = 0,
reflect_type: Optional[Literal["even", "odd"]] = "even",
out: Optional[np.ndarray] = None,
**kwargs: Optional[Any],
) -> np.ndarray:
if callable(mode):
return np.pad(
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
**kwargs,
)
if mode in ["maximum", "mean", "median", "minimum"]:
return np.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
stat_length=stat_length,
)
elif mode == "constant":
return np.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
constant_values=constant_values,
)
elif mode == "linear_ramp":
return np.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
end_values=end_values,
)
elif mode in ["reflect", "symmetric"]:
return np.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
reflect_type=reflect_type,
)
else:
return np.pad(
_flat_array_to_1_dim_array(x),
_flat_array_to_1_dim_array(input),
pad_width,
mode=mode,
)


def kaiser_window(
window_length: int,
periodic: bool = True,
beta: float = 12.0,
*,
dtype: Optional[np.dtype] = None,
out: Optional[np.ndarray] = None,
) -> np.ndarray:
if periodic is False:
return np.array(np.kaiser(M=window_length, beta=beta), dtype=dtype)
else:
return np.array(np.kaiser(M=window_length + 1, beta=beta)[:-1], dtype=dtype)


kaiser_window.support_native_out = False
28 changes: 1 addition & 27 deletions ivy/functional/backends/tensorflow/extensions/layers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import Union, Optional, Tuple, Literal
from numbers import Number
from typing import Union, Optional, Tuple
import tensorflow as tf


Expand Down Expand Up @@ -42,31 +41,6 @@ def max_pool2d(
return res


def pad(
x: tf.Tensor,
/,
pad_width: tf.Tensor,
*,
mode: Optional[Literal["constant", "reflect", "symmetric"]] = "constant",
stat_length: Optional[Union[tf.Tensor, int]] = None,
constant_values: Optional[Number] = 0,
end_values: Optional[Number] = 0,
reflect_type: Optional[Literal["even", "odd"]] = "even",
out: Optional[Union[tf.Tensor, tf.Variable]] = None,
) -> tf.Tensor:
if x.shape == ():
x = tf.reshape(x, (-1,))
if mode == "constant":
return tf.pad(
x,
pad_width,
mode=mode,
constant_values=constant_values,
)
else:
return tf.pad(x, pad_width, mode=mode)


def kaiser_window(
window_length: int,
periodic: bool = True,
Expand Down
Loading