diff --git a/benchmark/opperf/utils/ndarray_utils.py b/benchmark/opperf/utils/ndarray_utils.py index 7d9cd5256a2e..366e84e68f62 100644 --- a/benchmark/opperf/utils/ndarray_utils.py +++ b/benchmark/opperf/utils/ndarray_utils.py @@ -116,8 +116,8 @@ def get_mx_ndarray(ctx, in_tensor, dtype, initializer, attach_grad=True): tensor = nd.array(in_tensor, ctx=ctx, dtype=dtype) elif isinstance(in_tensor, np.ndarray): tensor = nd.array(in_tensor, ctx=ctx, dtype=dtype) - elif isinstance(in_tensor, mx.ndarray): - tensor = in_tensor.as_in_context(ctx=ctx).astype(dtype=dtype) + elif isinstance(in_tensor, nd.NDArray): + tensor = in_tensor.as_in_context(ctx).astype(dtype=dtype) else: raise ValueError("Invalid input type for creating input tensor. Input can be tuple() of shape or Numpy Array or" " MXNet NDArray. Given - ", in_tensor)