Skip to content

Commit

Permalink
Fix LSTM and GRU layers gradient calculations (apache#18203)
Browse files Browse the repository at this point in the history
* Fix input gradient calculation for bidirectional LSTM

For bidiractional LSTM with number of layers > 2 input gradient calculation was incorrect.
Reason of wrong calculations was overwriting y derivative (dy) tensor by
calculated x derivative (dx) tensor before right2left layer could use dy for own
gradient calculations.
Propsed fix uses additional space to avoid overwriting.

* Fix gradient calculation for GRU

For GRU with number of layers > 2 i2h_weight gradient for
layers in the middle (all except last and first) was incorrect.
Wrong caluculations were caused by assigning output pointer to
input instead of calculating new input pointer.

* Enable tests for GRU and LSTM gradients

* Fix comments

* Change loop iteration deduction

* Add more test cases for fused rnn layers
  • Loading branch information
bgawrych authored and AntiZpvoh committed Jul 6, 2020
1 parent f603590 commit 44200fd
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 62 deletions.
4 changes: 3 additions & 1 deletion src/operator/rnn-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,9 @@ inline size_t GetRNNWorkspaceSize(index_t seq_length,
case rnn_enum::kLstm:
size = seq_length * batch_size * hidden_size * (4 + direction) + // wx*x + inter-y
batch_size * hidden_size * 6 + // wh*h + h + c
seq_length * hidden_size * 8; // Used in Backward, Δbx, Δbh
seq_length * hidden_size * 8 + // Used in Backward, Δbx, Δbh
// temporary dy in backward computation for bidirectional layers
seq_length * batch_size * hidden_size * (direction - 1 ? direction : 0);
break;
case rnn_enum::kGru:
// Differs with Lstm, the outputs of three gates are also held in memory
Expand Down
7 changes: 6 additions & 1 deletion src/operator/rnn_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -564,6 +564,7 @@ void LstmBackward(DType* ws,
const index_t w_size1 = (I + H) * H * 4; // first layer
const index_t w_size2 = (D * H + H) * H * 4; // other layers
const index_t cell_size = N * H;
const index_t y_size = T * N * H * D;
DType* dy_tmp_ptr = ws2 + T * cell_size * 4 + cell_size * 3;
for (int i = L - 1; i >= 0; --i) {
const index_t input_size = i ? H * D : I;
Expand Down Expand Up @@ -594,6 +595,10 @@ void LstmBackward(DType* ws,
x, hx[idx], cx[idx], y, dy, dx, dhx[idx], dcx[idx],
dhy_cur_ptr, dcy_cur_ptr, w_cur_ptr, dw_cur_ptr, db_cur_ptr,
req_data, req_params, req_state, req_statecell);

// Prevent overwritting dy while calculating dx in left2right layer
const int loop_iteration = (L - 1) - i;
dy_tmp_ptr = loop_iteration % 2 ? dy_tmp_ptr - y_size : dy_tmp_ptr + y_size;
}
if (dropout > 0.0f && i > 0 && req_data != kNullOp) {
dropout_random = dropout_random - T * N * D * H;
Expand Down Expand Up @@ -1507,7 +1512,7 @@ void GruBackward(DType* ws,
if (dhy_l)
dhy_l = dhy_l - D * N * H;
y_l = y_l - T * N * H * D;
y_tmp = y_l;
y_tmp = y_tmp - T * N * H * D;
if (l == 1) {
wx_l = wx_l - (inputsize + H) * H * 3 * D;
wh_l = wx_l + inputsize * 3 * H;
Expand Down
90 changes: 30 additions & 60 deletions tests/python/unittest/test_gluon_rnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -714,15 +714,10 @@ def check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_siz
stack_input_grad = sx.grad.asnumpy()

assert_allclose(fused_out.asnumpy(), stack_out.asnumpy(), rtol=rtol, atol=atol)
if mx.context.current_context().device_type == 'cpu' and \
not mx.runtime.Features().is_enabled('MKLDNN') and \
'rnn' not in fused_layer.prefix:
print("LSTM and GRU on native CPU give wrong gradients. "
"Tracking issue: https://github.com/apache/incubator-mxnet/issues/17898.")
else:
assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol)
for key, value in fused_grads.items():
assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol)
assert_allclose(fused_input_grad, stack_input_grad, rtol=rtol, atol=atol)
for key, value in fused_grads.items():
assert_allclose(value.asnumpy(), stack_grads[key].asnumpy(), rtol=rtol, atol=atol)

num_layers = fused_begin_state[0].shape[0] // (2 if bidirectional else 1)
check_rnn_states(fused_states, stack_states, num_layers, bidirectional, len(fused_begin_state) == 2)

Expand All @@ -748,61 +743,32 @@ def create_op_by_mode(mode):
return fused_op, stack_op, recurrent_block_prefix


def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, loss):
def check_rnn_unidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss):
fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
# ==== Single layer ====
fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix)
fused_layer.initialize()

stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
with stack_layer.name_scope():
stack_layer.add(stack_op(hidden_size, prefix='l0_'))
stack_layer.initialize()

check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size)

