Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Add multi-layer, bidirectional RNN with tanh activation #28804

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
299 changes: 299 additions & 0 deletions ivy/functional/ivy/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2562,6 +2562,305 @@ def lstm(

return output[:, -1], output, (h_outs, c_outs)

@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@inputs_to_ivy_arrays
@handle_array_function
def rnn_tanh_update(
x: Union[ivy.Array, ivy.NativeArray],
init_h: Union[ivy.Array, ivy.NativeArray],
kernel: Union[ivy.Array, ivy.NativeArray],
recurrent_kernel: Union[ivy.Array, ivy.NativeArray],
/,
*,
bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
recurrent_bias: Optional[Union[ivy.Array, ivy.NativeArray]] = None,
time_major: bool = False,
) -> Tuple[ivy.Array, ivy.Array]:
"""Perform RNN-tanh update by unrolling time dimension of input array.

Parameters
----------
x
input tensor of RNN layer *[batch_shape, t, in]* if time_major=False,
else *[t, batch_shape, in]*.
init_h
initial state tensor for the cell output *[batch_shape, out]*.
kernel
weights for cell kernel *[in, out]*.
recurrent_kernel
weights for cell recurrent kernel *[out, out]*.
bias
bias for cell kernel *[out]*. (Default value = None)
recurrent_bias
bias for cell recurrent kernel *[out]*. (Default value = None)
time_major
whether or not the input tensor `x` has the time dimension before batch dim.

Returns
-------
ret
hidden state for all timesteps of shape *[batch_shape,t,out]* if time_major
is False, else *[t, batch_shape, out]*, and the final hidden state of shape
*[batch_shape,out]*.
"""
if time_major:
x = ivy.swapaxes(x, 0, 1)

# get shapes
x_shape = list(x.shape)
batch_shape = x_shape[:-2]
timesteps = x_shape[-2]
input_channels = x_shape[-1]
x_flat = ivy.reshape(x, (-1, input_channels))

# input kernel
Wx = kernel
Wx_x = ivy.reshape(
ivy.matmul(x_flat, Wx) + (bias if bias is not None else 0),
batch_shape + [timesteps, -1],
)

# recurrent kernel
Wh = recurrent_kernel

# rnn states
ht = init_h

# rnn outputs
hts_list = []

# unrolled time dimension with rnn steps
for Wx_xt in ivy.unstack(Wx_x, axis=-2):
htm1 = ht

Wh_htm1 = ivy.matmul(htm1, Wh) + (
recurrent_bias if recurrent_bias is not None else 0
)

ht = ivy.tanh(Wx_xt + Wh_htm1)

hts_list.append(ivy.expand_dims(ht, axis=-2))

ret = ivy.concat(hts_list, axis=-2)
if time_major:
ret = ivy.swapaxes(ret, 0, 1)

return ret, ht


@handle_exceptions
@handle_nestable
@handle_array_like_without_promotion
@inputs_to_ivy_arrays
@handle_array_function
def rnn_tanh(
input: ivy.Array,
initial_state: ivy.Array,
all_weights: Tuple[ivy.Array],
num_layers: int,
dropout: float,
train: bool,
bidirectional: bool,
batch_first: bool = False,
batch_sizes: Sequence = None,
weights_transposed: bool = False,
has_bias: bool = True,
):
"""Applies a multi-layer RNN-tanh to an input sequence.

Parameters
----------
input
input array of shape (seq_len, batch, input_size) when `batch_first` is False
or (batch, seq_len, input_size) when `batch_first` is True.
initial_state
initial hidden state of shape (num_layers * num_directions, batch, hidden_size).
all_weights
tuple of arrays representing the learnable weights of the rnn, with each
layer having up to two arrays (w_ih, w_hh, b_ih, b_hh) representing the weights
and biases (if biases are being used).

w_ih: weight of shape (hidden_size, input_size)
w_hh: weight of shape (hidden_size, hidden_size)
b_ih: bias of shape (hidden_size,)
b_hh: bias of shape (hidden_size,)
num_layers
number of layers for the rnn to use.
dropout
dropout rate.
train
whether to run the rnn in train mode or eval mode.
bidirectional
whether the rnn is bidirectional or unidirectional.
batch_first
defines the data format of the input and output arrays.
batch_sizes
specifies the batch size at each timestep, when the input is a packed sequence.
weights_transposed
whether the weights are transposed compared to the format in which they are expected (input_size, hidden_size)
rather than (hidden_size, input_size).
has_bias
whether the `all_weights` argument includes biases.

Returns
-------
output
output array of shape (seq_len, batch, num_directions * hidden_size) or
(batch, seq_len, num_directions * hidden_size), depending on `batch_first`.
h_out
final hidden state of shape (num_layers * num_directions, batch, hidden_size).
"""
if weights_transposed:
# transpose the weights if they are in the wrong format
all_weights = [
ivy.swapaxes(weight, 1, 0) if weight.dim() == 2 else weight
for weight in all_weights
]
else:
all_weights = list(all_weights)

weights_per_layer = 2 if not has_bias else 4

assert len(all_weights) == num_layers * weights_per_layer * (1 + bidirectional)
layer_weights = [
all_weights[i : i + weights_per_layer]
for i in range(0, len(all_weights), weights_per_layer)
]

if batch_sizes is not None:
input, batch_sizes = _pad_packed_sequence(input, batch_sizes)

if batch_first:
input = ivy.swapaxes(input, 0, 1)

if dropout and train:
raise ivy.utils.exceptions.IvyNotImplementedException()

unidirectional = not bidirectional

h0 = initial_state
h_outs = []

output = input
for i in range(num_layers):
if unidirectional:
if has_bias:
weight_ih, weight_hh, (bias_i, bias_h) = _transform_weights(
layer_weights, i
)
else:
weight_ih, weight_hh = _transform_weights_no_bias(layer_weights, i)
bias_i = bias_h = None

state_indices = i, i + 1
else:
if has_bias:
weight_ih_f, weight_hh_f, (bias_i_f, bias_h_f) = _transform_weights(
layer_weights, 2 * i
)
weight_ih_b, weight_hh_b, (bias_i_b, bias_h_b) = _transform_weights(
layer_weights, 2 * i + 1
)
else:
weight_ih_f, weight_hh_f = _transform_weights_no_bias(
layer_weights, 2 * i
)
weight_ih_b, weight_hh_b = _transform_weights_no_bias(
layer_weights, 2 * i + 1
)
bias_i_f = bias_h_f = bias_i_b = bias_h_b = None

weight_ih = weight_ih_f, weight_ih_b
weight_hh = weight_hh_f, weight_hh_b
bias_i = bias_i_f, bias_i_b
bias_h = bias_h_f, bias_h_b

state_indices = 2 * i, 2 * i + 2

output, h_out = _rnn_tanh_layer(
output,
_retrieve_state(h0, *state_indices, num_layers),
(weight_ih, weight_hh),
(bias_i, bias_h),
bidirectional,
batch_first=False,
batch_sizes=batch_sizes,
)
h_outs.append(h_out)

if batch_first:
output = ivy.swapaxes(output, 0, 1)

h_outs = h_out if num_layers == 1 else ivy.concat(h_outs, axis=0)

if batch_sizes is not None:
output = _pack_padded_sequence(output, batch_sizes)[0]

return output[:, -1], output, h_outs


def _rnn_tanh_cell(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could make this a public rather than private method, so ivy.rnn_tanh_cell can be used in the torch frontend rnn_tanh_cell implementation?

x,
init_h,
kernel,
recurrent_kernel,
bias,
recurrent_bias,
batch_first,
batch_sizes=None,
):
if init_h.shape[0] == 1:
init_h = ivy.squeeze(init_h, axis=0)

out, ht = ivy.rnn_tanh_update(
x,
init_h,
kernel,
recurrent_kernel,
bias=bias,
recurrent_bias=recurrent_bias,
time_major=not batch_first,
)
return out, ivy.expand_dims(ht, axis=0)


def _rnn_tanh_layer(
x,
initial_state,
all_weights,
all_biases,
bidirectional,
batch_first,
batch_sizes=None,
):
out, ht = _rnn_tanh_cell(
x,
initial_state,
*all_weights,
*all_biases,
batch_first=batch_first,
batch_sizes=batch_sizes,
)
if bidirectional:
x_rev = ivy.flip(x, axis=0)
initial_state_rev = ivy.flip(initial_state, axis=0)

out_rev, ht_rev = _rnn_tanh_cell(
x_rev,
initial_state_rev,
*all_weights,
*all_biases,
batch_first=batch_first,
batch_sizes=batch_sizes,
)
out_rev = ivy.flip(out_rev, axis=0)
ht_rev = ivy.flip(ht_rev, axis=0)
return ivy.concat([out, out_rev], axis=-1), ivy.concat([ht, ht_rev], axis=0)

return out, ht


# Helpers #

Expand Down
Loading