diff --git a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc index e52009d7add7..66378d74f5d7 100644 --- a/src/runtime/contrib/dnnl/dnnl_json_runtime.cc +++ b/src/runtime/contrib/dnnl/dnnl_json_runtime.cc @@ -338,7 +338,7 @@ class DNNLJSONRuntime : public JSONRuntimeBase { auto data_entry = node.GetInputs()[0]; dnnl::memory::dims shape = nodes_[data_entry.id_].GetOpShape()[data_entry.index_]; - auto data_md = dnnl::memory::desc{{shape}, dt::f32, tag::abcd}; + dnnl::memory::desc data_md = GenDNNLMemDescByShape(shape, dt::f32); auto relu_desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_inference, dnnl::algorithm::eltwise_relu, data_md, 0); @@ -349,9 +349,8 @@ class DNNLJSONRuntime : public JSONRuntimeBase { net_.push_back(relu); auto data_memory = BindDNNLMemory(data_entry, data_md); - auto out_md = dnnl::memory::desc(shape, dt::f32, tag::abcd); JSONGraphNodeEntry out_entry(nid, 0); - auto out_memory = BindDNNLMemory(out_entry, out_md); + auto out_memory = BindDNNLMemory(out_entry, data_md); net_args_.push_back({{DNNL_ARG_SRC, data_memory}, {DNNL_ARG_DST, out_memory}}); } diff --git a/tests/python/relay/test_json_runtime.py b/tests/python/relay/test_json_runtime.py index 8107dc231adb..721271ac70f1 100644 --- a/tests/python/relay/test_json_runtime.py +++ b/tests/python/relay/test_json_runtime.py @@ -225,7 +225,7 @@ def test_relu(): dtype = "float32" shape = (1, 32, 14, 14) - def gen_relu(): + def gen_relu(shape): data0 = relay.var("data0", shape=shape, dtype=dtype) out = relay.nn.relu(data0) @@ -250,18 +250,22 @@ def gen_relu(): return mod, ref_mod - mod, ref_mod = gen_relu() + def check(shape): + mod, ref_mod = gen_relu(shape) + + data0 = np.random.uniform(-1, 1, shape).astype(dtype) + check_result( + mod, + ref_mod, + { + "data0": data0, + }, + shape, + tol=1e-5, + ) - data0 = np.random.uniform(-1, 1, shape).astype(dtype) - check_result( - mod, - ref_mod, - { - "data0": data0, - }, - (1, 32, 14, 14), - tol=1e-5, - ) + check(shape=(1, 32, 14, 14)) + check(shape=(1, 32)) def test_dense():