Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Fix LSTM and GRU layers gradient calculations #18203

Merged
merged 6 commits into from
May 14, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ inline size_t GetRNNWorkspaceSize(index_t seq_length,
case rnn_enum::kLstm:
size = seq_length * batch_size * hidden_size * (4 + direction) + // wx*x + inter-y
batch_size * hidden_size * 6 + // wh*h + h + c
seq_length * hidden_size * 8; // Used in Backward, Δbx, Δbh
seq_length * hidden_size * 8 + // Used in Backward, Δbx, Δbh
// temporary dy in backward computation for bidirectional layers
pengzhao-intel marked this conversation as resolved.
Show resolved Hide resolved
seq_length * batch_size * hidden_size * (direction - 1 ? direction : 0);
break;
case rnn_enum::kGru:
// Differs with Lstm, the outputs of three gates are also held in memory
Expand Down
7 changes: 6 additions & 1 deletion src/operator/rnn_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ void LstmBackward(DType* ws,
const index_t w_size1 = (I + H) * H * 4; // first layer
const index_t w_size2 = (D * H + H) * H * 4; // other layers
const index_t cell_size = N * H;
const index_t y_size = T * N * H * D;
DType* dy_tmp_ptr = ws2 + T * cell_size * 4 + cell_size * 3;
for (int i = L - 1; i >= 0; --i) {
const index_t input_size = i ? H * D : I;
Expand Down Expand Up @@ -594,6 +595,10 @@ void LstmBackward(DType* ws,
x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx],
dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr,
req_data, req_params, req_state, req_statecell);

// Prevent overwritting dy while calculating dx in left2right layer
const int loop_iteration = (L - 1) - i;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an additional case which contains even number of layers in UT (the current UT is only covering 1 and 3 layer(s))?

dy_tmp_ptr = loop_iteration % 2 ? dy_tmp_ptr - y_size : dy_tmp_ptr + y_size;
}
if (dropout > 0.0f && i > 0 && req_data != kNullOp) {
dropout_random = dropout_random - T * N * D * H;
Expand Down Expand Up @@ -1507,7 +1512,7 @@ void GruBackward(DType* ws,
if (dhy_l)
dhy_l = dhy_l - D * N * H;
y_l = y_l - T * N * H * D;
y_tmp = y_l;
y_tmp = y_tmp - T * N * H * D;
if (l == 1) {
wx_l = wx_l - (inputsize + H) * H * 3 * D;
wh_l = wx_l + inputsize * 3 * H;
Expand Down
90 changes: 30 additions & 60 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,15 +714,10 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_siz
stack_input_grad = sx.grad.asnumpy()

assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol, atol=atol)
if mx.context.current_context().device_type == 'cpu' and \
not mx.runtime.Features().is_enabled('MKLDNN') and \
'rnn' not in fused_layer.prefix:
print("LSTM and GRU on native CPU give wrong gradients. "
"Tracking issue: https://github.com/apache/incubator-mxnet/issues/17898.")
else:
assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol)
for key, value in fused_grads.items():
assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol)
assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol)
for key, value in fused_grads.items():
assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol)

num_layers = fused_begin_state[0].shape[0] // (2 if bidirectional else 1)
check_rnn_states(fused_states, stack_states, num_layers, bidirectional, len(fused_begin_state) == 2)

Expand All @@ -748,61 +743,32 @@ def create_op_by_mode(mode):
return fused_op, stack_op, recurrent_block_prefix


def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss):
def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss):
fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
# ==== Single layer ====
fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix)
fused_layer.initialize()

stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
with stack_layer.name_scope():
stack_layer.add(stack_op(hidden_size, prefix='l0_'))
stack_layer.initialize()

check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size)

# ==== Multiple layer ====
fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix)
fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix)
fused_layer.initialize()

stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
with stack_layer.name_scope():
stack_layer.add(stack_op(hidden_size, prefix='l0_'))
stack_layer.add(stack_op(hidden_size, prefix='l1_'))
stack_layer.add(stack_op(hidden_size, prefix='l2_'))
for n in range(num_layers):
stack_layer.add(stack_op(hidden_size, prefix=f'l{n}_'))
stack_layer.initialize()

check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size)


def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss):
def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss):
fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
# ==== Single layer ====
fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix)
fused_layer.initialize()

stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
with stack_layer.name_scope():
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'),
stack_op(hidden_size, prefix='r0_')))
stack_layer.initialize()

check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True)

# ==== Multiple layer ====
fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix)
fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix)
fused_layer.initialize()

stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
with stack_layer.name_scope():
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'),
stack_op(hidden_size, prefix='r0_')))
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l1_'),
stack_op(hidden_size, prefix='r1_')))
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l2_'),
stack_op(hidden_size, prefix='r2_')))
stack_layer.initialize()

for n in range(num_layers):
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix=f'l{n}_'),
stack_op(hidden_size, prefix=f'r{n}_')))
stack_layer.initialize()
check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True)


Expand All @@ -811,43 +777,47 @@ def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss):
def test_fused_lstm_layer():
input_sizes = [8]
hidden_sizes = [8, 16]
for input_size, hidden_size in product(input_sizes, hidden_sizes):
num_layers = [1, 2, 3, 4]
for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers):
loss = mx.gluon.loss.L2Loss()
check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, loss)
check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, loss)
check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, num_layers, loss)
check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, num_layers, loss)


@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_fused_gru_layer():
input_sizes = [8]
hidden_sizes = [8, 16]
for input_size, hidden_size in product(input_sizes, hidden_sizes):
num_layers = [1, 2, 3, 4]
for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers):
loss = mx.gluon.loss.L2Loss()
check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, loss)
check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, loss)
check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, num_layers, loss)
check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, num_layers, loss)


@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_fused_rnnrelu_layer():
input_sizes = [8]
hidden_sizes = [8, 16]
for input_size, hidden_size in product(input_sizes, hidden_sizes):
num_layers = [1, 2, 3, 4]
for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers):
loss = mx.gluon.loss.L2Loss()
check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, loss)
check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, loss)
check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, num_layers, loss)
check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, num_layers, loss)


@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_fused_rnntanh_layer():
input_sizes = [8]
hidden_sizes = [8, 16]
for input_size, hidden_size in product(input_sizes, hidden_sizes):
num_layers = [1, 2, 3, 4]
for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers):
loss = mx.gluon.loss.L2Loss()
check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss)
check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss)
check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, num_layers, loss)
check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, num_layers, loss)


@pytest.mark.serial
Expand Down