Skip to content

Commit

Permalink
[Frontend] [ONNX] Support sequence_lens of GRU.
Browse files Browse the repository at this point in the history
Support convert sequence_lens input of GRU.
  • Loading branch information
Jianjian.Guan authored and Jianjian.Guan committed Dec 15, 2022
1 parent c547bbb commit 67e7dc4
Show file tree
Hide file tree
Showing 3 changed files with 104 additions and 11 deletions.
57 changes: 54 additions & 3 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ def gru_cell(
n_act=_op.tanh,
backwards=False,
linear_before_reset=True,
sequence_lens=None,
):
"""
Common implementation of GRU cell for all frontends of TVM
Expand Down Expand Up @@ -765,15 +766,53 @@ def gru_cell(
activation function for new gate. it is tanh by default
backwards : bool
Flag for reverse pass of GRU
linear_before_reset : bool
Flag for applying the linear transformation before multiplying by the output of the reset
gate.
sequence_lens : relay.op
Tensor specifying lengths of the sequences in a batch.
Shape = (batch_size)
Returns
-------
result : List[relay.Expr], relay.Expr, relay.Expr
The sequence of computed result, final hidden and cell state
"""

outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):

seq_len = len(input_seqs)
input_dtype = infer_type(input_seqs[0]).checked_type.dtype

if sequence_lens is not None:
shape = infer_shape(sequence_lens)
dtype = infer_type(sequence_lens).checked_type.dtype

arange = _op.arange(_op.const(0), _op.const(seq_len), dtype=dtype)
arange = _op.expand_dims(arange, 1)
sequence_lens = _op.broadcast_to(sequence_lens, [seq_len, shape[0]])

# cast to data dtype
mask = _op.less(arange, sequence_lens)
mask = _op.cast(mask, dtype=input_dtype)
mask = _op.expand_dims(mask, 2)
mask_seqs = unbind(mask)

res_mask = _op.greater_equal(arange, sequence_lens)
res_mask = _op.cast(res_mask, dtype=input_dtype)
res_mask = _op.expand_dims(res_mask, 2)
res_mask_seqs = unbind(res_mask)

if backwards:
# need a mask to keep intial_h_B correct
initial_h = hidden_state
initial_h_mask = _op.equal(arange, sequence_lens)
initial_h_mask = _op.cast(initial_h_mask, dtype=input_dtype)
initial_h_mask = _op.expand_dims(initial_h_mask, 2)
initial_h_mask_seqs = unbind(initial_h_mask)

output = _op.zeros(infer_shape(hidden_state), input_dtype)
for i in range(seq_len) if not backwards else reversed(range(seq_len)):
x_t = input_seqs[i]
xwt = _op.nn.dense(x_t, w_inp)
if linear_before_reset:
hwt = _op.nn.dense(hidden_state, w_hid)
Expand Down Expand Up @@ -806,9 +845,21 @@ def gru_cell(

hidden_state = (hidden_state - n_gate) * z_gate + n_gate

if sequence_lens is not None:
hidden_state = hidden_state * mask_seqs[i]

outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]

return outputs_list, hidden_state
if sequence_lens is not None:
output = output * res_mask_seqs[i] + hidden_state
else:
output = hidden_state

# make sure initial_h_B correct
if backwards and sequence_lens is not None:
hidden_state = hidden_state + initial_h * initial_h_mask_seqs[i]

return outputs_list, output


def lstm_cell(
Expand Down
18 changes: 12 additions & 6 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3126,8 +3126,7 @@ def _inputs_helper(cls, inputs, layout):
Wp = inputs[1]
Rp = inputs[2]
Bp = inputs[3]
# Sequence length currently unused as it can be inferred from shapes.
# sequence_lens = inputs['sequence_lens']
sequence_lens = inputs[4]
Hp_0 = inputs[5]

num_directions = infer_shape(Wp)[0]
Expand Down Expand Up @@ -3158,11 +3157,11 @@ def _inputs_helper(cls, inputs, layout):
Bs = None
if Bp is not None:
Bs = _op.split(Bp, num_directions)
return X_steps, H_ts, Ws, Rs, Bs, num_directions
return X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens

@classmethod
def _impl_common(cls, inputs, attr, layout):
X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
acts = cls._get_activations(attr, 1, num_directions, "RNN")

weights_dicts = []
Expand Down Expand Up @@ -3261,7 +3260,7 @@ def _default_activations(cls, num_directions):

@classmethod
def _impl_common(cls, inputs, attr, layout):
X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout)
acts = cls._get_activations(attr, 3, num_directions, "LSTM")

