Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] [Torch] [ONNX] GRU layer #8781

Merged
merged 9 commits into from
Aug 25, 2021
84 changes: 84 additions & 0 deletions python/tvm/relay/frontend/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -658,6 +658,90 @@ def unbind(data, axis=0):
return _expr.TupleWrapper(_expr.Tuple(ret), selections)


def gru_cell(
input_seqs,
hidden_state,
w_inp,
w_hid,
b_inp=None,
b_hid=None,
rz_act=_op.sigmoid,
n_act=_op.tanh,
backwards=False,
linear_before_reset=True,
):
"""
Common implementation of GRU cell for all frontends of TVM
TODO(vvchernov): currently it is used by pytorch and ONNX. Extend for other frontends

Parameters
----------
input_seqs : List[relay.Expr]
The sequence of input tensors
Input tensor should be 2d while issue #8412 is not resolved
Shape = (batch, feature_size)
hidden_state : relay.Expr
Hidden state. shape = (batch_size, hidden_size)
w_inp, w_hid : relay.Expr
weight matrices. wi shape = (3 * hidden_size, feature_size)
wh shape = (3 * hidden_size, hidden_size)
NOTE: wi = (w_ir|w_iz|w_in) for reset, update and new gates.
The order is important for correct GRU calculation!
b_inp, b_hid : relay.Expr
bias matrices. The same order of internal parts as for weights. shape = (3 * hidden_size)
r_act : relay.op
activation funtion for reset gate. it is sigmoid by default
z_act : relay.op
activation funtion for update gate. it is sigmoid by default
n_act : relay.op
activation funtion for new gate. it is tanh by default
backwards : bool
Flag for reverse pass of GRU

Returns
-------
result : List[relay.Expr], relay.Expr, relay.Expr
The sequence of computed result, final hidden and cell state
"""

outputs_list = []
for x_t in input_seqs if not backwards else reversed(input_seqs):
xwt = _op.nn.dense(x_t, w_inp)
if linear_before_reset:
hwt = _op.nn.dense(hidden_state, w_hid)
if b_inp is not None and b_hid is not None:
xwt += b_inp
hwt += b_hid
i_r, i_z, i_n = _op.split(xwt, 3, axis=-1)
h_r, h_z, h_n = _op.split(hwt, 3, axis=-1)
r_gate = rz_act(i_r + h_r)
z_gate = rz_act(i_z + h_z)
n_gate = n_act(i_n + r_gate * h_n)
else:
i_r, i_z, i_n = _op.split(xwt, 3, axis=1)
w_hr, w_hz, w_hn = _op.split(w_hid, 3, axis=0)
r_gate = i_r + _op.nn.dense(hidden_state, w_hr)
z_gate = i_z + _op.nn.dense(hidden_state, w_hz)
if b_inp is not None and b_hid is not None:
b_ir, b_iz, b_in = _op.split(b_inp, 3, axis=-1)
b_hr, b_hz, b_hn = _op.split(b_hid, 3, axis=-1)
r_gate += b_ir + b_hr
z_gate += b_iz + b_hz
i_n += b_in
h_n = _op.nn.dense((r_gate * hidden_state), w_hn) + b_hn
else:
h_n = _op.nn.dense((r_gate * hidden_state), w_hn)
r_gate = rz_act(r_gate)
z_gate = rz_act(z_gate)
n_gate = n_act(i_n + h_n)

hidden_state = (hidden_state - n_gate) * z_gate + n_gate

outputs_list.append(hidden_state) # [seq_num, (batch, hidden_size)]

return outputs_list, hidden_state


def lstm_cell(
input_seqs,
hidden_state,
Expand Down
149 changes: 72 additions & 77 deletions python/tvm/relay/frontend/onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
infer_value,
new_var,
unbind,
gru_cell,
lstm_cell,
)

Expand Down Expand Up @@ -2349,56 +2350,41 @@ class GRU(RNN):
"""Operator convert for GRU"""

@classmethod
def generate_gru(
cls, X_steps, H_t, W, R, B, linear_before_reset, f_act, g_act, W_dtype, backwards=False
def bidir_gru_cell(
cls,
input_seqs,
weight_dicts,
acts,
):
"""Create an unrolled gru loop.

See https://github.com/onnx/onnx/blob/master/docs/Operators.md for math.
"""
h_list = []
seq_length = len(X_steps)
for i in range(seq_length):
step = X_steps[i] if not backwards else X_steps[seq_length - (i + 1)]
step = _op.squeeze(step, axis=[0])
current = _op.nn.dense(step, W)
cz, cr, ch = _op.split(current, 3, axis=1)
rz, rr, rh = _op.split(R, 3, axis=0)
z = cz + _op.nn.dense(H_t, rz)
r = cr + _op.nn.dense(H_t, rr)
if B is not None:
WB, RB = _op.split(B, 2)
wbz, wbr, wbh = _op.split(WB, 3, axis=-1)
rbz, rbr, rbh = _op.split(RB, 3, axis=-1)
z += wbz + rbz
r += wbr + rbr
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh) + rbh)) + wbh
else:
h = ch + _op.nn.dense((r * H_t), rh) + wbh + rbh
else:
if linear_before_reset:
h = ch + (r * (_op.nn.dense(H_t, rh)))
else:
h = ch + _op.nn.dense((r * H_t), rh)

z = f_act(z)
r = f_act(r)
h = g_act(h)

