diff --git a/ivy/data_classes/array/experimental/layers.py b/ivy/data_classes/array/experimental/layers.py index 19970aeac64ed..f9092b71cd227 100644 --- a/ivy/data_classes/array/experimental/layers.py +++ b/ivy/data_classes/array/experimental/layers.py @@ -957,3 +957,70 @@ def fft2( 0. +0.j , 0. +0.j ]]) """ return ivy.fft2(self._data, s=s, dim=dim, norm=norm, out=out) + + def ifftn( + self: ivy.Array, + s: Optional[Union[int, Tuple[int, ...]]] = None, + axes: Optional[Union[int, Tuple[int, ...]]] = None, + *, + norm: str = "backward", + out: Optional[ivy.Array] = None, + ) -> ivy.Array: + """ + Compute the N-dimensional inverse discrete Fourier Transform. + + Parameters + ---------- + x + Input array of complex numbers. + + s + sequence of ints, optional + Shape (length of transformed axis) of the output (`s[0]` refers to axis 0, + `s[1]` to axis 1, etc.). If given shape is smaller than that of the input, + the input is cropped. If larger, input is padded with zeros. If `s` is not + given, shape of input along axes specified by axes is used. + axes + axes over which to compute the IFFT. If not given, last `len(s)` axes are + used, or all axes if `s` is also not specified. Repeated indices in axes + means inverse transform over that axis is performed multiple times. + norm + Optional argument, "backward", "ortho" or "forward". Defaults to be + "backward". + "backward" indicates no normalization. + "ortho" indicates normalization by 1/sqrt(n). + "forward" indicates normalization by 1/n. + out + Optional output array, for writing the result to. It must have a shape that + the inputs broadcast to. + + Returns + ------- + ret + The truncated or zero-padded input, transformed along the axes indicated + by axes, or by a combination of s or x, as explained in the parameters + section above. + + Examples + -------- + >>> x = ivy.array([[0.24730653+0.90832391j, 0.49495562+0.9039565j, + 0.98193269+0.49560517j], + [0.93280757+0.48075343j, 0.28526384+0.3351205j, + 0.2343787 +0.83528011j], + [0.18791352+0.30690572j, 0.82115787+0.96195183j, + 0.44719226+0.72654048j]]) + >>> y = ivy.ifftn(x) + >>> print(y) + ivy.array([[ 0.51476765+0.66160417j, -0.04319742-0.05411636j, + -0.015561 -0.04216015j], + [ 0.06310689+0.05347854j, -0.13392983+0.16052352j, + -0.08371392+0.17252843j], + [-0.0031429 +0.05421245j, -0.10446617-0.17747098j, + 0.05344324+0.07972424j]]) + + >>> b = ivy.ifftn(x, s=[2, 1], axes=[0, 1], norm='ortho') + >>> print(b) + ivy.array([[ 0.8344667 +0.98222595j], + [-0.48472244+0.30233797j]]) + """ + return ivy.ifftn(self._data, s=s, axes=axes, norm=norm, out=out) diff --git a/ivy/data_classes/container/experimental/layers.py b/ivy/data_classes/container/experimental/layers.py index ce2f4f7867f20..9d6cf8865cea7 100644 --- a/ivy/data_classes/container/experimental/layers.py +++ b/ivy/data_classes/container/experimental/layers.py @@ -1907,3 +1907,119 @@ def adaptive_avg_pool2d( prune_unapplied=prune_unapplied, map_sequences=map_sequences, ) + + @staticmethod + def static_ifftn( + x: ivy.Container, + s: Optional[Union[int, Tuple[int, ...]]] = None, + axes: Optional[Union[int, Tuple[int, ...]]] = None, + *, + norm: str = "backward", + key_chains: Optional[Union[List[str], Dict[str, str]]] = None, + to_apply: bool = True, + prune_unapplied: bool = False, + map_sequences: bool = False, + out: Optional[ivy.Container] = None, + ): + """ + ivy.Container static method variant of ivy.ifftn. + + This method simply wraps the function, and so the docstring for + ivy.ifftn also applies to this method with minimal changes. + + Parameters + ---------- + x + Input array of complex numbers. + + s + sequence of ints, optional + Shape (length of transformed axis) of the output (`s[0]` refers to axis 0, + `s[1]` to axis 1, etc.). If given shape is smaller than that of the input, + the input is cropped. If larger, input is padded with zeros. If `s` is not + given, shape of input along axes specified by axes is used. + axes + axes over which to compute the IFFT. If not given, last `len(s)` axes are + used, or all axes if `s` is also not specified. Repeated indices in axes + means inverse transform over that axis is performed multiple times. + norm + Optional argument, "backward", "ortho" or "forward". + Defaults to be "backward". + "backward" indicates no normalization. + "ortho" indicates normalization by 1/sqrt(n). + "forward" indicates normalization by 1/n. + out + Optional output array, for writing the result to. It must have a shape that + the inputs broadcast to. + + Returns + ------- + ret + The truncated or zero-padded input, transformed along the axes indicated + by axes, or by a combination of s or x, as explained in the parameters + section above. + """ + return ContainerBase.cont_multi_map_in_function( + "ifftn", + x, + s=s, + axes=axes, + norm=norm, + key_chains=key_chains, + to_apply=to_apply, + prune_unapplied=prune_unapplied, + map_sequences=map_sequences, + out=out, + ) + + def ifftn( + self: ivy.Container, + s: Optional[Union[int, Tuple[int, ...]]] = None, + axes: Optional[Union[int, Tuple[int, ...]]] = None, + *, + norm: str = "backward", + out: Optional[ivy.Array] = None, + ): + """ + ivy.Container static method variant of ivy.ifftn. + + This method simply wraps the function, and so the docstring for + ivy.ifftn also applies to this method with minimal changes. + + Parameters + ---------- + x + Input array of complex numbers. + + s + sequence of ints, optional + Shape (length of transformed axis) of the output (`s[0]` refers to axis 0, + `s[1]` to axis 1, etc.). If given shape is smaller than that of the input, + the input is cropped. If larger, input is padded with zeros. If `s` is not + given, shape of input along axes specified by axes is used. + axes + axes over which to compute the IFFT. If not given, last `len(s)` axes are + used, or all axes if `s` is also not specified. Repeated indices in axes + means inverse transform over that axis is performed multiple times. + norm + Optional argument, "backward", "ortho" or "forward". + Defaults to be "backward". + "backward" indicates no normalization. + "ortho" indicates normalization by 1/sqrt(n). + "forward" indicates normalization by 1/n. + out + Optional output array, for writing the result to. It must have a shape that + the inputs broadcast to. + + Returns + ------- + ret + Container containing the transformed inputs + """ + return self.static_ifftn( + self, + s=s, + axes=axes, + norm=norm, + out=out, + ) diff --git a/ivy/functional/backends/jax/experimental/layers.py b/ivy/functional/backends/jax/experimental/layers.py index 553e0709b115f..913b70a8ddaea 100644 --- a/ivy/functional/backends/jax/experimental/layers.py +++ b/ivy/functional/backends/jax/experimental/layers.py @@ -766,6 +766,17 @@ def fft2( return jnp.fft.fft2(x, s, dim, norm).astype(jnp.complex128) + +def ifftn( + x: JaxArray, + s: Optional[Union[int, Tuple[int]]] = None, + axes: Optional[Union[int, Tuple[int]]] = None, + *, + norm: str = "backward", + out: Optional[JaxArray] = None, +) -> JaxArray: + return jnp.fft.ifftn(x, s, axes, norm) + @with_unsupported_dtypes( {"0.4.12 and below": ("bfloat16", "float16", "complex")}, backend_version ) @@ -786,4 +797,4 @@ def embedding( embeddings = jnp.where( norms < -max_norm, embeddings * -max_norm / norms, embeddings ) - return embeddings + diff --git a/ivy/functional/backends/numpy/experimental/layers.py b/ivy/functional/backends/numpy/experimental/layers.py index 49b4bfec92bc9..4fe13d8ad1df3 100644 --- a/ivy/functional/backends/numpy/experimental/layers.py +++ b/ivy/functional/backends/numpy/experimental/layers.py @@ -892,6 +892,16 @@ def fft2( return np.fft.fft2(x, s, dim, norm).astype(np.complex128) +def ifftn( + x: np.ndarray, + s: Optional[Union[int, Tuple[int]]] = None, + axes: Optional[Union[int, Tuple[int]]] = None, + *, + norm: str = "backward", + out: Optional[np.ndarray] = None, +) -> np.ndarray: + return np.fft.ifftn(x, s, axes, norm).astype(x.dtype) + @with_unsupported_dtypes({"1.25.0 and below": ("complex",)}, backend_version) def embedding( weights: np.ndarray, @@ -911,3 +921,4 @@ def embedding( norms < -max_norm, embeddings * -max_norm / norms, embeddings ) return embeddings + diff --git a/ivy/functional/backends/paddle/experimental/layers.py b/ivy/functional/backends/paddle/experimental/layers.py index 4a01d32b09783..86a201428751c 100644 --- a/ivy/functional/backends/paddle/experimental/layers.py +++ b/ivy/functional/backends/paddle/experimental/layers.py @@ -316,3 +316,14 @@ def interpolate( antialias: Optional[bool] = False, ): raise IvyNotImplementedException() + + +def ifftn( + x: paddle.Tensor, + s: Optional[Union[int, Tuple[int]]] = None, + axes: Optional[Union[int, Tuple[int]]] = None, + *, + norm: Optional[str] = "backward", + out: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: + return paddle.fft.ifftn(x, s, axes, norm) diff --git a/ivy/functional/backends/tensorflow/experimental/layers.py b/ivy/functional/backends/tensorflow/experimental/layers.py index 911a0c136975a..c98e1095a3a96 100644 --- a/ivy/functional/backends/tensorflow/experimental/layers.py +++ b/ivy/functional/backends/tensorflow/experimental/layers.py @@ -916,3 +916,274 @@ def fft2( # Apply the same normalization as 'backward' in NumPy tf_fft2 = _fft2_norm(tf_fft2, s, dim, norm) return tf_fft2 + + +# --- IFFTN --- # +def fft_input_validation(x): + if not x.dtype.is_complex: + raise TypeError( + "Invalid FFT input: `x` must be of a complex dtype. Received: {}".format( + x.dtype + ) + ) + return x + + +def shape_and_axes_validation(shape, axes, input_rank_tensor): + if shape is not None: + shape = tf.convert_to_tensor(shape, dtype=tf.dtypes.int32) + checks_shape = [ + tf.debugging.assert_less_equal( + tf.size(shape), + input_rank_tensor, + message=( + "Argument `shape` cannot have length greater than the rank of `x`. " + "Received: {}" + ).format(shape), + ) + ] + with tf.control_dependencies(checks_shape): + shape = tf.identity(shape) + + if axes is not None: + axes = tf.convert_to_tensor(axes, dtype=tf.dtypes.int32) + checks_axes = [ + tf.debugging.assert_less_equal( + tf.size(axes), + input_rank_tensor, + message=( + "Argument `axes` cannot have length greater than the rank of `x`. " + "Received: {}" + ).format(axes), + ), + tf.debugging.assert_less( + axes, + input_rank_tensor, + message=( + "Argument `axes` contains invalid indices. Received: {}" + ).format(axes), + ), + tf.debugging.assert_greater_equal( + axes, + -input_rank_tensor, + message=( + "Argument `axes` contains invalid indices. Received: {}" + ).format(axes), + ), + ] + with tf.control_dependencies(checks_axes): + axes = tf.identity(axes) + + if shape is not None and axes is not None: + checks_shape_axes = [ + tf.debugging.assert_equal( + tf.size(shape), + tf.size(axes), + message=( + "Arguments `shape` and `axes` must have equal length. " + "Received: {}, {}" + ).format(shape, axes), + ) + ] + with tf.control_dependencies(checks_shape_axes): + shape, axes = tf.identity_n([shape, axes]) + + return shape, axes + + +def axes_initialization(shape, axes, input_shape, input_rank_tensor): + if axes is None: + axes = ( + tf.range(-tf.size(input_shape), 0) + if shape is None + else tf.range(-tf.size(shape), 0) + ) + axes = tf.where(tf.math.less(axes, 0), axes + input_rank_tensor, axes) + return axes + + +def perform_actions_initialization(shape, axes, input_shape, input_rank_tensor): + perform_padding = shape is not None + perform_transpose = tf.math.logical_not( + tf.math.reduce_all( + tf.math.equal( + axes, tf.range(input_rank_tensor - tf.size(axes), input_rank_tensor) + ) + ) + ) + return perform_padding, perform_transpose + + +def shape_initialization(shape, axes, x): + if shape is None: + shape = tf.gather(tf.shape(x), axes, axis=0) + return shape + + +def rank_initialization(axes): + rank = tf.size(axes) + with tf.control_dependencies( + [ + tf.debugging.assert_less_equal( + rank, 3, message="N-D FFT supported only up to 3-D." + ) + ] + ): + rank = tf.identity(rank) + + return rank + + +def norm_initialization(norm, shape, x): + if norm == "backward": + norm_factor = tf.constant(1, x.dtype) + elif norm == "forward" or norm == "ortho": + norm_factor = tf.cast(tf.math.reduce_prod(shape), x.dtype) + if norm == "ortho": + norm_factor = tf.math.sqrt(norm_factor) + return norm_factor + + +def get_x_after_pad_or_crop(x, shape, axes, perform_padding, input_rank_tensor): + if perform_padding: + pad_shape = -tf.ones([input_rank_tensor], dtype=tf.int32) + pad_shape = tf.tensor_scatter_nd_update( + pad_shape, tf.expand_dims(axes, -1), shape + ) + x = _right_pad_or_crop(x, pad_shape) + return x + + +def get_perm(input_rank_tensor, axes): + all_dims = tf.range(input_rank_tensor, dtype=tf.dtypes.int32) + perm = tf.concat( + [ + tf.boolean_mask( + all_dims, + tf.foldl( + lambda acc, elem: tf.math.logical_and( + acc, tf.math.not_equal(all_dims, elem) + ), + axes, + initializer=tf.fill(all_dims.shape, True), + ), + ), + axes, + ], + 0, + ) + return perm + + +def ifft_operations(x, rank, norm_factor): + if x.shape.rank == 1: + x = tf.signal.ifft(x) + elif x.shape.rank == 2: + x = tf.switch_case( + rank - 1, {0: lambda: tf.signal.ifft(x), 1: lambda: tf.signal.ifft2d(x)} + ) + else: + x = tf.switch_case( + rank - 1, + { + 0: lambda: tf.signal.ifft(x), + 1: lambda: tf.signal.ifft2d(x), + 2: lambda: tf.signal.ifft3d(x), + }, + ) + x = x * norm_factor + return x + + +def transpose_x(x, perm, perform_transpose): + x = tf.cond(perform_transpose, lambda: tf.transpose(x, perm=perm), lambda: x) + return x + + +def static_output_shape(input_shape, shape, axes): + output_shape = input_shape.as_list() + if shape is not None: + if axes is None: + axes = list(range(-len(shape), 0)) + if isinstance(shape, tf.Tensor): + if isinstance(axes, tf.Tensor): + output_shape = [None] * len(output_shape) + else: + for ax in axes: + output_shape[ax] = None + else: + for idx, ax in enumerate(axes): + output_shape[ax] = shape[idx] + return tf.TensorShape(output_shape) + + +def _right_pad_or_crop(tensor, shape): + input_shape = tf.shape(tensor) + shape = tf.convert_to_tensor(shape, dtype=tf.dtypes.int32) + with tf.control_dependencies( + [tf.debugging.assert_less_equal(tf.size(shape), tf.size(input_shape))] + ): + shape = tf.identity(shape) + shape = tf.concat([input_shape[: tf.size(input_shape) - tf.size(shape)], shape], 0) + + pad_sizes = tf.math.maximum(shape - input_shape, 0) + pad_sizes = tf.expand_dims(pad_sizes, -1) + pad_sizes = tf.concat( + [tf.zeros(pad_sizes.shape, dtype=tf.dtypes.int32), pad_sizes], -1 + ) + tensor = tf.pad(tensor, pad_sizes, constant_values=0) + + crop_tensor = tf.zeros(shape.shape, dtype=tf.dtypes.int32) + tensor = tf.slice(tensor, crop_tensor, shape) + return tensor + + +def _ifftn_helper(x, shape, axes, norm): + x = fft_input_validation(tf.convert_to_tensor(x)) + input_shape = x.shape + input_rank_tensor = tf.rank(x) + + shape_, axes_ = shape_and_axes_validation(shape, axes, input_rank_tensor) + + axes = axes_initialization(shape, axes, input_shape, input_rank_tensor) + + perform_padding, perform_transpose = perform_actions_initialization( + shape, axes, input_shape, input_rank_tensor + ) + + shape = shape_initialization(shape, axes, x) + + rank = rank_initialization(axes) + + norm_factor = norm_initialization(norm, shape, x) + + x = get_x_after_pad_or_crop(x, shape, axes, perform_padding, input_rank_tensor) + + perm = get_perm(input_rank_tensor, axes) + + x = transpose_x(x, perm, perform_transpose) + + x = ifft_operations(x, rank, norm_factor) + + x = transpose_x(x, tf.argsort(perm), perform_transpose) + + x = tf.ensure_shape(x, static_output_shape(input_shape, shape_, axes_)) + + return x + + +def ifftn( + x: Union[tf.Tensor, tf.Variable], + s: Optional[Union[int, Tuple[int]]] = None, + axes: Optional[Union[int, Tuple[int]]] = None, + *, + norm: Optional[str] = "backward", + out: Optional[Union[tf.Tensor, tf.Variable]] = None, +) -> Union[tf.Tensor, tf.Variable]: + result = _ifftn_helper(x, s, axes, norm) + + if out is not None: + out = result + return out + else: + return result diff --git a/ivy/functional/backends/torch/experimental/layers.py b/ivy/functional/backends/torch/experimental/layers.py index c2962b2af478c..3403e0f0f52a6 100644 --- a/ivy/functional/backends/torch/experimental/layers.py +++ b/ivy/functional/backends/torch/experimental/layers.py @@ -896,3 +896,14 @@ def fft2( return torch.tensor( torch.fft.fft2(x, s, dim, norm, out=out), dtype=torch.complex128 ) + + +def ifftn( + x: torch.Tensor, + s: Optional[Union[int, Tuple[int]]] = None, + axes: Optional[Union[int, Tuple[int]]] = None, + *, + norm: Optional[str] = "backward", + out: Optional[torch.Tensor] = None, +) -> torch.Tensor: + return torch.fft.ifftn(x, s=s, dim=axes, norm=norm, out=out) diff --git a/ivy/functional/ivy/experimental/layers.py b/ivy/functional/ivy/experimental/layers.py index 91e5dede085bf..561af90c65d17 100644 --- a/ivy/functional/ivy/experimental/layers.py +++ b/ivy/functional/ivy/experimental/layers.py @@ -2399,3 +2399,80 @@ def fft2( 0. +0.j , 0. +0.j ]]) """ return ivy.current_backend(x).fft2(x, s=s, dim=dim, norm=norm, out=out) + + +@handle_exceptions +@handle_nestable +@handle_array_like_without_promotion +@handle_out_argument +@to_native_arrays_and_back +def ifftn( + x: Union[ivy.Array, ivy.NativeArray], + s: Optional[Union[int, Tuple[int, ...]]] = None, + axes: Optional[Union[int, Tuple[int, ...]]] = None, + *, + norm: str = "backward", + out: Optional[ivy.Array] = None, +) -> ivy.Array: + r""" + Compute the N-dimensional inverse discrete Fourier Transform. + + Parameters + ---------- + x + Input array of complex numbers. + s + Shape (length of transformed axis) of the output (`s[0]` refers to axis 0, + `s[1]` to axis 1, etc.). If given shape is smaller than that of the input, + the input is cropped. If larger, input is padded with zeros. If `s` is not + given, shape of input along axes specified by axes is used. + axes + Axes over which to compute the IFFT. If not given, last `len(s)` axes are + used, or all axes if `s` is also not specified. Repeated indices in axes + means inverse transform over that axis is performed multiple times. + norm + Indicates direction of the forward/backward pair of transforms is scaled + and with what normalization factor. "backward" indicates no normalization. + "ortho" indicates normalization by $\frac{1}{\sqrt{n}}$. "forward" + indicates normalization by $\frac{1}{n}$. + out + Optional output array for writing the result to. It must have a shape that + the inputs broadcast to. + + Returns + ------- + out + The truncated or zero-padded input, transformed along the axes indicated + by axes, or by a combination of s or x, as explained in the parameters + section above. + + Raises + ------ + ValueError + If `s` and `axes` have different length. + IndexError + If an element of axes is larger than the number of axes of x. + + Examples + -------- + >>> x = ivy.array([[0.24730653+0.90832391j, 0.49495562+0.9039565j, + 0.98193269+0.49560517j], + [0.93280757+0.48075343j, 0.28526384+0.3351205j, + 0.2343787 +0.83528011j], + [0.18791352+0.30690572j, 0.82115787+0.96195183j, + 0.44719226+0.72654048j]]) + >>> y = ivy.ifftn(x) + >>> print(y) + ivy.array([[ 0.51476765+0.66160417j, -0.04319742-0.05411636j, + -0.015561 -0.04216015j], + [ 0.06310689+0.05347854j, -0.13392983+0.16052352j, + -0.08371392+0.17252843j], + [-0.0031429 +0.05421245j, -0.10446617-0.17747098j, + 0.05344324+0.07972424j]]) + + >>> b = ivy.ifftn(x, s=[2, 1], axes=[0, 1], norm='ortho') + >>> print(b) + ivy.array([[ 0.8344667 +0.98222595j], + [-0.48472244+0.30233797j]]) + """ + return ivy.current_backend(x).ifftn(x, s=s, axes=axes, norm=norm, out=out) diff --git a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py index 86329c8bb175e..c9505aa9c3426 100644 --- a/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_experimental/test_nn/test_layers.py @@ -1040,3 +1040,69 @@ def test_fft2( dim=dim, norm=norm, ) + + +@st.composite +def x_and_ifftn(draw): + min_fft_points = 2 + dtype = draw(helpers.get_dtypes("complex")) + x_dim = draw( + helpers.get_shape( + min_dim_size=2, max_dim_size=100, min_num_dims=1, max_num_dims=4 + ) + ) + x = draw( + helpers.array_values( + dtype=dtype[0], + shape=tuple(x_dim), + min_value=-1e-10, + max_value=1e10, + ) + ) + axes = draw( + st.lists( + st.integers(0, len(x_dim) - 1), min_size=1, max_size=len(x_dim), unique=True + ) + ) + norm = draw(st.sampled_from(["forward", "ortho", "backward"])) + + # Shape for s can be larger, smaller or equal to the size of the input + # along the axes specified by axes. + # Here, we're generating a list of integers corresponding to each axis in axes. + s = draw( + st.lists( + st.integers(min_fft_points, 256), min_size=len(axes), max_size=len(axes) + ) + ) + + return dtype, x, axes, norm, s + + +@handle_test( + fn_tree="functional.ivy.experimental.ifftn", + d_x_d_s_n=x_and_ifftn(), + ground_truth_backend="numpy", + test_gradients=st.just(False), +) +def test_ifftn( + *, + d_x_d_s_n, + test_flags, + backend_fw, + fn_name, + on_device, + ground_truth_backend, +): + dtype, x, axes, norm, s = d_x_d_s_n + helpers.test_function( + ground_truth_backend=ground_truth_backend, + input_dtypes=dtype, + test_flags=test_flags, + fw=backend_fw, + on_device=on_device, + fn_name=fn_name, + x=x, + s=s, + axes=axes, + norm=norm, + )