# cell state
Expand Down Expand Up @@ -3346,6 +3345,7 @@ def bidir_gru_cell(
input_seqs,
weight_dicts,
acts,
sequence_lens=None,
):
"""
Bidirectional GRU cell
Expand All @@ -3356,6 +3356,7 @@ def bidir_gru_cell(
**weight_dicts[0],
rz_act=acts[0],
n_act=acts[1],
sequence_lens=sequence_lens,
)

reverse_outputs, rev_H_t = gru_cell(
Expand All @@ -3364,6 +3365,7 @@ def bidir_gru_cell(
rz_act=acts[2],
n_act=acts[3],
backwards=True,
sequence_lens=sequence_lens,
)

final_outputs = []
Expand All @@ -3383,7 +3385,9 @@ def _default_activations(cls, num_directions):

@classmethod
def _impl_common(cls, inputs, attr, layout):
X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout)
X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens = cls._inputs_helper(
inputs, layout
)
acts = cls._get_activations(attr, 2, num_directions, "GRU")
linear_before_reset = attr.get("linear_before_reset", 0)

Expand Down Expand Up @@ -3412,6 +3416,7 @@ def _impl_common(cls, inputs, attr, layout):
input_seqs=X_steps,
weight_dicts=weights_dicts,
acts=acts,
sequence_lens=sequence_lens,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
Expand All @@ -3420,6 +3425,7 @@ def _impl_common(cls, inputs, attr, layout):
**weights_dicts[0],
rz_act=acts[0],
n_act=acts[1],
sequence_lens=sequence_lens,
)

# output shape = (seqs_num, num_directions, batch_size, hidden_size)
Expand Down
40 changes: 38 additions & 2 deletions tests/python/frontend/onnx/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -3897,6 +3897,7 @@ def verify_rnn(
atol=1e-5,
target=None,
dev=None,
use_sequence_lens=False,
):
"""verify_rnn"""
if rnn_type == "RNN":
Expand Down Expand Up @@ -3954,10 +3955,16 @@ def register(np_arr, name, shape=None):
)
register(b_np, "B")

if use_sequence_lens:
sequence_np = np.random.uniform(0, seq_length, size=(batch_size)).astype("int32")
register(sequence_np, "sequence_lens")

if use_initial_state:
assert use_bias is True, "Initial states must have bias specified."
sequence_np = np.repeat(seq_length, batch_size).astype("int32")
register(sequence_np, "sequence_lens")

if not use_sequence_lens:
sequence_np = np.repeat(seq_length, batch_size).astype("int32")
register(sequence_np, "sequence_lens")

if layout == 1:
initial_h_np = np.random.uniform(size=(batch_size, directions, hidden_size)).astype(
Expand Down Expand Up @@ -4211,6 +4218,35 @@ def verify_rnn_helper(target, dev, rnn_type):
# dev=dev,
# )

# Testing with initial state
if rnn_type == "GRU":
verify_rnn(
seq_length=2,
batch_size=1,
input_size=16,
hidden_size=32,
use_bias=True,
use_initial_state=True,
rnn_type=rnn_type,
directions=directions,
target=target,
dev=dev,
use_sequence_lens=True,
)
verify_rnn(
seq_length=8,
batch_size=8,
input_size=16,
hidden_size=32,
use_bias=True,
use_initial_state=True,
rnn_type=rnn_type,
directions=directions,
target=target,
dev=dev,
use_sequence_lens=True,
)

# Testing with peepholes
if rnn_type == "LSTM":
verify_rnn(
Expand Down

0 comments on commit 67e7dc4

Please sign in to comment.