Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
hgt312 committed Oct 21, 2019
1 parent 746cbc5 commit e51d1fd
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 2 deletions.
12 changes: 11 additions & 1 deletion python/mxnet/ndarray/numpy/_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming', 'blackman', 'flip',
'around', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril', 'identity', 'take',
'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal',
'hsplit', 'rot90', 'einsum', 'true_divide']
'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory', 'may_share_memory']


@set_module('mxnet.ndarray.numpy')
Expand Down Expand Up @@ -4761,3 +4761,13 @@ def einsum(*operands, **kwargs):
subscripts = operands[0]
operands = operands[1:]
return _npi.einsum(*operands, subscripts=subscripts, out=out, optimize=int(optimize_arg))


@set_module('mxnet.ndarray.numpy')
def shares_memory(a, b, max_work=None):
return _npi.share_memory(a, b).item()


@set_module('mxnet.ndarray.numpy')
def may_share_memory(a, b, max_work=None):
return _npi.share_memory(a, b).item()
13 changes: 12 additions & 1 deletion python/mxnet/numpy/multiarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@
'swapaxes', 'clip', 'argmax', 'std', 'var', 'indices', 'copysign', 'ravel', 'hanning', 'hamming',
'blackman', 'flip', 'around', 'arctan2', 'hypot', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'tril',
'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'equal', 'not_equal', 'greater', 'less',
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide']
'greater_equal', 'less_equal', 'hsplit', 'rot90', 'einsum', 'true_divide', 'shares_memory',
'may_share_memory']

# Return code for dispatching indexing function call
_NDARRAY_UNSUPPORTED_INDEXING = -1
Expand Down Expand Up @@ -6419,3 +6420,13 @@ def einsum(*operands, **kwargs):
... np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=True)
"""
return _mx_nd_np.einsum(*operands, **kwargs)


@set_module('mxnet.numpy')
def shares_memory(a, b, max_work=None):
return _mx_nd_np.shares_memory(a, b, max_work)


@set_module('mxnet.numpy')
def may_share_memory(a, b, max_work=None):
return _mx_nd_np.may_share_memory(a, b, max_work)
62 changes: 62 additions & 0 deletions src/operator/numpy/np_memory_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_memory_op.cc
*/

#include "./np_memory_op.h"

namespace mxnet {
namespace op {

inline bool NumpyShareMemoryType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kBool);
return out_attrs->at(0) != -1;
}

inline bool NumpyShareMemoryShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
CHECK_EQ(in_attrs->size(), 2U);
CHECK_EQ(out_attrs->size(), 1U);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape(0, -1));
return true;
}

NNVM_REGISTER_OP(_npi_share_memory)
.set_num_inputs(2)
.set_num_outputs(1)
.set_attr<nnvm::FListInputNames>("FListInputNames",
[](const NodeAttrs& attrs) {
return std::vector<std::string>{"a", "b"};
})
.set_attr<mxnet::FInferShape>("FInferShape", NumpyShareMemoryShape)
.set_attr<nnvm::FInferType>("FInferType", NumpyShareMemoryType)
.set_attr<FCompute>("FCompute<cpu>", NumpyShareMemoryCompute<cpu>)
.add_argument("a", "NDArray-or-Symbol", "First input")
.add_argument("b", "NDArray-or-Symbol", "Second input");

} // namespace op
} // namespace mxnet
34 changes: 34 additions & 0 deletions src/operator/numpy/np_memory_op.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_memory_op.cu
*/

#include "./np_memory_op.h"

namespace mxnet {
namespace op {

NNVM_REGISTER_OP(_npi_share_memory)
.set_attr<FCompute>("FCompute<gpu>", NumpyShareMemoryCompute<gpu>);

} // namespace op
} // namespace mxnet
69 changes: 69 additions & 0 deletions src/operator/numpy/np_memory_op.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* Copyright (c) 2019 by Contributors
* \file np_memory_op.h
* \brief Function definition of numpy memory op
*/

#ifndef MXNET_OPERATOR_NUMPY_NP_MEMORY_OP_H_
#define MXNET_OPERATOR_NUMPY_NP_MEMORY_OP_H_

#include <vector>
#include <string>
#include "../operator_common.h"

namespace mxnet {
namespace op {

template<typename xpu>
void NumpyShareMemoryCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<TBlob>& inputs,
const std::vector<OpReqType>& req,
const std::vector<TBlob>& outputs) {
using namespace mxnet_op;
using namespace mshadow;
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
const TBlob& a = inputs[0];
const TBlob& b = inputs[1];
const TBlob& outdata = outputs[0];

if (a.Size() == 0 || b.Size() == 0) {
*(outdata.dptr<bool>()) = false;
return;
}
uint64_t start1 = reinterpret_cast<uint64_t>(a.dptr_);
uint64_t end1 = start1 + a.Size();
uint64_t start2 = reinterpret_cast<uint64_t>(b.dptr_);
uint64_t end2 = start2 + b.Size();
if (!(start1 < end2 && start2 < end1 && start1 < end1 && start2 < end2)) {
*(outdata.dptr<bool>()) = false;
} else {
*(outdata.dptr<bool>()) = true;
}
return;
}

} // namespace op
} // namespace mxnet

#endif // MXNET_OPERATOR_NUMPY_NP_MEMORY_OP_H_

0 comments on commit e51d1fd

Please sign in to comment.