From 843853e3472e6d2b27963d04452702d97155befb Mon Sep 17 00:00:00 2001 From: Zixuan Wei Date: Tue, 22 Oct 2019 22:19:13 +0800 Subject: [PATCH] Revert test_rnnrelu_sym to origin --- tests/python/unittest/test_operator.py | 26 ++++++++++---------------- 1 file changed, 10 insertions(+), 16 deletions(-) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index aae82086eb65..7ea106b2620f 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -258,23 +258,17 @@ def test_rnntanh_bidirectional(): @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10') def test_rnnrelu_sym(): - if default_context().device_type == 'gpu': - print("Skip test `rnn_relu_sym` on gpu. This is tracked by https://github.com/apache/incubator-mxnet/issues/16548") - return - Ts = [1, 5] - Ns = [1, 32] - Is = [32, 128, 512] - Hs = [32, 128, 512] - for T, N, I, H in itertools.product(Ts, Ns, Is, Hs): - fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='') - stack = mx.rnn.SequentialRNNCell() - stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_')) - stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_')) - stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_')) + T, N, I, H = 5, 32, 200, 200 - check_rnn_consistency(fused, stack, T, N, I, H, 'write') - check_rnn_consistency(fused, stack, T, N, I, H, 'add') - check_rnn_consistency(fused, stack, T, N, I, H, 'null') + fused = mx.rnn.FusedRNNCell(H, num_layers=3, mode='rnn_relu', get_next_state=True, prefix='') + stack = mx.rnn.SequentialRNNCell() + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l0_')) + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l1_')) + stack.add(mx.rnn.RNNCell(H, activation='relu', prefix='l2_')) + + check_rnn_consistency(fused, stack, T, N, I, H, 'write') + check_rnn_consistency(fused, stack, T, N, I, H, 'add') + check_rnn_consistency(fused, stack, T, N, I, H, 'null') @with_seed() @assert_raises_cudnn_not_satisfied(min_version='5.1.10')