# ==== Multiple layer ====
fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix)
fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=False, prefix=recurrent_block_prefix)
fused_layer.initialize()

stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
with stack_layer.name_scope():
stack_layer.add(stack_op(hidden_size, prefix='l0_'))
stack_layer.add(stack_op(hidden_size, prefix='l1_'))
stack_layer.add(stack_op(hidden_size, prefix='l2_'))
for n in range(num_layers):
stack_layer.add(stack_op(hidden_size, prefix=f'l{n}_'))
stack_layer.initialize()

check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size)


def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss):
def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, num_layers, loss):
fused_op, stack_op, recurrent_block_prefix = create_op_by_mode(mode)
# ==== Single layer ====
fused_layer = fused_op(hidden_size, num_layers=1, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix)
fused_layer.initialize()

stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
with stack_layer.name_scope():
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'),
stack_op(hidden_size, prefix='r0_')))
stack_layer.initialize()

check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True)

# ==== Multiple layer ====
fused_layer = fused_op(hidden_size, num_layers=3, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix)
fused_layer = fused_op(hidden_size, num_layers=num_layers, layout='NTC', bidirectional=True, prefix=recurrent_block_prefix)
fused_layer.initialize()

stack_layer = mx.gluon.rnn.HybridSequentialRNNCell(prefix=recurrent_block_prefix)
with stack_layer.name_scope():
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l0_'),
stack_op(hidden_size, prefix='r0_')))
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l1_'),
stack_op(hidden_size, prefix='r1_')))
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix='l2_'),
stack_op(hidden_size, prefix='r2_')))
stack_layer.initialize()

for n in range(num_layers):
stack_layer.add(gluon.rnn.BidirectionalCell(stack_op(hidden_size, prefix=f'l{n}_'),
stack_op(hidden_size, prefix=f'r{n}_')))
stack_layer.initialize()
check_rnn_consistency(fused_layer, stack_layer, loss, input_size, hidden_size, bidirectional=True)


Expand All @@ -811,43 +777,47 @@ def check_rnn_bidir_layer_gradients(mode, input_size, hidden_size, loss):
def test_fused_lstm_layer():
input_sizes = [8]
hidden_sizes = [8, 16]
for input_size, hidden_size in product(input_sizes, hidden_sizes):
num_layers = [1, 2, 3, 4]
for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers):
loss = mx.gluon.loss.L2Loss()
check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, loss)
check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, loss)
check_rnn_unidir_layer_gradients('lstm', input_size, hidden_size, num_layers, loss)
check_rnn_bidir_layer_gradients('lstm', input_size, hidden_size, num_layers, loss)


@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_fused_gru_layer():
input_sizes = [8]
hidden_sizes = [8, 16]
for input_size, hidden_size in product(input_sizes, hidden_sizes):
num_layers = [1, 2, 3, 4]
for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers):
loss = mx.gluon.loss.L2Loss()
check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, loss)
check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, loss)
check_rnn_unidir_layer_gradients('gru', input_size, hidden_size, num_layers, loss)
check_rnn_bidir_layer_gradients('gru', input_size, hidden_size, num_layers, loss)


@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_fused_rnnrelu_layer():
input_sizes = [8]
hidden_sizes = [8, 16]
for input_size, hidden_size in product(input_sizes, hidden_sizes):
num_layers = [1, 2, 3, 4]
for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers):
loss = mx.gluon.loss.L2Loss()
check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, loss)
check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, loss)
check_rnn_unidir_layer_gradients('rnn_relu', input_size, hidden_size, num_layers, loss)
check_rnn_bidir_layer_gradients('rnn_relu', input_size, hidden_size, num_layers, loss)


@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_fused_rnntanh_layer():
input_sizes = [8]
hidden_sizes = [8, 16]
for input_size, hidden_size in product(input_sizes, hidden_sizes):
num_layers = [1, 2, 3, 4]
for input_size, hidden_size, num_layers in product(input_sizes, hidden_sizes, num_layers):
loss = mx.gluon.loss.L2Loss()
check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss)
check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, loss)
check_rnn_unidir_layer_gradients('rnn_tanh', input_size, hidden_size, num_layers, loss)
check_rnn_bidir_layer_gradients('rnn_tanh', input_size, hidden_size, num_layers, loss)


@pytest.mark.serial
Expand Down

0 comments on commit 44200fd

Please sign in to comment.