diff --git a/ivy/functional/ivy/layers.py b/ivy/functional/ivy/layers.py index 508a090f10e3..c04600cad390 100644 --- a/ivy/functional/ivy/layers.py +++ b/ivy/functional/ivy/layers.py @@ -2562,6 +2562,305 @@ def lstm( return output[:, -1], output, (h_outs, c_outs) +@handle_exceptions +@handle_nestable +@handle_array_like_without_promotion +@inputs_to_ivy_arrays +@handle_array_function +def rnn_tanh_update( + x: Union[ivy.Array, ivy.NativeArray], + init_h: Union[ivy.Array, ivy.NativeArray], + kernel: Union[ivy.Array, ivy.NativeArray], + recurrent_kernel: Union[ivy.Array, ivy.NativeArray], + /, + *, + bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + recurrent_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None, + time_major: bool = False, +) -> Tuple[ivy.Array, ivy.Array]: + """Perform RNN-tanh update by unrolling time dimension of input array. + + Parameters + ---------- + x + input tensor of RNN layer *[batch_shape, t, in]* if time_major=False, + else *[t, batch_shape, in]*. + init_h + initial state tensor for the cell output *[batch_shape, out]*. + kernel + weights for cell kernel *[in, out]*. + recurrent_kernel + weights for cell recurrent kernel *[out, out]*. + bias + bias for cell kernel *[out]*. (Default value = None) + recurrent_bias + bias for cell recurrent kernel *[out]*. (Default value = None) + time_major + whether or not the input tensor `x` has the time dimension before batch dim. + + Returns + ------- + ret + hidden state for all timesteps of shape *[batch_shape,t,out]* if time_major + is False, else *[t, batch_shape, out]*, and the final hidden state of shape + *[batch_shape,out]*. + """ + if time_major: + x = ivy.swapaxes(x, 0, 1) + + # get shapes + x_shape = list(x.shape) + batch_shape = x_shape[:-2] + timesteps = x_shape[-2] + input_channels = x_shape[-1] + x_flat = ivy.reshape(x, (-1, input_channels)) + + # input kernel + Wx = kernel + Wx_x = ivy.reshape( + ivy.matmul(x_flat, Wx) + (bias if bias is not None else 0), + batch_shape + [timesteps, -1], + ) + + # recurrent kernel + Wh = recurrent_kernel + + # rnn states + ht = init_h + + # rnn outputs + hts_list = [] + + # unrolled time dimension with rnn steps + for Wx_xt in ivy.unstack(Wx_x, axis=-2): + htm1 = ht + + Wh_htm1 = ivy.matmul(htm1, Wh) + ( + recurrent_bias if recurrent_bias is not None else 0 + ) + + ht = ivy.tanh(Wx_xt + Wh_htm1) + + hts_list.append(ivy.expand_dims(ht, axis=-2)) + + ret = ivy.concat(hts_list, axis=-2) + if time_major: + ret = ivy.swapaxes(ret, 0, 1) + + return ret, ht + + +@handle_exceptions +@handle_nestable +@handle_array_like_without_promotion +@inputs_to_ivy_arrays +@handle_array_function +def rnn_tanh( + input: ivy.Array, + initial_state: ivy.Array, + all_weights: Tuple[ivy.Array], + num_layers: int, + dropout: float, + train: bool, + bidirectional: bool, + batch_first: bool = False, + batch_sizes: Sequence = None, + weights_transposed: bool = False, + has_bias: bool = True, +): + """Applies a multi-layer RNN-tanh to an input sequence. + + Parameters + ---------- + input + input array of shape (seq_len, batch, input_size) when `batch_first` is False + or (batch, seq_len, input_size) when `batch_first` is True. + initial_state + initial hidden state of shape (num_layers * num_directions, batch, hidden_size). + all_weights + tuple of arrays representing the learnable weights of the rnn, with each + layer having up to two arrays (w_ih, w_hh, b_ih, b_hh) representing the weights + and biases (if biases are being used). + + w_ih: weight of shape (hidden_size, input_size) + w_hh: weight of shape (hidden_size, hidden_size) + b_ih: bias of shape (hidden_size,) + b_hh: bias of shape (hidden_size,) + num_layers + number of layers for the rnn to use. + dropout + dropout rate. + train + whether to run the rnn in train mode or eval mode. + bidirectional + whether the rnn is bidirectional or unidirectional. + batch_first + defines the data format of the input and output arrays. + batch_sizes + specifies the batch size at each timestep, when the input is a packed sequence. + weights_transposed + whether the weights are transposed compared to the format in which they are expected (input_size, hidden_size) + rather than (hidden_size, input_size). + has_bias + whether the `all_weights` argument includes biases. + + Returns + ------- + output + output array of shape (seq_len, batch, num_directions * hidden_size) or + (batch, seq_len, num_directions * hidden_size), depending on `batch_first`. + h_out + final hidden state of shape (num_layers * num_directions, batch, hidden_size). + """ + if weights_transposed: + # transpose the weights if they are in the wrong format + all_weights = [ + ivy.swapaxes(weight, 1, 0) if weight.dim() == 2 else weight + for weight in all_weights + ] + else: + all_weights = list(all_weights) + + weights_per_layer = 2 if not has_bias else 4 + + assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional) + layer_weights = [ + all_weights[i : i + weights_per_layer] + for i in range(0, len(all_weights), weights_per_layer) + ] + + if batch_sizes is not None: + input, batch_sizes = _pad_packed_sequence(input, batch_sizes) + + if batch_first: + input = ivy.swapaxes(input, 0, 1) + + if dropout and train: + raise ivy.utils.exceptions.IvyNotImplementedException() + + unidirectional = not bidirectional + + h0 = initial_state + h_outs = [] + + output = input + for i in range(num_layers): + if unidirectional: + if has_bias: + weight_ih, weight_hh, (bias_i, bias_h) = _transform_weights( + layer_weights, i + ) + else: + weight_ih, weight_hh = _transform_weights_no_bias(layer_weights, i) + bias_i = bias_h = None + + state_indices = i, i + 1 + else: + if has_bias: + weight_ih_f, weight_hh_f, (bias_i_f, bias_h_f) = _transform_weights( + layer_weights, 2 * i + ) + weight_ih_b, weight_hh_b, (bias_i_b, bias_h_b) = _transform_weights( + layer_weights, 2 * i + 1 + ) + else: + weight_ih_f, weight_hh_f = _transform_weights_no_bias( + layer_weights, 2 * i + ) + weight_ih_b, weight_hh_b = _transform_weights_no_bias( + layer_weights, 2 * i + 1 + ) + bias_i_f = bias_h_f = bias_i_b = bias_h_b = None + + weight_ih = weight_ih_f, weight_ih_b + weight_hh = weight_hh_f, weight_hh_b + bias_i = bias_i_f, bias_i_b + bias_h = bias_h_f, bias_h_b + + state_indices = 2 * i, 2 * i + 2 + + output, h_out = _rnn_tanh_layer( + output, + _retrieve_state(h0, *state_indices, num_layers), + (weight_ih, weight_hh), + (bias_i, bias_h), + bidirectional, + batch_first=False, + batch_sizes=batch_sizes, + ) + h_outs.append(h_out) + + if batch_first: + output = ivy.swapaxes(output, 0, 1) + + h_outs = h_out if num_layers == 1 else ivy.concat(h_outs, axis=0) + + if batch_sizes is not None: + output = _pack_padded_sequence(output, batch_sizes)[0] + + return output[:, -1], output, h_outs + + +def rnn_tanh_cell( + x, + init_h, + kernel, + recurrent_kernel, + bias, + recurrent_bias, + batch_first, + batch_sizes=None, +): + if init_h.shape[0] == 1: + init_h = ivy.squeeze(init_h, axis=0) + + out, ht = ivy.rnn_tanh_update( + x, + init_h, + kernel, + recurrent_kernel, + bias=bias, + recurrent_bias=recurrent_bias, + time_major=not batch_first, + ) + return out, ivy.expand_dims(ht, axis=0) + + +def _rnn_tanh_layer( + x, + initial_state, + all_weights, + all_biases, + bidirectional, + batch_first, + batch_sizes=None, +): + out, ht = rnn_tanh_cell( + x, + initial_state, + *all_weights, + *all_biases, + batch_first=batch_first, + batch_sizes=batch_sizes, + ) + if bidirectional: + x_rev = ivy.flip(x, axis=0) + initial_state_rev = ivy.flip(initial_state, axis=0) + + out_rev, ht_rev = rnn_tanh_cell( + x_rev, + initial_state_rev, + *all_weights, + *all_biases, + batch_first=batch_first, + batch_sizes=batch_sizes, + ) + out_rev = ivy.flip(out_rev, axis=0) + ht_rev = ivy.flip(ht_rev, axis=0) + return ivy.concat([out, out_rev], axis=-1), ivy.concat([ht, ht_rev], axis=0) + + return out, ht + # Helpers # diff --git a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py index 385b19b1a4c1..e22dea1ca514 100644 --- a/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py +++ b/ivy_tests/test_ivy/test_functional/test_nn/test_layers.py @@ -948,6 +948,61 @@ def _x_and_lstm(draw, dtypes): ) +# RNN # +#-----# + +@st.composite +def _x_and_rnn_tanh(draw, dtypes): + dtype = draw(dtypes) + batch_shape = (1,) + + t = draw(helpers.ints(min_value=1, max_value=2)) + _in_ = draw(helpers.ints(min_value=1, max_value=2)) + _out_ = draw(helpers.ints(min_value=1, max_value=2)) + + x_rnn_shape = batch_shape + (t,) + (_in_,) + init_h_shape = batch_shape + (_out_,) + kernel_shape = (_in_,) + (_out_,) + recurrent_kernel_shape = (_out_,) + (_out_,) + bias_shape = (_out_,) + + x_rnn = draw( + helpers.array_values( + dtype=dtype[0], shape=x_rnn_shape, min_value=0, max_value=1 + ) + ) + init_h = draw( + helpers.array_values( + dtype=dtype[0], shape=init_h_shape, min_value=0, max_value=1 + ) + ) + kernel = draw( + helpers.array_values( + dtype=dtype[0], shape=kernel_shape, min_value=0, max_value=1 + ) + ) + recurrent_kernel = draw( + helpers.array_values( + dtype=dtype[0], shape=recurrent_kernel_shape, min_value=0, max_value=1 + ) + ) + bias = draw( + helpers.array_values(dtype=dtype[0], shape=bias_shape, min_value=0, max_value=1) + ) + recurrent_bias = draw( + helpers.array_values(dtype=dtype[0], shape=bias_shape, min_value=0, max_value=1) + ) + return ( + dtype, + x_rnn, + init_h, + kernel, + recurrent_kernel, + bias, + recurrent_bias, + ) + + # Attention # # ----------# @@ -1601,6 +1656,39 @@ def test_lstm_update(*, dtype_lstm, test_flags, backend_fw, fn_name, on_device): recurrent_bias=recurrent_bias, ) +# rnn_tanh_update +@handle_test( + fn_tree="functional.ivy.rnn_tanh_update", + dtype_rnn=_x_and_rnn_tanh( + dtypes=helpers.get_dtypes("numeric"), + ), + test_with_out=st.just(False), +) +def test_rnn_tanh_update(*, dtype_rnn, test_flags, backend_fw, fn_name, on_device): + ( + dtype, + x_rnn, + init_h, + kernel, + recurrent_kernel, + bias, + recurrent_bias, + ) = dtype_rnn + helpers.test_function( + input_dtypes=dtype, + test_flags=test_flags, + backend_to_test=backend_fw, + fn_name=fn_name, + on_device=on_device, + rtol_=1e-01, + atol_=1e-01, + x=x_rnn, + init_h=init_h, + kernel=kernel, + recurrent_kernel=recurrent_kernel, + bias=bias, + recurrent_bias=recurrent_bias, + ) # multi_head_attention @handle_test(