From 002d4f1d91aac28d28d3147ddd6479c4e7dce016 Mon Sep 17 00:00:00 2001 From: dw_sjtu <46704444+sjtuWangDing@users.noreply.github.com> Date: Tue, 7 Apr 2020 14:39:53 +0800 Subject: [PATCH] * impl - FFi for linalg op (#17795) * fix - cpplint * impl - benchmark ffi for ops * rm - FFI for ops with param * fix - makefile * fix - not include unordered_map and use num_inputs * ci - compiler error * fix - change cholesky interface Co-authored-by: Ubuntu --- Makefile | 4 +- benchmark/python/ffi/benchmark_ffi.py | 8 ++ python/mxnet/ndarray/numpy/linalg.py | 18 ++--- python/mxnet/symbol/numpy/linalg.py | 2 +- src/api/operator/numpy/linalg/np_eigvals.cc | 61 ++++++++++++++++ src/api/operator/numpy/linalg/np_inv.cc | 43 +++++++++++ src/api/operator/numpy/linalg/np_pinv.cc | 73 +++++++++++++++++++ src/api/operator/numpy/linalg/np_potrf.cc | 48 ++++++++++++ src/api/operator/numpy/linalg/np_solve.cc | 43 +++++++++++ src/api/operator/numpy/linalg/np_tensorinv.cc | 48 ++++++++++++ .../operator/numpy/linalg/np_tensorsolve.cc | 56 ++++++++++++++ src/api/operator/ufunc_helper.cc | 1 + src/api/operator/utils.cc | 5 ++ src/api/operator/utils.h | 5 +- src/operator/numpy/linalg/np_eigvals-inl.h | 6 ++ src/operator/numpy/linalg/np_pinv-inl.h | 14 ++++ src/operator/numpy/linalg/np_potrf.cc | 3 +- src/operator/numpy/linalg/np_tensorinv-inl.h | 6 ++ .../numpy/linalg/np_tensorsolve-inl.h | 6 ++ src/operator/tensor/la_op.h | 6 ++ tests/python/unittest/test_numpy_op.py | 1 - 21 files changed, 440 insertions(+), 17 deletions(-) create mode 100644 src/api/operator/numpy/linalg/np_eigvals.cc create mode 100644 src/api/operator/numpy/linalg/np_inv.cc create mode 100644 src/api/operator/numpy/linalg/np_pinv.cc create mode 100644 src/api/operator/numpy/linalg/np_potrf.cc create mode 100644 src/api/operator/numpy/linalg/np_solve.cc create mode 100644 src/api/operator/numpy/linalg/np_tensorinv.cc create mode 100644 src/api/operator/numpy/linalg/np_tensorsolve.cc diff --git a/Makefile b/Makefile index 49ba8fe00e7f..e5d6bb288134 100644 --- a/Makefile +++ b/Makefile @@ -464,9 +464,9 @@ endif all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages extension_libs -SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc) +SRC = $(wildcard src/*/*/*/*/*.cc src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc) OBJ = $(patsubst %.cc, build/%.o, $(SRC)) -CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu) +CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu) CUOBJ = $(patsubst %.cu, build/%_gpu.o, $(CUSRC)) ifeq ($(USE_TVM_OP), 1) diff --git a/benchmark/python/ffi/benchmark_ffi.py b/benchmark/python/ffi/benchmark_ffi.py index ee3fccfaa185..21209853dafb 100644 --- a/benchmark/python/ffi/benchmark_ffi.py +++ b/benchmark/python/ffi/benchmark_ffi.py @@ -62,6 +62,14 @@ def prepare_workloads(): OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2']) OpArgMngr.add_workload("add", pool['2x2'], pool['2x2']) OpArgMngr.add_workload("linalg.svd", pool['3x3']) + OpArgMngr.add_workload("linalg.cholesky", pool['1x1']) + OpArgMngr.add_workload("linalg.eigvals", pool['1x1']) + OpArgMngr.add_workload("linalg.eigvalsh", pool['1x1'], UPLO='L') + OpArgMngr.add_workload("linalg.inv", pool['1x1']) + OpArgMngr.add_workload("linalg.pinv", pool['2x3x3'], pool['1'], hermitian=False) + OpArgMngr.add_workload("linalg.solve", pool['1x1'], pool['1']) + OpArgMngr.add_workload("linalg.tensorinv", pool['1x1'], ind=2) + OpArgMngr.add_workload("linalg.tensorsolve", pool['1x1x1'], pool['1x1x1'], (2, 0, 1)) OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1) OpArgMngr.add_workload("argmax", pool['3x2'], axis=-1) OpArgMngr.add_workload("argmin", pool['3x2'], axis=-1) diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index fdcbdac2247a..9c2344770bfa 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -92,9 +92,7 @@ def pinv(a, rcond=1e-15, hermitian=False): """ if hermitian is True: raise NotImplementedError("hermitian is not supported yet...") - if _mx_nd_np._np.isscalar(rcond): - return _npi.pinv_scalar_rcond(a, rcond, hermitian) - return _npi.pinv(a, rcond, hermitian) + return _api_internal.pinv(a, rcond, hermitian) # pylint: disable=too-many-return-statements @@ -389,7 +387,7 @@ def cholesky(a): array([[16., 4.], [ 4., 10.]]) """ - return _npi.cholesky(a) + return _api_internal.cholesky(a, True) def inv(a): @@ -431,7 +429,7 @@ def inv(a): [[-1.2500001 , 0.75000006], [ 0.75000006, -0.25000003]]]) """ - return _npi.inv(a) + return _api_internal.inv(a) def det(a): @@ -595,7 +593,7 @@ def solve(a, b): >>> np.allclose(np.dot(a, x), b) True """ - return _npi.solve(a, b) + return _api_internal.solve(a, b) def tensorinv(a, ind=2): @@ -650,7 +648,7 @@ def tensorinv(a, ind=2): >>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) True """ - return _npi.tensorinv(a, ind) + return _api_internal.tensorinv(a, ind) def tensorsolve(a, b, axes=None): @@ -698,7 +696,7 @@ def tensorsolve(a, b, axes=None): >>> np.allclose(np.tensordot(a, x, axes=3), b) True """ - return _npi.tensorsolve(a, b, axes) + return _api_internal.tensorsolve(a, b, axes) def eigvals(a): @@ -766,7 +764,7 @@ def eigvals(a): >>> LA.eigvals(A) array([ 1., -1.]) # random """ - return _npi.eigvals(a) + return _api_internal.eigvals(a) def eigvalsh(a, UPLO='L'): @@ -825,7 +823,7 @@ def eigvalsh(a, UPLO='L'): >>> LA.eigvalsh(a, UPLO='L') array([-2.87381886, 5.10144682, 6.38623114]) # in ascending order """ - return _npi.eigvalsh(a, UPLO) + return _api_internal.eigvalsh(a, UPLO) def eig(a): diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index d326b37f0635..c05144abe4f5 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -378,7 +378,7 @@ def cholesky(a): array([[16., 4.], [ 4., 10.]]) """ - return _npi.cholesky(a) + return _npi.cholesky(a, True) def inv(a): diff --git a/src/api/operator/numpy/linalg/np_eigvals.cc b/src/api/operator/numpy/linalg/np_eigvals.cc new file mode 100644 index 000000000000..acde49f87b74 --- /dev/null +++ b/src/api/operator/numpy/linalg/np_eigvals.cc @@ -0,0 +1,61 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_eigvals.cc + * \brief Implementation of the API of functions in src/operator/numpy/linalg/np_eigvals.cc + */ +#include +#include +#include "../../utils.h" +#include "../../../../operator/numpy/linalg/np_eigvals-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.eigvals") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_eigvals"); + nnvm::NodeAttrs attrs; + attrs.op = op; + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +MXNET_REGISTER_API("_npi.eigvalsh") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_eigvalsh"); + nnvm::NodeAttrs attrs; + op::EigvalshParam param; + param.UPLO = *((args[1].operator std::string()).c_str()); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_inv.cc b/src/api/operator/numpy/linalg/np_inv.cc new file mode 100644 index 000000000000..238f666f29bd --- /dev/null +++ b/src/api/operator/numpy/linalg/np_inv.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_inv.cc + * \brief Implementation of the API of functions in src/operator/tensor/la_op.cc + */ +#include +#include +#include "../../utils.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.inv") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_inv"); + nnvm::NodeAttrs attrs; + attrs.op = op; + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_pinv.cc b/src/api/operator/numpy/linalg/np_pinv.cc new file mode 100644 index 000000000000..b14407c7b69f --- /dev/null +++ b/src/api/operator/numpy/linalg/np_pinv.cc @@ -0,0 +1,73 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_pinv.cc + * \brief Implementation of the API of functions in src/operator/numpy/linalg/np_pinv.cc + */ +#include +#include +#include "../../utils.h" +#include "../../../../operator/numpy/linalg/np_pinv-inl.h" + +namespace mxnet { + +inline static void _npi_pinv(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_pinv"); + op::PinvParam param; + nnvm::NodeAttrs attrs; + param.hermitian = args[2].operator bool(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 2; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +} + +inline static void _npi_pinv_scalar_rcond(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_pinv_scalar_rcond"); + op::PinvScalarRcondParam param; + nnvm::NodeAttrs attrs; + param.rcond = args[1].operator double(); + param.hermitian = args[2].operator bool(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +} + +MXNET_REGISTER_API("_npi.pinv") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) { + _npi_pinv_scalar_rcond(args, ret); + } else { + _npi_pinv(args, ret); + } +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_potrf.cc b/src/api/operator/numpy/linalg/np_potrf.cc new file mode 100644 index 000000000000..811ce74f8692 --- /dev/null +++ b/src/api/operator/numpy/linalg/np_potrf.cc @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_potrf.cc + * \brief Implementation of the API of functions in src/operator/numpy/linalg/np_potrf.cc + */ +#include +#include +#include "../../utils.h" +#include "../../../../operator/numpy/linalg/np_potrf-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.cholesky") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_cholesky"); + nnvm::NodeAttrs attrs; + op::LaCholeskyParam param; + param.lower = args[1].operator bool(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_solve.cc b/src/api/operator/numpy/linalg/np_solve.cc new file mode 100644 index 000000000000..d0d263881701 --- /dev/null +++ b/src/api/operator/numpy/linalg/np_solve.cc @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_solve.cc + * \brief Implementation of the API of functions in src/operator/numpy/linalg/np_solve.cc + */ +#include +#include +#include "../../utils.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.solve") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_solve"); + nnvm::NodeAttrs attrs; + attrs.op = op; + int num_inputs = 2; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_tensorinv.cc b/src/api/operator/numpy/linalg/np_tensorinv.cc new file mode 100644 index 000000000000..c3062eee637f --- /dev/null +++ b/src/api/operator/numpy/linalg/np_tensorinv.cc @@ -0,0 +1,48 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_tensorinv.cc + * \brief Implementation of the API of functions in src/operator/numpy/linalg/np_tensorinv.cc + */ +#include +#include +#include "../../utils.h" +#include "../../../../operator/numpy/linalg/np_tensorinv-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.tensorinv") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_tensorinv"); + nnvm::NodeAttrs attrs; + op::TensorinvParam param; + param.ind = args[1].operator int(); + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 1; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +} // namespace mxnet diff --git a/src/api/operator/numpy/linalg/np_tensorsolve.cc b/src/api/operator/numpy/linalg/np_tensorsolve.cc new file mode 100644 index 000000000000..5a50c22ea94e --- /dev/null +++ b/src/api/operator/numpy/linalg/np_tensorsolve.cc @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_tensorsolve.cc + * \brief Implementation of the API of functions in src/operator/numpy/linalg/np_tensorsolve.cc + */ +#include +#include +#include "../../utils.h" +#include "../../../../operator/numpy/linalg/np_tensorsolve-inl.h" + +namespace mxnet { + +MXNET_REGISTER_API("_npi.tensorsolve") +.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) { + using namespace runtime; + const nnvm::Op* op = Op::Get("_npi_tensorsolve"); + nnvm::NodeAttrs attrs; + op::TensorsolveParam param; + if (args[2].type_code() == kNull) { + param.a_axes = Tuple(); + } else { + if (args[2].type_code() == kDLInt) { + param.a_axes = Tuple(1, args[2].operator int64_t()); + } else { + param.a_axes = Tuple(args[2].operator ObjectRef()); + } + } + attrs.parsed = param; + attrs.op = op; + SetAttrDict(&attrs); + int num_inputs = 2; + int num_outputs = 0; + NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()}; + auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr); + *ret = reinterpret_cast(ndoutputs[0]); +}); + +} // namespace mxnet diff --git a/src/api/operator/ufunc_helper.cc b/src/api/operator/ufunc_helper.cc index 67bc68031417..787166df576a 100644 --- a/src/api/operator/ufunc_helper.cc +++ b/src/api/operator/ufunc_helper.cc @@ -23,6 +23,7 @@ */ #include "ufunc_helper.h" #include "utils.h" +#include "../../imperative/imperative_utils.h" namespace mxnet { diff --git a/src/api/operator/utils.cc b/src/api/operator/utils.cc index 79e94cffbf70..307bb290080b 100644 --- a/src/api/operator/utils.cc +++ b/src/api/operator/utils.cc @@ -22,9 +22,14 @@ * \brief Utility functions for operator invoke */ #include "utils.h" +#include "../../imperative/imperative_utils.h" namespace mxnet { +bool is_recording() { + return Imperative::Get()->is_recording(); +} + void SetInOut(std::vector* ndinputs, std::vector* ndoutputs, int num_inputs, diff --git a/src/api/operator/utils.h b/src/api/operator/utils.h index 49ee6bf2c9af..53e62ee7635b 100644 --- a/src/api/operator/utils.h +++ b/src/api/operator/utils.h @@ -28,7 +28,6 @@ #include #include #include -#include "../../imperative/imperative_utils.h" namespace mxnet { @@ -48,9 +47,11 @@ std::vector Invoke(const nnvm::Op* op, int* num_outputs, NDArray** outputs); +bool is_recording(); + template void SetAttrDict(nnvm::NodeAttrs* attrs) { - if (Imperative::Get()->is_recording()) { + if (is_recording()) { ::dmlc::get(attrs->parsed).SetAttrDict(&(attrs->dict)); } } diff --git a/src/operator/numpy/linalg/np_eigvals-inl.h b/src/operator/numpy/linalg/np_eigvals-inl.h index 81b46d237206..26b351ac8eab 100644 --- a/src/operator/numpy/linalg/np_eigvals-inl.h +++ b/src/operator/numpy/linalg/np_eigvals-inl.h @@ -27,6 +27,7 @@ #include #include +#include #include "../../operator_common.h" #include "../../mshadow_op.h" #include "../../tensor/la_op.h" @@ -312,6 +313,11 @@ struct EigvalshParam : public dmlc::Parameter { .set_default('L') .describe("Specifies whether the calculation is done with the lower or upper triangular part."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream UPLO_s; + UPLO_s << UPLO; + (*dict)["UPLO"] = UPLO_s.str(); + } }; template diff --git a/src/operator/numpy/linalg/np_pinv-inl.h b/src/operator/numpy/linalg/np_pinv-inl.h index 76bcc9a1ab64..b3b8e0c76c64 100644 --- a/src/operator/numpy/linalg/np_pinv-inl.h +++ b/src/operator/numpy/linalg/np_pinv-inl.h @@ -27,6 +27,7 @@ #include #include +#include #include #include "../../operator_common.h" #include "../../mshadow_op.h" @@ -48,6 +49,11 @@ struct PinvParam : public dmlc::Parameter { .set_default(false) .describe("If True, A is assumed to be Hermitian (symmetric if real-valued)."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream hermitian_s; + hermitian_s << hermitian; + (*dict)["hermitian"] = hermitian_s.str(); + } }; struct PinvScalarRcondParam : public dmlc::Parameter { @@ -61,6 +67,14 @@ struct PinvScalarRcondParam : public dmlc::Parameter { .set_default(false) .describe("If True, A is assumed to be Hermitian (symmetric if real-valued)."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream rcond_s; + std::ostringstream hermitian_s; + rcond_s << rcond; + hermitian_s << hermitian; + (*dict)["rcond"] = rcond_s.str(); + (*dict)["hermitian"] = hermitian_s.str(); + } }; template diff --git a/src/operator/numpy/linalg/np_potrf.cc b/src/operator/numpy/linalg/np_potrf.cc index cad2b3084c21..40e900872365 100644 --- a/src/operator/numpy/linalg/np_potrf.cc +++ b/src/operator/numpy/linalg/np_potrf.cc @@ -56,7 +56,8 @@ NNVM_REGISTER_OP(_npi_cholesky) { return std::vector>{{0, 0}}; }) .set_attr("FCompute", LaOpForward) .set_attr("FGradient", ElemwiseGradUseOut{"_backward_linalg_potrf"}) -.add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices to be decomposed"); +.add_argument("A", "NDArray-or-Symbol", "Tensor of input matrices to be decomposed") +.add_arguments(LaCholeskyParam::__FIELDS__()); } // namespace op } // namespace mxnet diff --git a/src/operator/numpy/linalg/np_tensorinv-inl.h b/src/operator/numpy/linalg/np_tensorinv-inl.h index 4f92ccf9d125..414c3f09ec45 100644 --- a/src/operator/numpy/linalg/np_tensorinv-inl.h +++ b/src/operator/numpy/linalg/np_tensorinv-inl.h @@ -27,6 +27,7 @@ #include #include +#include #include "../../operator_common.h" #include "../../mshadow_op.h" #include "../../tensor/la_op.h" @@ -44,6 +45,11 @@ struct TensorinvParam : public dmlc::Parameter { .set_default(2) .describe("Number of first indices that are involved in the inverse sum."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream ind_s; + ind_s << ind; + (*dict)["ind"] = ind_s.str(); + } }; template diff --git a/src/operator/numpy/linalg/np_tensorsolve-inl.h b/src/operator/numpy/linalg/np_tensorsolve-inl.h index 829a119b64a2..bbde4d40434a 100644 --- a/src/operator/numpy/linalg/np_tensorsolve-inl.h +++ b/src/operator/numpy/linalg/np_tensorsolve-inl.h @@ -27,6 +27,7 @@ #include #include +#include #include "../../operator_common.h" #include "../../mshadow_op.h" #include "../../tensor/la_op.h" @@ -46,6 +47,11 @@ struct TensorsolveParam : public dmlc::Parameter { .set_default(mxnet::Tuple()) .describe("Tuple of ints, optional. Axes in a to reorder to the right, before inversion."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream a_axes_s; + a_axes_s << a_axes; + (*dict)["a_axes"] = a_axes_s.str(); + } }; // Fix negative axes. diff --git a/src/operator/tensor/la_op.h b/src/operator/tensor/la_op.h index e15390ecde5a..cf80f28cb8b2 100644 --- a/src/operator/tensor/la_op.h +++ b/src/operator/tensor/la_op.h @@ -29,6 +29,7 @@ #include #include #include +#include #include "../mshadow_op.h" #include "../mxnet_op.h" #include "../operator_common.h" @@ -91,6 +92,11 @@ struct LaCholeskyParam : public dmlc::Parameter { .describe ("True if the triangular matrix is lower triangular, false if it is upper triangular."); } + void SetAttrDict(std::unordered_map* dict) { + std::ostringstream lower_s; + lower_s << lower; + (*dict)["lower"] = lower_s.str(); + } }; // Parameters for matrix-matrix multiplication where one is a triangular matrix. diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 25b2098a0796..62f9f56088ac 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -5910,7 +5910,6 @@ def hybrid_forward(self, F, a): assert mx_out.shape == np_out.shape assert_almost_equal(mx_out.asnumpy(), np_out, rtol=1e-1, atol=1e-1) if grad_req != 'null': - print(shape, grad_req) mx_out.backward() # Test imperative once again