H_t = ((_expr.const(1, dtype=W_dtype) - z) * h) + (z * H_t)
h_list.append(_op.expand_dims(H_t, axis=0))
Bidirectional GRU cell
"""
seq_len = len(input_seqs)
forward_outputs, fw_H_t = gru_cell(
input_seqs,
**weight_dicts[0],
rz_act=acts[0],
n_act=acts[1],
)

if backwards:
# Canonical view is hidden states from the first token not last
h_list = h_list[::-1]
reverse_outputs, rev_H_t = gru_cell(
input_seqs,
**weight_dicts[1],
rz_act=acts[2],
n_act=acts[3],
backwards=True,
)

# Concatenate outputs and add back in direction axis.
concatenated = _op.concatenate(h_list, 0)
output = _op.expand_dims(concatenated, axis=1)
H_t = _op.expand_dims(H_t, axis=0)
final_outputs = []
for i in range(seq_len):
final_outputs.append(
_op.stack([forward_outputs[i], reverse_outputs[seq_len - 1 - i]], axis=0)
)

return output, H_t
return (
_op.stack(final_outputs, axis=0),
_op.stack([fw_H_t, rev_H_t], axis=0),
)

@classmethod
def _impl_v7(cls, inputs, attr, params):
Expand All @@ -2416,20 +2402,14 @@ def _impl_v7(cls, inputs, attr, params):
W_dtype = infer_type(Wp).checked_type.dtype

if num_directions not in [1, 2]:
raise NotImplementedError(
f"Directions for GRUs should be either 1 or 2 got {num_directions}"
)
raise ValueError("num_directions must be either 1 or 2!")

X_shape = infer_shape(X)
hidden_size = infer_shape(Rp)[-1]
batch_size = X_shape[1]

# Initialize state if not provided.
# Otherwise remove bidirectional axis.
if Hp_0 is None:
Hp_0 = _op.zeros((num_directions, batch_size, hidden_size), W_dtype)
if Bp is None:
Bp = _op.zeros((num_directions, hidden_size * 6), W_dtype)

if "activations" in attr:
activations = attr["activations"]
Expand Down Expand Up @@ -2460,39 +2440,54 @@ def _impl_v7(cls, inputs, attr, params):
else:
acts = [_op.sigmoid, _op.tanh] * 2

result_output = []
result_H = []
# TODO (vvchernov): It can be replaced by _op.split if issue #8412 is resolved
X_steps = unbind(X, axis=0)

X_steps = _op.split(X, indices_or_sections=X_shape[0], axis=0)
H_ts = _op.split(Hp_0, num_directions)
Ws = _op.split(Wp, num_directions)
Rs = _op.split(Rp, num_directions)
Bs = _op.split(Bp, num_directions)

if Bp is not None:
Bs = _op.split(Bp, num_directions)

weights_dicts = []
for i in range(num_directions):
H_t = _op.squeeze(H_ts[i], axis=[0])
W = _op.squeeze(Ws[i], axis=[0])
R = _op.squeeze(Rs[i], axis=[0])
B = _op.squeeze(Bs[i], axis=[0])
f_act, g_act = acts[i * 2 : (i + 1) * 2]
output, H = GRU.generate_gru(
X_steps=X_steps,
H_t=H_t,
W=W,
R=R,
B=B,
linear_before_reset=linear_before_reset,
f_act=f_act,
g_act=g_act,
W_dtype=W_dtype,
backwards=i == 1,
)
weights_dict = {}

weights_dict["hidden_state"] = _op.squeeze(H_ts[i], axis=[0])
weights_dict["linear_before_reset"] = linear_before_reset

# Weights permutation: onnx format i-o-f-c, lstm cell format i-f-c-o
matz, matr, matn = _op.split(_op.squeeze(Ws[i], axis=[0]), 3)
weights_dict["w_inp"] = _op.concatenate([matr, matz, matn], axis=0)
matz, matr, matn = _op.split(_op.squeeze(Rs[i], axis=[0]), 3)
weights_dict["w_hid"] = _op.concatenate([matr, matz, matn], axis=0)
if Bp is not None:
Bi, Bh = _op.split(Bs[i], 2, -1)
matz, matr, matn = _op.split(_op.squeeze(Bi, axis=[0]), 3)
weights_dict["b_inp"] = _op.concatenate([matr, matz, matn], axis=0)
matz, matr, matn = _op.split(_op.squeeze(Bh, axis=[0]), 3)
weights_dict["b_hid"] = _op.concatenate([matr, matz, matn], axis=0)
weights_dicts.append(weights_dict)

result_output.append(output)
result_H.append(H)
if num_directions == 2:
output, H = GRU.bidir_gru_cell(
input_seqs=X_steps,
weight_dicts=weights_dicts,
acts=acts,
)
else:
# outputs shape = [seqs_num, (batch_size, hidden_size)]
outputs, H = gru_cell(
input_seqs=X_steps,
**weights_dicts[0],
rz_act=acts[0],
n_act=acts[1],
)

output = _op.concatenate(result_output, axis=1)
H = _op.concatenate(result_H, axis=0)
# output shape = (seqs_num, num_directions, batch_size, hidden_size)
output = _op.expand_dims(_op.stack(outputs, axis=0), axis=1)
H = _op.expand_dims(H, axis=0)

return _expr.TupleWrapper(_expr.Tuple((output, H)), 2)

Expand Down
Loading