Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
* impl - FFi for linalg op
Browse files Browse the repository at this point in the history
* fix - cpplint

* impl - benchmark ffi for ops

* rm - FFI for ops with param

* fix - makefile

* fix - not include unordered_map
  • Loading branch information
Ubuntu committed Mar 12, 2020
1 parent bd6e917 commit 2206d8b
Show file tree
Hide file tree
Showing 16 changed files with 330 additions and 15 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,9 @@ endif

all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages extension_libs

SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
SRC = $(wildcard src/*/*/*/*/*.cc src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
OBJ = $(patsubst %.cc, build/%.o, $(SRC))
CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu)
CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu)
CUOBJ = $(patsubst %.cu, build/%_gpu.o, $(CUSRC))

ifeq ($(USE_TVM_OP), 1)
Expand Down
5 changes: 5 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,11 @@ def prepare_workloads():
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("random.uniform", low=0, high=1, size=1)
OpArgMngr.add_workload("linalg.cholesky", pool['1x1'])
OpArgMngr.add_workload("linalg.eigvalsh", pool['1x1'], UPLO='L')
OpArgMngr.add_workload("linalg.pinv", pool['2x3x3'], pool['1'], hermitian=False)
OpArgMngr.add_workload("linalg.tensorinv", pool['1x1'], ind=2)
OpArgMngr.add_workload("linalg.tensorsolve", pool['1x1x1'], pool['1x1x1'], (2, 0, 1))


def benchmark_helper(f, *args, **kwargs):
Expand Down
15 changes: 7 additions & 8 deletions python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from . import _op as _mx_nd_np
from . import _internal as _npi
from . import _api_internal

__all__ = ['norm', 'svd', 'cholesky', 'inv', 'det', 'slogdet', 'solve', 'tensorinv', 'tensorsolve', 'pinv',
'eigvals', 'eig', 'eigvalsh', 'eigh']
Expand Down Expand Up @@ -91,9 +92,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
"""
if hermitian is True:
raise NotImplementedError("hermitian is not supported yet...")
if _mx_nd_np._np.isscalar(rcond):
return _npi.pinv_scalar_rcond(a, rcond, hermitian)
return _npi.pinv(a, rcond, hermitian)
return _api_internal.pinv(a, rcond, hermitian)


# pylint: disable=too-many-return-statements
Expand Down Expand Up @@ -332,7 +331,7 @@ def svd(a):
return tuple(_npi.svd(a))


def cholesky(a):
def cholesky(a, lower=True):
r"""
Cholesky decomposition.
Expand Down Expand Up @@ -388,7 +387,7 @@ def cholesky(a):
array([[16., 4.],
[ 4., 10.]])
"""
return _npi.cholesky(a)
return _api_internal.cholesky(a, lower)


def inv(a):
Expand Down Expand Up @@ -649,7 +648,7 @@ def tensorinv(a, ind=2):
>>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
True
"""
return _npi.tensorinv(a, ind)
return _api_internal.tensorinv(a, ind)


def tensorsolve(a, b, axes=None):
Expand Down Expand Up @@ -697,7 +696,7 @@ def tensorsolve(a, b, axes=None):
>>> np.allclose(np.tensordot(a, x, axes=3), b)
True
"""
return _npi.tensorsolve(a, b, axes)
return _api_internal.tensorsolve(a, b, axes)


def eigvals(a):
Expand Down Expand Up @@ -824,7 +823,7 @@ def eigvalsh(a, UPLO='L'):
>>> LA.eigvalsh(a, UPLO='L')
array([-2.87381886, 5.10144682, 6.38623114]) # in ascending order
"""
return _npi.eigvalsh(a, UPLO)
return _api_internal.eigvalsh(a, UPLO)


def eig(a):
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def svd(a):
return _mx_nd_np.linalg.svd(a)


def cholesky(a):
def cholesky(a, lower=True):
r"""
Cholesky decomposition.
Expand Down Expand Up @@ -288,7 +288,7 @@ def cholesky(a):
array([[16., 4.],
[ 4., 10.]])
"""
return _mx_nd_np.linalg.cholesky(a)
return _mx_nd_np.linalg.cholesky(a, lower)


def inv(a):
Expand Down
4 changes: 2 additions & 2 deletions python/mxnet/symbol/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def svd(a):
return _npi.svd(a)


def cholesky(a):
def cholesky(a, lower=True):
r"""
Cholesky decomposition.
Expand Down Expand Up @@ -378,7 +378,7 @@ def cholesky(a):
array([[16., 4.],
[ 4., 10.]])
"""
return _npi.cholesky(a)
return _npi.cholesky(a, lower)


def inv(a):
Expand Down
48 changes: 48 additions & 0 deletions src/api/operator/numpy/linalg/np_eigvals.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_eigvals.cc
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_eigvals.cc
*/

#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"
#include "../../../../operator/numpy/linalg/np_eigvals-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.eigvalsh")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_eigvalsh");
nnvm::NodeAttrs attrs;
op::EigvalshParam param;
param.UPLO = *((args[1].operator std::string()).c_str());
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::EigvalshParam>(&attrs);
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

} // namespace mxnet
72 changes: 72 additions & 0 deletions src/api/operator/numpy/linalg/np_pinv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_pinv.cc
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_pinv.cc
*/

#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"
#include "../../../../operator/numpy/linalg/np_pinv-inl.h"

namespace mxnet {

inline static void _npi_pinv(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_pinv");
op::PinvParam param;
nnvm::NodeAttrs attrs;
param.hermitian = args[2].operator bool();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::PinvParam>(&attrs);
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 2, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}

inline static void _npi_pinv_scalar_rcond(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_pinv_scalar_rcond");
op::PinvScalarRcondParam param;
nnvm::NodeAttrs attrs;
param.rcond = args[1].operator double();
param.hermitian = args[2].operator bool();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::PinvScalarRcondParam>(&attrs);
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}

MXNET_REGISTER_API("_npi.pinv")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) {
_npi_pinv_scalar_rcond(args, ret);
} else {
_npi_pinv(args, ret);
}
});

} // namespace mxnet
48 changes: 48 additions & 0 deletions src/api/operator/numpy/linalg/np_potrf.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_potrf.cc
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_potrf.cc
*/

#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"
#include "../../../../operator/numpy/linalg/np_potrf-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.cholesky")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_cholesky");
nnvm::NodeAttrs attrs;
op::LaCholeskyParam param;
param.lower = args[1].operator bool();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::LaCholeskyParam>(&attrs);
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

} // namespace mxnet
48 changes: 48 additions & 0 deletions src/api/operator/numpy/linalg/np_tensorinv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_tensorinv.cc
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_tensorinv.cc
*/

#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"
#include "../../../../operator/numpy/linalg/np_tensorinv-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.tensorinv")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_tensorinv");
nnvm::NodeAttrs attrs;
op::TensorinvParam param;
param.ind = args[1].operator int();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::TensorinvParam>(&attrs);
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, 1, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

} // namespace mxnet
Loading

0 comments on commit 2206d8b

Please sign in to comment.