Skip to content

Commit

Permalink
update test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Aug 4, 2018
1 parent ebf26a9 commit 371d2eb
Showing 1 changed file with 4 additions and 7 deletions.
11 changes: 4 additions & 7 deletions nnvm/tests/python/compiler/test_op_fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def build_and_run(sym, params, data, out_shape, target, ctx, opt_level=2):
module.set_input("data", data)
module.run()
out = module.get_output(0, tvm.nd.empty(out_shape))
return out.asnumpy(), graph
return out.asnumpy()


def test_fuse_conv2d_elu():
Expand All @@ -112,12 +112,9 @@ def get_sym(out_channel):
sym2 = get_sym(out_channel)
_, params1 = utils.create_workload(sym1, 1, dshape[1:], seed=0)
_, params2 = utils.create_workload(sym2, 1, dshape[1:], seed=0)
print(params1.keys())
print(params2.keys())
print("Running on target", target)
output, g1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2)
output2, g2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0)
np.testing.assert_allclose(output, output2, rtol=1e-5, atol=1e-5)
output1 = build_and_run(sym1, params1, data, oshape, target, ctx, opt_level=2)
output2 = build_and_run(sym2, params2, data, oshape, target, ctx, opt_level=0)
np.testing.assert_allclose(output1, output2, rtol=1e-5, atol=1e-5)

if __name__ == "__main__":
test_injective_reduce_injective()
Expand Down

0 comments on commit 371d2eb

Please sign in to comment.