Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
test for the new infer storage
Browse files Browse the repository at this point in the history
  • Loading branch information
Hao Jin committed Jul 10, 2018
1 parent 198cae3 commit 9a5e2c1
Showing 1 changed file with 21 additions and 1 deletion.
22 changes: 21 additions & 1 deletion tests/python/mkl/test_mkldnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import os
import numpy as np
import mxnet as mx
from mxnet.test_utils import assert_almost_equal
from mxnet.test_utils import rand_ndarray, assert_almost_equal
from mxnet import gluon
from mxnet.gluon import nn
from mxnet.test_utils import *
Expand Down Expand Up @@ -240,5 +240,25 @@ def check_batchnorm_training(stype):
for stype in stypes:
check_batchnorm_training(stype)


@with_seed()
def test_fullyconnected():
def check_fullyconnected_training(stype):
data_shape = rand_shape_nd(2)
weight_shape = rand_shape_nd(2)
weight_shape = (weight_shape[0], data_shape[1])
for density in [1.0, 0.5, 0.0]:
x = rand_ndarray(shape=data_shape, stype=stype, density=density)
w = rand_ndarray(shape=weight_shape, stype=stype, density=density)
x_sym = mx.sym.Variable("data")
w_sym = mx.sym.Variable("weight")
sym = mx.sym.FullyConnected(data=x_sym, weight=w_sym, num_hidden=weight_shape[0], no_bias=True)
in_location = [x, w]
check_numeric_gradient(sym, in_location, numeric_eps=1e-3, rtol=1e-3, atol=5e-3)
stypes = ['row_sparse', 'default']
for stype in stypes:
check_fullyconnected_training(stype)


if __name__ == '__main__':
test_mkldnn_install()

0 comments on commit 9a5e2c1

Please sign in to comment.