Skip to content

Commit

Permalink
* impl - FFi for linalg op (apache#17795)
Browse files Browse the repository at this point in the history
* fix - cpplint

* impl - benchmark ffi for ops

* rm - FFI for ops with param

* fix - makefile

* fix - not include unordered_map and use num_inputs

* ci - compiler error

* fix - change cholesky interface

Co-authored-by: Ubuntu <ubuntu@ip-172-31-10-214.us-east-2.compute.internal>
  • Loading branch information
2 people authored and Vladimir Cherepanov committed Apr 7, 2020
1 parent 886b90b commit bc1decd
Show file tree
Hide file tree
Showing 21 changed files with 440 additions and 17 deletions.
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,9 @@ endif

all: lib/libmxnet.a lib/libmxnet.so $(BIN) extra-packages extension_libs

SRC = $(wildcard src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
SRC = $(wildcard src/*/*/*/*/*.cc src/*/*/*/*.cc src/*/*/*.cc src/*/*.cc src/*.cc)
OBJ = $(patsubst %.cc, build/%.o, $(SRC))
CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu)
CUSRC = $(wildcard src/*/*/*/*.cu src/*/*/*/*.cu src/*/*/*.cu src/*/*.cu src/*.cu)
CUOBJ = $(patsubst %.cu, build/%_gpu.o, $(CUSRC))

ifeq ($(USE_TVM_OP), 1)
Expand Down
8 changes: 8 additions & 0 deletions benchmark/python/ffi/benchmark_ffi.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,14 @@ def prepare_workloads():
OpArgMngr.add_workload("cumsum", pool['3x2'], axis=0, out=pool['3x2'])
OpArgMngr.add_workload("add", pool['2x2'], pool['2x2'])
OpArgMngr.add_workload("linalg.svd", pool['3x3'])
OpArgMngr.add_workload("linalg.cholesky", pool['1x1'])
OpArgMngr.add_workload("linalg.eigvals", pool['1x1'])
OpArgMngr.add_workload("linalg.eigvalsh", pool['1x1'], UPLO='L')
OpArgMngr.add_workload("linalg.inv", pool['1x1'])
OpArgMngr.add_workload("linalg.pinv", pool['2x3x3'], pool['1'], hermitian=False)
OpArgMngr.add_workload("linalg.solve", pool['1x1'], pool['1'])
OpArgMngr.add_workload("linalg.tensorinv", pool['1x1'], ind=2)
OpArgMngr.add_workload("linalg.tensorsolve", pool['1x1x1'], pool['1x1x1'], (2, 0, 1))
OpArgMngr.add_workload("split", pool['3x3'], (0, 1, 2), axis=1)
OpArgMngr.add_workload("argmax", pool['3x2'], axis=-1)
OpArgMngr.add_workload("argmin", pool['3x2'], axis=-1)
Expand Down
18 changes: 8 additions & 10 deletions python/mxnet/ndarray/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,9 +92,7 @@ def pinv(a, rcond=1e-15, hermitian=False):
"""
if hermitian is True:
raise NotImplementedError("hermitian is not supported yet...")
if _mx_nd_np._np.isscalar(rcond):
return _npi.pinv_scalar_rcond(a, rcond, hermitian)
return _npi.pinv(a, rcond, hermitian)
return _api_internal.pinv(a, rcond, hermitian)


# pylint: disable=too-many-return-statements
Expand Down Expand Up @@ -389,7 +387,7 @@ def cholesky(a):
array([[16., 4.],
[ 4., 10.]])
"""
return _npi.cholesky(a)
return _api_internal.cholesky(a, True)


def inv(a):
Expand Down Expand Up @@ -431,7 +429,7 @@ def inv(a):
[[-1.2500001 , 0.75000006],
[ 0.75000006, -0.25000003]]])
"""
return _npi.inv(a)
return _api_internal.inv(a)


def det(a):
Expand Down Expand Up @@ -595,7 +593,7 @@ def solve(a, b):
>>> np.allclose(np.dot(a, x), b)
True
"""
return _npi.solve(a, b)
return _api_internal.solve(a, b)


def tensorinv(a, ind=2):
Expand Down Expand Up @@ -650,7 +648,7 @@ def tensorinv(a, ind=2):
>>> np.allclose(np.tensordot(ainv, b, 1), np.linalg.tensorsolve(a, b))
True
"""
return _npi.tensorinv(a, ind)
return _api_internal.tensorinv(a, ind)


def tensorsolve(a, b, axes=None):
Expand Down Expand Up @@ -698,7 +696,7 @@ def tensorsolve(a, b, axes=None):
>>> np.allclose(np.tensordot(a, x, axes=3), b)
True
"""
return _npi.tensorsolve(a, b, axes)
return _api_internal.tensorsolve(a, b, axes)


def eigvals(a):
Expand Down Expand Up @@ -766,7 +764,7 @@ def eigvals(a):
>>> LA.eigvals(A)
array([ 1., -1.]) # random
"""
return _npi.eigvals(a)
return _api_internal.eigvals(a)


def eigvalsh(a, UPLO='L'):
Expand Down Expand Up @@ -825,7 +823,7 @@ def eigvalsh(a, UPLO='L'):
>>> LA.eigvalsh(a, UPLO='L')
array([-2.87381886, 5.10144682, 6.38623114]) # in ascending order
"""
return _npi.eigvalsh(a, UPLO)
return _api_internal.eigvalsh(a, UPLO)


def eig(a):
Expand Down
2 changes: 1 addition & 1 deletion python/mxnet/symbol/numpy/linalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def cholesky(a):
array([[16., 4.],
[ 4., 10.]])
"""
return _npi.cholesky(a)
return _npi.cholesky(a, True)


def inv(a):
Expand Down
61 changes: 61 additions & 0 deletions src/api/operator/numpy/linalg/np_eigvals.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_eigvals.cc
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_eigvals.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"
#include "../../../../operator/numpy/linalg/np_eigvals-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.eigvals")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_eigvals");
nnvm::NodeAttrs attrs;
attrs.op = op;
int num_inputs = 1;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

MXNET_REGISTER_API("_npi.eigvalsh")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_eigvalsh");
nnvm::NodeAttrs attrs;
op::EigvalshParam param;
param.UPLO = *((args[1].operator std::string()).c_str());
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::EigvalshParam>(&attrs);
int num_inputs = 1;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

} // namespace mxnet
43 changes: 43 additions & 0 deletions src/api/operator/numpy/linalg/np_inv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_inv.cc
* \brief Implementation of the API of functions in src/operator/tensor/la_op.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.inv")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_inv");
nnvm::NodeAttrs attrs;
attrs.op = op;
int num_inputs = 1;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

} // namespace mxnet
73 changes: 73 additions & 0 deletions src/api/operator/numpy/linalg/np_pinv.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_pinv.cc
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_pinv.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"
#include "../../../../operator/numpy/linalg/np_pinv-inl.h"

namespace mxnet {

inline static void _npi_pinv(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_pinv");
op::PinvParam param;
nnvm::NodeAttrs attrs;
param.hermitian = args[2].operator bool();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::PinvParam>(&attrs);
int num_inputs = 2;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}

inline static void _npi_pinv_scalar_rcond(runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_pinv_scalar_rcond");
op::PinvScalarRcondParam param;
nnvm::NodeAttrs attrs;
param.rcond = args[1].operator double();
param.hermitian = args[2].operator bool();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::PinvScalarRcondParam>(&attrs);
int num_inputs = 1;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
}

MXNET_REGISTER_API("_npi.pinv")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
if (args[1].type_code() == kDLFloat || args[1].type_code() == kDLInt) {
_npi_pinv_scalar_rcond(args, ret);
} else {
_npi_pinv(args, ret);
}
});

} // namespace mxnet
48 changes: 48 additions & 0 deletions src/api/operator/numpy/linalg/np_potrf.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_potrf.cc
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_potrf.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"
#include "../../../../operator/numpy/linalg/np_potrf-inl.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.cholesky")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_cholesky");
nnvm::NodeAttrs attrs;
op::LaCholeskyParam param;
param.lower = args[1].operator bool();
attrs.parsed = param;
attrs.op = op;
SetAttrDict<op::LaCholeskyParam>(&attrs);
int num_inputs = 1;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

} // namespace mxnet
43 changes: 43 additions & 0 deletions src/api/operator/numpy/linalg/np_solve.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file np_solve.cc
* \brief Implementation of the API of functions in src/operator/numpy/linalg/np_solve.cc
*/
#include <mxnet/api_registry.h>
#include <mxnet/runtime/packed_func.h>
#include "../../utils.h"

namespace mxnet {

MXNET_REGISTER_API("_npi.solve")
.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
using namespace runtime;
const nnvm::Op* op = Op::Get("_npi_solve");
nnvm::NodeAttrs attrs;
attrs.op = op;
int num_inputs = 2;
int num_outputs = 0;
NDArray* inputs[] = {args[0].operator mxnet::NDArray*(), args[1].operator mxnet::NDArray*()};
auto ndoutputs = Invoke(op, &attrs, num_inputs, inputs, &num_outputs, nullptr);
*ret = reinterpret_cast<mxnet::NDArray*>(ndoutputs[0]);
});

} // namespace mxnet
Loading

0 comments on commit bc1decd

Please sign in to comment.