Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Revert test_rnnrelu_sym to origin
Browse files Browse the repository at this point in the history
  • Loading branch information
zixuanweeei committed Oct 22, 2019
1 parent d9f4025 commit 843853e
Showing 1 changed file with 10 additions and 16 deletions.
26 changes: 10 additions & 16 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,23 +258,17 @@ def test_rnntanh_bidirectional():
@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
def test_rnnrelu_sym():
if default_context().device_type == 'gpu':
print("Skip test `rnn_relu_sym` on gpu. This is tracked by https://github.com/apache/incubator-mxnet/issues/16548")
return
Ts = [1, 5]
Ns = [1, 32]
Is = [32, 128, 512]
Hs = [32, 128, 512]
for T, N, I, H in itertools.product(Ts, Ns, Is, Hs):
fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='')
stack = mx.rnn.SequentialRNNCell()
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_'))
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_'))
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_'))
T, N, I, H = 5, 32, 200, 200

check_rnn_consistency(fused, stack, T, N, I, H, 'write')
check_rnn_consistency(fused, stack, T, N, I, H, 'add')
check_rnn_consistency(fused, stack, T, N, I, H, 'null')
fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='')
stack = mx.rnn.SequentialRNNCell()
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_'))
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_'))
stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_'))

check_rnn_consistency(fused, stack, T, N, I, H, 'write')
check_rnn_consistency(fused, stack, T, N, I, H, 'add')
check_rnn_consistency(fused, stack, T, N, I, H, 'null')

@with_seed()
@assert_raises_cudnn_not_satisfied(min_version='5.1.10')
Expand Down

0 comments on commit 843853e

Please sign in to comment.