diff --git a/python/mxnet/ndarray/numpy/linalg.py b/python/mxnet/ndarray/numpy/linalg.py index 74ba41f22979..05700e5ee09d 100644 --- a/python/mxnet/ndarray/numpy/linalg.py +++ b/python/mxnet/ndarray/numpy/linalg.py @@ -21,7 +21,7 @@ from . import _op as _mx_nd_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'tensorinv'] def norm(x, ord=None, axis=None, keepdims=False): @@ -352,3 +352,58 @@ def slogdet(a): (1., -1151.2925464970228) """ return _npi.slogdet(a) + + +def tensorinv(a, ind=2): + r""" + Compute the 'inverse' of an N-dimensional array. + + The result is an inverse for `a` relative to the tensordot operation + ``tensordot(a, b, ind)``, i. e., up to floating-point accuracy, + ``tensordot(tensorinv(a), a, ind)`` is the "identity" tensor for the + tensordot operation. + + Parameters + ---------- + a : array_like + Tensor to 'invert'. Its shape must be 'square', i. e., + ``prod(a.shape[:ind]) == prod(a.shape[ind:])``. + ind : int, optional + Number of first indices that are involved in the inverse sum. + Must be a positive integer, default is 2. + + Returns + ------- + b : ndarray + `a`'s tensordot inverse, shape ``a.shape[ind:] + a.shape[:ind]``. + + Raises + ------ + MXNetError + If `a` is singular or not 'square' (in the above sense). + + See Also + -------- + tensordot, tensorsolve + + Examples + -------- + >>> a = np.eye(4*6) + >>> a.shape = (4, 6, 8, 3) + >>> ainv = np.linalg.tensorinv(a, ind=2) + >>> ainv.shape + (8, 3, 4, 6) + >>> b = np.random.randn(4, 6) + >>> np.allclose(np.tensordot(ainv, b), np.linalg.tensorsolve(a, b)) + True + + >>> a = np.eye(4*6) + >>> a.shape = (24, 8, 3) + >>> ainv = np.linalg.tensorinv(a, ind=1) + >>> ainv.shape + (8, 3, 24) + >>> b = np.random.randn(24) + >>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) + True + """ + return _npi.tensorinv(a, ind) diff --git a/python/mxnet/numpy/linalg.py b/python/mxnet/numpy/linalg.py index fbe3631eb6e6..9d9ba53fbdbe 100644 --- a/python/mxnet/numpy/linalg.py +++ b/python/mxnet/numpy/linalg.py @@ -20,7 +20,7 @@ from __future__ import absolute_import from ..ndarray import numpy as _mx_nd_np -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'tensorinv'] def norm(x, ord=None, axis=None, keepdims=False): @@ -370,3 +370,58 @@ def slogdet(a): (1., -1151.2925464970228) """ return _mx_nd_np.linalg.slogdet(a) + + +def tensorinv(a, ind=2): + r""" + Compute the 'inverse' of an N-dimensional array. + + The result is an inverse for `a` relative to the tensordot operation + ``tensordot(a, b, ind)``, i. e., up to floating-point accuracy, + ``tensordot(tensorinv(a), a, ind)`` is the "identity" tensor for the + tensordot operation. + + Parameters + ---------- + a : array_like + Tensor to 'invert'. Its shape must be 'square', i. e., + ``prod(a.shape[:ind]) == prod(a.shape[ind:])``. + ind : int, optional + Number of first indices that are involved in the inverse sum. + Must be a positive integer, default is 2. + + Returns + ------- + b : ndarray + `a`'s tensordot inverse, shape ``a.shape[ind:] + a.shape[:ind]``. + + Raises + ------ + MXNetError + If `a` is singular or not 'square' (in the above sense). + + See Also + -------- + tensordot, tensorsolve + + Examples + -------- + >>> a = np.eye(4*6) + >>> a.shape = (4, 6, 8, 3) + >>> ainv = np.linalg.tensorinv(a, ind=2) + >>> ainv.shape + (8, 3, 4, 6) + >>> b = np.random.randn(4, 6) + >>> np.allclose(np.tensordot(ainv, b), np.linalg.tensorsolve(a, b)) + True + + >>> a = np.eye(4*6) + >>> a.shape = (24, 8, 3) + >>> ainv = np.linalg.tensorinv(a, ind=1) + >>> ainv.shape + (8, 3, 24) + >>> b = np.random.randn(24) + >>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) + True + """ + return _mx_nd_np.linalg.tensorinv(a, ind) diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py index 1f68ca3c522a..ba97ab4ae895 100644 --- a/python/mxnet/numpy_dispatch_protocol.py +++ b/python/mxnet/numpy_dispatch_protocol.py @@ -130,6 +130,7 @@ def _run_with_array_ufunc_proto(*args, **kwargs): 'linalg.norm', 'linalg.cholesky', 'linalg.inv', + 'linalg.tensorinv', 'shape', 'trace', 'tril', diff --git a/python/mxnet/symbol/numpy/linalg.py b/python/mxnet/symbol/numpy/linalg.py index cf33777b2637..76dec276a869 100644 --- a/python/mxnet/symbol/numpy/linalg.py +++ b/python/mxnet/symbol/numpy/linalg.py @@ -22,7 +22,7 @@ from . import _op as _mx_sym_np from . import _internal as _npi -__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet'] +__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'tensorinv'] def norm(x, ord=None, axis=None, keepdims=False): @@ -339,3 +339,58 @@ def slogdet(a): (1., -1151.2925464970228) """ return _npi.slogdet(a) + + +def tensorinv(a, ind=2): + r""" + Compute the 'inverse' of an N-dimensional array. + + The result is an inverse for `a` relative to the tensordot operation + ``tensordot(a, b, ind)``, i. e., up to floating-point accuracy, + ``tensordot(tensorinv(a), a, ind)`` is the "identity" tensor for the + tensordot operation. + + Parameters + ---------- + a : array_like + Tensor to 'invert'. Its shape must be 'square', i. e., + ``prod(a.shape[:ind]) == prod(a.shape[ind:])``. + ind : int, optional + Number of first indices that are involved in the inverse sum. + Must be a positive integer, default is 2. + + Returns + ------- + b : ndarray + `a`'s tensordot inverse, shape ``a.shape[ind:] + a.shape[:ind]``. + + Raises + ------ + MXNetError + If `a` is singular or not 'square' (in the above sense). + + See Also + -------- + tensordot, tensorsolve + + Examples + -------- + >>> a = np.eye(4*6) + >>> a.shape = (4, 6, 8, 3) + >>> ainv = np.linalg.tensorinv(a, ind=2) + >>> ainv.shape + (8, 3, 4, 6) + >>> b = np.random.randn(4, 6) + >>> np.allclose(np.tensordot(ainv, b), np.linalg.tensorsolve(a, b)) + True + + >>> a = np.eye(4*6) + >>> a.shape = (24, 8, 3) + >>> ainv = np.linalg.tensorinv(a, ind=1) + >>> ainv.shape + (8, 3, 24) + >>> b = np.random.randn(24) + >>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b)) + True + """ + return _npi.tensorinv(a, ind) diff --git a/src/operator/numpy/linalg/np_tensorinv-inl.h b/src/operator/numpy/linalg/np_tensorinv-inl.h new file mode 100644 index 000000000000..0cabe824c210 --- /dev/null +++ b/src/operator/numpy/linalg/np_tensorinv-inl.h @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_tensorinv-inl.h + * \brief Placeholder for tensor inverse + */ +#ifndef MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_ +#define MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_ + +#include +#include +#include "../../operator_common.h" +#include "../../mshadow_op.h" +#include "../../tensor/la_op.h" +#include "../../tensor/la_op-inl.h" + +namespace mxnet { +namespace op { + +using namespace mshadow; + +struct TensorinvParam : public dmlc::Parameter { + int ind; + DMLC_DECLARE_PARAMETER(TensorinvParam) { + DMLC_DECLARE_FIELD(ind) + .set_default(2) + .describe("Number of first indices that are involved in the inverse sum."); + } +}; + +template +void TensorinvOpForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 1U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + mshadow::Stream *s = ctx.get_stream(); + const mxnet::TBlob& a_tblob = inputs[0]; + const mxnet::TBlob& inv_a_tblob = outputs[0]; + const mxnet::TShape& a_shape = a_tblob.shape_; + CHECK_EQ(inv_a_tblob.type_flag_, a_tblob.type_flag_) + << "Binary function only support input/output with the same type"; + MSHADOW_SGL_DBL_TYPE_SWITCH( + outputs[0].type_flag_, + OType, { + const int ind = nnvm::get(attrs.parsed).ind; + dim_t prod_front = 1, prod_back = 1; + if (ind < a_shape.ndim()) { + for (int i = 0; i < ind; ++i) { + prod_front *= a_shape[i]; + } + for (int i = ind; i < a_shape.ndim(); ++i) { + prod_back *= a_shape[i]; + } + } else { + for (int i = 0; i < a_shape.ndim(); ++i) { + prod_front *= a_shape[i]; + } + } + Tensor A = + a_tblob.get_with_shape(Shape3(1, prod_back, prod_front), s); + Tensor inv_A = + inv_a_tblob.get_with_shape(Shape3(1, prod_back, prod_front), s); + inverse::op(A, inv_A, ctx, attrs); + }); +} + +template +void TensorinvOpBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + CHECK_EQ(inputs.size(), 2U); + CHECK_EQ(outputs.size(), 1U); + CHECK_EQ(req.size(), 1U); + + mshadow::Stream *s = ctx.get_stream(); + // const int axes = nnvm::get(attrs.parsed).ind; + const TBlob& out_grad = inputs[0]; + const TBlob& inv_a = inputs[1]; + const TBlob& grad_a = outputs[0]; + const TShape& inv_a_shape = inv_a.shape_; + MSHADOW_SGL_DBL_TYPE_SWITCH( + outputs[0].type_flag_, + OType, { + const int axes = nnvm::get(attrs.parsed).ind; + CHECK_LE(inv_a_shape.ndim(), 6U) + << "tensorinv backward only support tensor's dimension <= 6"; + if (axes < inv_a_shape.ndim()) { + const int axes1 = inv_a_shape.ndim() - axes, axes2 = axes; + TShape inv_a_transpose_shape(inv_a_shape.ndim(), -1); + for (int i = 0; i < axes; ++i) { + inv_a_transpose_shape[i] = inv_a_shape[i + inv_a_shape.ndim() - axes]; + } + for (int i = axes; i < inv_a_shape.ndim(); ++i) { + inv_a_transpose_shape[i] = inv_a_shape[i - axes]; + } + TShape temp_shape(2 * axes, -1); + for (int i = 0; i < axes; ++i) { + temp_shape[i] = inv_a_transpose_shape[i]; + temp_shape[i + axes] = inv_a_transpose_shape[i]; + } + Tensor workspace = + ctx.requested[0].get_space_typed(Shape1(temp_shape.Size() * sizeof(OType)), + ctx.get_stream()); + TBlob temp_tblob = + TBlob(reinterpret_cast(workspace.dptr_), temp_shape, xpu::kDevMask); + dim_t a1 = 1, a2 = 1; + for (int i = 0; i < axes2; ++i) { + a1 *= inv_a_transpose_shape[i]; + } + for (int i = 0; i < axes1; ++i) { + a2 *= inv_a_shape[i]; + } + Tensor inv_a_tensor = + inv_a.get_with_shape(Shape3(1, a2, a1), s); + Tensor out_grad_tensor = + out_grad.get_with_shape(Shape3(1, a2, a1), s); + Tensor temp_tensor = + temp_tblob.get_with_shape(Shape3(1, a1, a1), s); + Tensor grad_a_tensor = + grad_a.get_with_shape(Shape3(1, a1, a2), s); + gemm2::op(inv_a_tensor, out_grad_tensor, temp_tensor, OType(1), true, false, s); + gemm2::op(temp_tensor, inv_a_tensor, grad_a_tensor, OType(-1), false, true, s); + } else { // axes >= inv_a_shape.ndim() + dim_t a = 1; + for (int i = 0; i < inv_a_shape.ndim(); ++i) { + a *= inv_a_shape[i]; + } + // check again + CHECK_EQ(a, 1U) + << "a shape must be square, i. e., prod(a.shape[:ind]) == prod(a.shape[ind:])."; + Tensor inv_a_tensor = + inv_a.get_with_shape(Shape1(1), s); + Tensor out_grad_tensor = + out_grad.get_with_shape(Shape1(1), s); + Tensor grad_a_tensor = + grad_a.get_with_shape(Shape1(1), s); + ASSIGN_DISPATCH(grad_a_tensor, kWriteTo, + OType(-1) * inv_a_tensor * out_grad_tensor * inv_a_tensor); + } + }); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_NUMPY_LINALG_NP_TENSORINV_INL_H_ diff --git a/src/operator/numpy/linalg/np_tensorinv.cc b/src/operator/numpy/linalg/np_tensorinv.cc new file mode 100644 index 000000000000..2fee11c846b7 --- /dev/null +++ b/src/operator/numpy/linalg/np_tensorinv.cc @@ -0,0 +1,134 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2019 by Contributors + * \file np_tensorinv.cc + * \brief CPU implementation placeholder of Tensor Inverse Operator + */ +#include "./np_tensorinv-inl.h" + +namespace mxnet { +namespace op { + +inline bool TensorinvOpShape(const nnvm::NodeAttrs &attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + + const mxnet::TShape& a_shape = (*in_attrs)[0]; + const int a_ndim = a_shape.ndim(); + mxnet::TShape inv_a_shape(a_shape); + if (!ndim_is_known(a_shape)) { + return false; + } + // ind > 0, defalut = 2 + int ind = 2; + ind = nnvm::get(attrs.parsed).ind; + CHECK_GT(ind, 0) << "Invalid ind argument."; + + if (a_ndim > 0 && ind < a_ndim) { + for (int i = 0; i < ind; ++i) { + inv_a_shape[a_ndim - ind + i] = a_shape[i]; + } + for (int i = ind; i < a_ndim; ++i) { + inv_a_shape[i - ind] = a_shape[i]; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, inv_a_shape); + } else { // ind >= a_ndim + SHAPE_ASSIGN_CHECK(*out_attrs, 0, inv_a_shape); + } + CHECK_NE(inv_a_shape.ndim(), 0) + << "can not reshape array"; + + dim_t prod_front = 1, prod_back = 1; + if (ind < a_ndim) { + for (int i = 0; i < ind; ++i) { + prod_front *= a_shape[i]; + } + for (int i = ind; i < a_ndim; ++i) { + prod_back *= a_shape[i]; + } + CHECK_GT(prod_back, 0) + << "can not reshape array of size 0 into shape"; + } else { + for (int i = 0; i < a_ndim; ++i) { + prod_front *= a_shape[i]; + } + } + // prod_back >= 1 and prod_front == prod_back + CHECK_EQ(prod_front, prod_back) + << "a shape must be square, i. e., prod(a.shape[:ind]) == prod(a.shape[ind:])."; + return !mxnet::op::shape_is_none(out_attrs->at(0)); +} + +inline bool TensorinvOpType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 1U); + int a_type = in_attrs->at(0); + // unsupport float16 + CHECK_NE(a_type, mshadow::kFloat16) + << "array type float16 is unsupported in linalg"; + if (mshadow::kFloat32 == a_type) { + TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0)); + } else { + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat64); + } + return out_attrs->at(0) != -1; +} + +DMLC_REGISTER_PARAMETER(TensorinvParam); + +NNVM_REGISTER_OP(_npi_tensorinv) +.describe(R"code()code" ADD_FILELINE) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(1) +.set_num_outputs(1) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"a"}; +}) +.set_attr("FInferShape", TensorinvOpShape) +.set_attr("FInferType", TensorinvOpType) +.set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector(1, ResourceRequest::kTempSpace); +}) +.set_attr("THasDeterministicOutput", true) +.set_attr("FCompute", TensorinvOpForward) +.set_attr("FGradient", mxnet::op::ElemwiseGradUseOut{"_backward_npi_tensorinv"}) +.add_argument("a", "NDArray-or-Symbol", "First input") +.add_arguments(TensorinvParam::__FIELDS__()); + +NNVM_REGISTER_OP(_backward_npi_tensorinv) +.set_attr_parser(mxnet::op::ParamParser) +.set_num_inputs(2) +.set_num_outputs(1) +.set_attr("FResourceRequest", + [](const NodeAttrs& ){ + return std::vector{1, ResourceRequest::kTempSpace}; +}) +.set_attr("TIsBackward", true) +.set_attr("FCompute", TensorinvOpBackward); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/numpy/linalg/np_tensorinv.cu b/src/operator/numpy/linalg/np_tensorinv.cu new file mode 100644 index 000000000000..8cad95f40b3a --- /dev/null +++ b/src/operator/numpy/linalg/np_tensorinv.cu @@ -0,0 +1,43 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file np_tensorinv.cu + * \brief GPU implementation of the Tensor Inverse Operator + */ + +#include +#include +#include "./np_tensorinv-inl.h" + +namespace mxnet { +namespace op { + +#if MXNET_USE_CUSOLVER == 1 + +NNVM_REGISTER_OP(_npi_tensorinv) +.set_attr("FCompute", TensorinvOpForward); + +NNVM_REGISTER_OP(_backward_npi_tensorinv) +.set_attr("FCompute", TensorinvOpBackward); + +#endif + +} // namespace op +} // namespace mxnet diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py index 930ad5260430..c474ae7f360c 100644 --- a/tests/python/unittest/test_numpy_interoperability.py +++ b/tests/python/unittest/test_numpy_interoperability.py @@ -322,6 +322,31 @@ def _add_workload_linalg_det(): OpArgMngr.add_workload('linalg.det', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) +def _add_workload_linalg_tensorinv(): + shapes = [ + (1, 20, 4, 5), + (2, 2, 10, 4, 5), + (2, 12, 5, 3, 4, 5), + (3, 2, 3, 4, 24) + ] + dtypes = (np.float32, np.float64) + for dtype, shape in itertools.product(dtypes, shapes): + ind = shape[0] + prod_front = 1 + prod_back = 1 + for k in shape[1:ind + 1]: + prod_front *= k + for k in shape[1 + ind:]: + prod_back *= k + a_shape = (prod_back, prod_front) + a = _np.random.randn(*a_shape) + if prod_back == prod_front: + if _np.allclose(_np.dot(a, _np.linalg.inv(a)), _np.eye(prod_front)): + a_shape = shape[1:] + a = a.reshape(a_shape) + OpArgMngr.add_workload('linalg.tensorinv', np.array(a, dtype=dtype), ind) + + def _add_workload_linalg_slogdet(): OpArgMngr.add_workload('linalg.slogdet', np.array(_np.ones((2, 2)), dtype=np.float32)) OpArgMngr.add_workload('linalg.slogdet', np.array(_np.ones((0, 1, 1)), dtype=np.float64)) @@ -1353,6 +1378,7 @@ def _prepare_workloads(): _add_workload_linalg_cholesky() _add_workload_linalg_inv() _add_workload_linalg_det() + _add_workload_linalg_tensorinv() _add_workload_linalg_slogdet() _add_workload_trace() _add_workload_tril() diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 9b7f7036bcda..cb2cae5ab886 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -3271,7 +3271,7 @@ def check_cholesky(L, data_np): ] dtypes = ['float32', 'float64'] for hybridize, dtype, shape in itertools.product([True, False], dtypes, shapes): - atol = rtol = 1e-2 + atol = rtol = 1e-1 test_cholesky = TestCholesky() if hybridize: @@ -3293,7 +3293,7 @@ def check_cholesky(L, data_np): if 0 in shape: data_np = np.ones(shape) else: - data_np_l = _np.random.uniform(-10., 10., shape) + data_np_l = _np.random.uniform(-1., 1., shape) if dtype == 'float32': data_np_l_flat = data_np_l.reshape((-1, shape[-2], shape[-1])) else: @@ -3412,6 +3412,97 @@ def check_inv(A_inv, data_np): check_inv(A_inv, data_np) +@with_seed() +@use_np +def test_np_linalg_solve(): + class TestSolve(HybridBlock): + def __init__(self): + super(TestSolve, self).__init__() + + def hybrid_forward(self, F, a, b): + return F.np.linalg.solve(a, b) + + def check_solve(x, a_np, b_np): + try: + x_expected = _np.linalg.solve(a_np, b_np) + except Exception as e: + print("a:", a_np) + print("a shape:", a_np.shape) + print("b", b_np) + print("b shape:", b_np.shape) + print(e) + else: + assert x.shape == x_expected.shape + assert_almost_equal(x.asnumpy(), x_expected, rtol=rtol, atol=atol) + + def get_grad_b(A, X): + dX = _np.ones_like(X) + A_inv = _np.linalg.inv(A) + A_inv_trans = _np.swapaxes(A_inv, -1, -2) + return _np.matmul(A_inv_trans, dX) + + shapes = [ + (0, 0), + (1, 1), + (3, 3), + (20, 20), + (3, 20, 20), + (1, 0, 0), + (0, 1, 1), + (0, 5, 3, 3), + (5, 0, 0, 0), + (2, 3, 10, 10) + ] + nrhs = (-1, 0, 1, 2, 5) + dtypes = ['float32', 'float64'] + for hybridize, shape, dtype, nrh in itertools.product([False, True], shapes, dtypes, nrhs): + rtol = 1e-3 + atol = 1e-5 + test_solve = TestSolve() + if hybridize: + test_solve.hybridize() + + if 0 in shape: + a = _np.ones(shape) + b = _np.ones(shape) + else: + shape_a = shape + a = _np.random.rand(*shape_a) + shape_b = list(shape_a) + if nrh == -1: + shape_b[-1] = 1 + x = _np.random.rand(*shape_b) + b = _np.matmul(a, x) + shape_b.pop() + b = b.reshape(shape_b) + else : + shape_b[-1] = nrh + x = _np.random.rand(*shape_b) + b = _np.matmul(a, x) + a = np.array(a, dtype=dtype) + b = np.array(b, dtype=dtype) + a.attach_grad() + b.attach_grad() + with mx.autograd.record(): + mx_out = test_solve(a, b) + # check solve validity + assert mx_out.shape == b.shape + check_solve(mx_out, a, b) + + # check backward. backward does not support empty input + if 0 not in mx_out.shape: + if nrh != -1: + mx.autograd.backward(mx_out) + b_backward_expected = get_grad_b(a.asnumpy(), mx_out.asnumpy()) + a_backward_expected = -_np.matmul(b_backward_expected, _np.swapaxes(mx_out, -1, -2).asnumpy()) + assert_almost_equal(a.grad.asnumpy(), a_backward_expected, rtol=rtol, atol=atol) + assert_almost_equal(b.grad.asnumpy(), b_backward_expected, rtol=rtol, atol=atol) + + # check imperative once again + mx_out = np.linalg.solve(a, b) + check_solve(mx_out, a, b) + + @with_seed() @use_np def test_np_linalg_det():