From 4aadd0f5af14a4a82176fd046890bb5988731856 Mon Sep 17 00:00:00 2001 From: Jianjian Guan Date: Thu, 15 Dec 2022 14:08:15 +0800 Subject: [PATCH] [Frontend] [ONNX] Support sequence_lens of GRU (#13587) [Frontend] [ONNX] Support sequence_lens of GRU. Support convert sequence_lens input of GRU. --- python/tvm/relay/frontend/common.py | 57 ++++++++++++++++++++-- python/tvm/relay/frontend/onnx.py | 18 ++++--- tests/python/frontend/onnx/test_forward.py | 40 ++++++++++++++- 3 files changed, 104 insertions(+), 11 deletions(-) diff --git a/python/tvm/relay/frontend/common.py b/python/tvm/relay/frontend/common.py index 5f961f1ae0e8..660426fb4ad5 100755 --- a/python/tvm/relay/frontend/common.py +++ b/python/tvm/relay/frontend/common.py @@ -737,6 +737,7 @@ def gru_cell( n_act=_op.tanh, backwards=False, linear_before_reset=True, + sequence_lens=None, ): """ Common implementation of GRU cell for all frontends of TVM @@ -765,7 +766,12 @@ def gru_cell( activation function for new gate. it is tanh by default backwards : bool Flag for reverse pass of GRU - + linear_before_reset : bool + Flag for applying the linear transformation before multiplying by the output of the reset + gate. + sequence_lens : relay.op + Tensor specifying lengths of the sequences in a batch. + Shape = (batch_size) Returns ------- result : List[relay.Expr], relay.Expr, relay.Expr @@ -773,7 +779,40 @@ def gru_cell( """ outputs_list = [] - for x_t in input_seqs if not backwards else reversed(input_seqs): + + seq_len = len(input_seqs) + input_dtype = infer_type(input_seqs[0]).checked_type.dtype + + if sequence_lens is not None: + shape = infer_shape(sequence_lens) + dtype = infer_type(sequence_lens).checked_type.dtype + + arange = _op.arange(_op.const(0), _op.const(seq_len), dtype=dtype) + arange = _op.expand_dims(arange, 1) + sequence_lens = _op.broadcast_to(sequence_lens, [seq_len, shape[0]]) + + # cast to data dtype + mask = _op.less(arange, sequence_lens) + mask = _op.cast(mask, dtype=input_dtype) + mask = _op.expand_dims(mask, 2) + mask_seqs = unbind(mask) + + res_mask = _op.greater_equal(arange, sequence_lens) + res_mask = _op.cast(res_mask, dtype=input_dtype) + res_mask = _op.expand_dims(res_mask, 2) + res_mask_seqs = unbind(res_mask) + + if backwards: + # need a mask to keep intial_h_B correct + initial_h = hidden_state + initial_h_mask = _op.equal(arange, sequence_lens) + initial_h_mask = _op.cast(initial_h_mask, dtype=input_dtype) + initial_h_mask = _op.expand_dims(initial_h_mask, 2) + initial_h_mask_seqs = unbind(initial_h_mask) + + output = _op.zeros(infer_shape(hidden_state), input_dtype) + for i in range(seq_len) if not backwards else reversed(range(seq_len)): + x_t = input_seqs[i] xwt = _op.nn.dense(x_t, w_inp) if linear_before_reset: hwt = _op.nn.dense(hidden_state, w_hid) @@ -806,9 +845,21 @@ def gru_cell( hidden_state = (hidden_state - n_gate) * z_gate + n_gate + if sequence_lens is not None: + hidden_state = hidden_state * mask_seqs[i] + outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)] - return outputs_list, hidden_state + if sequence_lens is not None: + output = output * res_mask_seqs[i] + hidden_state + else: + output = hidden_state + + # make sure initial_h_B correct + if backwards and sequence_lens is not None: + hidden_state = hidden_state + initial_h * initial_h_mask_seqs[i] + + return outputs_list, output def lstm_cell( diff --git a/python/tvm/relay/frontend/onnx.py b/python/tvm/relay/frontend/onnx.py index 3470099100d4..a8ab62602573 100644 --- a/python/tvm/relay/frontend/onnx.py +++ b/python/tvm/relay/frontend/onnx.py @@ -3126,8 +3126,7 @@ def _inputs_helper(cls, inputs, layout): Wp = inputs[1] Rp = inputs[2] Bp = inputs[3] - # Sequence length currently unused as it can be inferred from shapes. - # sequence_lens = inputs['sequence_lens'] + sequence_lens = inputs[4] Hp_0 = inputs[5] num_directions = infer_shape(Wp)[0] @@ -3158,11 +3157,11 @@ def _inputs_helper(cls, inputs, layout): Bs = None if Bp is not None: Bs = _op.split(Bp, num_directions) - return X_steps, H_ts, Ws, Rs, Bs, num_directions + return X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens @classmethod def _impl_common(cls, inputs, attr, layout): - X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout) + X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout) acts = cls._get_activations(attr, 1, num_directions, "RNN") weights_dicts = [] @@ -3261,7 +3260,7 @@ def _default_activations(cls, num_directions): @classmethod def _impl_common(cls, inputs, attr, layout): - X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout) + X_steps, H_ts, Ws, Rs, Bs, num_directions, _ = cls._inputs_helper(inputs, layout) acts = cls._get_activations(attr, 3, num_directions, "LSTM") # cell state @@ -3346,6 +3345,7 @@ def bidir_gru_cell( input_seqs, weight_dicts, acts, + sequence_lens=None, ): """ Bidirectional GRU cell @@ -3356,6 +3356,7 @@ def bidir_gru_cell( **weight_dicts[0], rz_act=acts[0], n_act=acts[1], + sequence_lens=sequence_lens, ) reverse_outputs, rev_H_t = gru_cell( @@ -3364,6 +3365,7 @@ def bidir_gru_cell( rz_act=acts[2], n_act=acts[3], backwards=True, + sequence_lens=sequence_lens, ) final_outputs = [] @@ -3383,7 +3385,9 @@ def _default_activations(cls, num_directions): @classmethod def _impl_common(cls, inputs, attr, layout): - X_steps, H_ts, Ws, Rs, Bs, num_directions = cls._inputs_helper(inputs, layout) + X_steps, H_ts, Ws, Rs, Bs, num_directions, sequence_lens = cls._inputs_helper( + inputs, layout + ) acts = cls._get_activations(attr, 2, num_directions, "GRU") linear_before_reset = attr.get("linear_before_reset", 0) @@ -3412,6 +3416,7 @@ def _impl_common(cls, inputs, attr, layout): input_seqs=X_steps, weight_dicts=weights_dicts, acts=acts, + sequence_lens=sequence_lens, ) else: # outputs shape = [seqs_num, (batch_size, hidden_size)] @@ -3420,6 +3425,7 @@ def _impl_common(cls, inputs, attr, layout): **weights_dicts[0], rz_act=acts[0], n_act=acts[1], + sequence_lens=sequence_lens, ) # output shape = (seqs_num, num_directions, batch_size, hidden_size) diff --git a/tests/python/frontend/onnx/test_forward.py b/tests/python/frontend/onnx/test_forward.py index dcd4f2defbe8..92a87ff6a72c 100644 --- a/tests/python/frontend/onnx/test_forward.py +++ b/tests/python/frontend/onnx/test_forward.py @@ -3897,6 +3897,7 @@ def verify_rnn( atol=1e-5, target=None, dev=None, + use_sequence_lens=False, ): """verify_rnn""" if rnn_type == "RNN": @@ -3954,10 +3955,16 @@ def register(np_arr, name, shape=None): ) register(b_np, "B") + if use_sequence_lens: + sequence_np = np.random.uniform(0, seq_length, size=(batch_size)).astype("int32") + register(sequence_np, "sequence_lens") + if use_initial_state: assert use_bias is True, "Initial states must have bias specified." - sequence_np = np.repeat(seq_length, batch_size).astype("int32") - register(sequence_np, "sequence_lens") + + if not use_sequence_lens: + sequence_np = np.repeat(seq_length, batch_size).astype("int32") + register(sequence_np, "sequence_lens") if layout == 1: initial_h_np = np.random.uniform(size=(batch_size, directions, hidden_size)).astype( @@ -4211,6 +4218,35 @@ def verify_rnn_helper(target, dev, rnn_type): # dev=dev, # ) + # Testing with initial state + if rnn_type == "GRU": + verify_rnn( + seq_length=2, + batch_size=1, + input_size=16, + hidden_size=32, + use_bias=True, + use_initial_state=True, + rnn_type=rnn_type, + directions=directions, + target=target, + dev=dev, + use_sequence_lens=True, + ) + verify_rnn( + seq_length=8, + batch_size=8, + input_size=16, + hidden_size=32, + use_bias=True, + use_initial_state=True, + rnn_type=rnn_type, + directions=directions, + target=target, + dev=dev, + use_sequence_lens=True, + ) + # Testing with peepholes if rnn_type == "LSTM": verify_rnn(