Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
enable test dropout with cudnn off
Browse files Browse the repository at this point in the history
  • Loading branch information
apeforest committed Jan 22, 2020
1 parent 3de5cae commit 109b5aa
Showing 1 changed file with 23 additions and 23 deletions.
46 changes: 23 additions & 23 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6955,8 +6955,9 @@ def test_stack():
check_numeric_gradient(out, inputs)


## TODO: test fails intermittently when cudnn on. temporarily disabled cudnn until gets fixed.
## tracked at https://github.com/apache/incubator-mxnet/issues/14288
@with_seed()
@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/14288")
def test_dropout():
def zero_count(array, ratio):
zeros = 0
Expand Down Expand Up @@ -7083,18 +7084,18 @@ def check_passthrough(ratio, shape, cudnn_off=True):
check_dropout_ratio(1.0, shape)
check_dropout_ratio(0.75, shape)
check_dropout_ratio(0.25, shape)
check_dropout_ratio(0.5, shape, cudnn_off=False)
check_dropout_ratio(0.0, shape, cudnn_off=False)
check_dropout_ratio(1.0, shape, cudnn_off=False)
check_dropout_ratio(0.75, shape, cudnn_off=False)
check_dropout_ratio(0.25, shape, cudnn_off=False)
# check_dropout_ratio(0.5, shape, cudnn_off=False)
# check_dropout_ratio(0.0, shape, cudnn_off=False)
# check_dropout_ratio(1.0, shape, cudnn_off=False)
# check_dropout_ratio(0.75, shape, cudnn_off=False)
# check_dropout_ratio(0.25, shape, cudnn_off=False)

check_passthrough(0.5, shape)
check_passthrough(0.0, shape)
check_passthrough(1.0, shape)
check_passthrough(0.5, shape, cudnn_off=False)
check_passthrough(0.0, shape, cudnn_off=False)
check_passthrough(1.0, shape, cudnn_off=False)
# check_passthrough(0.5, shape, cudnn_off=False)
# check_passthrough(0.0, shape, cudnn_off=False)
# check_passthrough(1.0, shape, cudnn_off=False)

nshape = (10, 10, 10, 10)
with mx.autograd.train_mode():
Expand All @@ -7111,20 +7112,19 @@ def check_passthrough(ratio, shape, cudnn_off=True):
check_dropout_axes(0.25, nshape, axes = (0, 1, 2))
check_dropout_axes(0.25, nshape, axes = (0, 2, 3))
check_dropout_axes(0.25, nshape, axes = (1, 2, 3))
check_dropout_axes(0.25, nshape, axes = (0,), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (1,), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (2,), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (3,), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 1), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 2), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 3), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (1, 2), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (1, 3), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (2, 3), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 1, 2), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (0, 2, 3), cudnn_off=False)
check_dropout_axes(0.25, nshape, axes = (1, 2, 3), cudnn_off=False)

# check_dropout_axes(0.25, nshape, axes = (0,), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (1,), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (2,), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (3,), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (0, 1), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (0, 2), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (0, 3), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (1, 2), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (1, 3), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (2, 3), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (0, 1, 2), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (0, 2, 3), cudnn_off=False)
# check_dropout_axes(0.25, nshape, axes = (1, 2, 3), cudnn_off=False)


@unittest.skip("test fails intermittently. temporarily disabled till it gets fixed. tracked at https://github.com/apache/incubator-mxnet/issues/11290")
Expand Down

0 comments on commit 109b5aa

Please sign in to comment.