Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

[MXNET-771] Fix Flaky Test test_executor.py:test_dot #11978

Merged
merged 4 commits into from
Aug 2, 2018
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 11 additions & 15 deletions tests/python/unittest/test_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,7 @@
import numpy as np
import mxnet as mx
from common import setup_module, with_seed, teardown


def reldiff(a, b):
diff = np.sum(np.abs(a - b))
norm = np.sum(np.abs(a))
reldiff = diff / norm
return reldiff
from mxnet.test_utils import assert_almost_equal


def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None):
Expand Down Expand Up @@ -64,18 +58,18 @@ def check_bind_with_uniform(uf, gf, dim, sf=None, lshape=None, rshape=None):
out1 = uf(lhs_arr.asnumpy(), rhs_arr.asnumpy())
out3 = exec3.outputs[0].asnumpy()
out4 = exec4.outputs[0].asnumpy()
assert reldiff(out1, out2) < 1e-6
assert reldiff(out1, out3) < 1e-6
assert reldiff(out1, out4) < 1e-6
assert_almost_equal(out1, out2, rtol=1e-5, atol=1e-5)
assert_almost_equal(out1, out3, rtol=1e-5, atol=1e-5)
assert_almost_equal(out1, out4, rtol=1e-5, atol=1e-5)
# test gradient
out_grad = mx.nd.array(np.ones(out2.shape))
lhs_grad2, rhs_grad2 = gf(out_grad.asnumpy(),
lhs_arr.asnumpy(),
rhs_arr.asnumpy())
executor.backward([out_grad])

assert reldiff(lhs_grad.asnumpy(), lhs_grad2) < 1e-6
assert reldiff(rhs_grad.asnumpy(), rhs_grad2) < 1e-6
assert_almost_equal(lhs_grad.asnumpy(), lhs_grad2, rtol=1e-5, atol=1e-5)
assert_almost_equal(rhs_grad.asnumpy(), rhs_grad2, rtol=1e-5, atol=1e-5)


@with_seed(0)
Expand Down Expand Up @@ -118,20 +112,22 @@ def check_bind(disable_bulk_exec):
check_bind(False)


@with_seed(0)
# @roywei: Removing fixed seed as flakiness in this test is fixed
# tracked at https://github.com/apache/incubator-mxnet/issues/11686
@with_seed()
def test_dot():
nrepeat = 10
maxdim = 4
for repeat in range(nrepeat):
s =tuple(np.random.randint(1, 500, size=3))
s =tuple(np.random.randint(1, 200, size=3))
check_bind_with_uniform(lambda x, y: np.dot(x, y),
lambda g, x, y: (np.dot(g, y.T), np.dot(x.T, g)),
2,
lshape=(s[0], s[1]),
rshape=(s[1], s[2]),
sf = mx.symbol.dot)
for repeat in range(nrepeat):
s =tuple(np.random.randint(1, 500, size=1))
s =tuple(np.random.randint(1, 200, size=1))
check_bind_with_uniform(lambda x, y: np.dot(x, y),
lambda g, x, y: (g * y, g * x),
2,
Expand Down