diff --git a/python/mxnet/autograd.py b/python/mxnet/autograd.py index 6f1cc4367821..f10b5ce3ed6d 100644 --- a/python/mxnet/autograd.py +++ b/python/mxnet/autograd.py @@ -360,6 +360,8 @@ def get_symbol(x): Symbol The retrieved Symbol. """ + assert isinstance(x, NDArray), \ + "get_symbol: Expecting %s, got %s"%(NDArray, type(x)) hdl = SymbolHandle() check_call(_LIB.MXAutogradGetSymbol(x.handle, ctypes.byref(hdl))) return Symbol(hdl)