Skip to content

Commit c109eb6

Browse files
reminisceeric-haibin-lin
authored andcommitted
sparse_retain op (#66)
* Initial checkin * Fix bugs * Add unit test for sparse_retain * Add example and modify test
1 parent 5c0685b commit c109eb6

File tree

6 files changed

+301
-20
lines changed

6 files changed

+301
-20
lines changed

include/mxnet/ndarray.h

+11-11
Original file line numberDiff line numberDiff line change
@@ -115,44 +115,44 @@ class NDArray {
115115
}
116116
/*! \brief constructor for NDArray with storage type
117117
*/
118-
NDArray(const NDArrayStorageType storage_type, const TShape &shape, Context ctx,
118+
NDArray(const NDArrayStorageType stype, const TShape &shape, Context ctx,
119119
bool delay_alloc = true, int dtype = mshadow::default_type_flag,
120120
std::vector<int> aux_types = {}, std::vector<TShape> aux_shapes = {},
121121
TShape storage_shape = TShape(mshadow::Shape1(0)))
122122
: shape_(shape), offset_(0), dtype_(dtype), entry_({nullptr, 0, 0}) {
123123
// Assign default aux types if not given
124124
if (aux_types.size() == 0) {
125-
if (storage_type == kRowSparseStorage) {
125+
if (stype == kRowSparseStorage) {
126126
aux_types = {ROW_SPARSE_IDX_TYPE};
127-
} else if (storage_type == kCSRStorage) {
127+
} else if (stype == kCSRStorage) {
128128
aux_types = {CSR_IND_PTR_TYPE, CSR_IDX_DTYPE};
129129
} else {
130-
LOG(FATAL) << "Unknown storage type" << storage_type;
130+
LOG(FATAL) << "Unknown storage type " << stype;
131131
}
132132
}
133133
// Assign default shapes if not given
134134
// unknown shapes are intialized as {0} such that Size() would return 0
135135
if (aux_shapes.size() == 0) {
136-
if (storage_type == kRowSparseStorage) {
136+
if (stype == kRowSparseStorage) {
137137
aux_shapes = {TShape(mshadow::Shape1(0))};
138-
} else if (storage_type == kCSRStorage) {
138+
} else if (stype == kCSRStorage) {
139139
// aux shapes for indptr and indices
140140
aux_shapes = {TShape(mshadow::Shape1(0)), TShape(mshadow::Shape1(0))};
141141
} else {
142-
LOG(FATAL) << "Unknown storage type" << storage_type;
142+
LOG(FATAL) << "Unknown storage type " << stype;
143143
}
144144
}
145145
if (storage_shape.Size() == 0) {
146-
if (storage_type == kRowSparseStorage) {
146+
if (stype == kRowSparseStorage) {
147147
storage_shape = shape;
148148
storage_shape[0] = aux_shapes[rowsparse::kIdx][0];
149-
} else if (storage_type == kCSRStorage) {
149+
} else if (stype == kCSRStorage) {
150150
storage_shape = aux_shapes[csr::kIdx];
151151
} else {
152-
LOG(FATAL) << "Unknown storage type" << storage_type;
152+
LOG(FATAL) << "Unknown storage type " << stype;
153153
}
154154
}
155-
ptr_ = std::make_shared<Chunk>(storage_type, storage_shape, ctx, delay_alloc,
155+
ptr_ = std::make_shared<Chunk>(stype, storage_shape, ctx, delay_alloc,
156156
dtype, aux_types, aux_shapes);
157157
#if MKL_EXPERIMENTAL == 1
158158
Mkl_mem_ = std::make_shared<MKLMemHolder>();

python/mxnet/test_utils.py

+21-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import numpy.random as rnd
1717
import mxnet as mx
1818
from .context import Context
19-
from .ndarray import array
19+
from .ndarray import array, _STORAGE_TYPE_STR_TO_ID
2020
from .symbol import Symbol
2121
try:
2222
import requests
@@ -67,6 +67,15 @@ def random_arrays(*shapes):
6767
return arrays[0]
6868
return arrays
6969

70+
71+
def random_sample(population, k):
72+
"""Return a k length list of the elements chosen from the population sequence."""
73+
assert 0 <= k <= len(population)
74+
population_copy = population[:]
75+
np.random.shuffle(population_copy)
76+
return population_copy[0:k]
77+
78+
7079
# TODO(haibin) also include types in arguments
7180
def rand_sparse_ndarray(shape, storage_type, density=None):
7281
"""Generate a random sparse ndarray. Returns the ndarray, value(np) and indices(np) """
@@ -457,7 +466,8 @@ def numeric_grad(executor, location, aux_states=None, eps=1e-4, use_forward_trai
457466

458467

459468
def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rtol=1e-2,
460-
atol=None, grad_nodes=None, use_forward_train=True, ctx=None):
469+
atol=None, grad_nodes=None, use_forward_train=True, ctx=None,
470+
grad_stype_dict=None):
461471
"""Verify an operation by checking backward pass via finite difference method.
462472
463473
Based on Theano's `theano.gradient.verify_grad` [1]
@@ -474,7 +484,7 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto
474484
- if type is dict of str -> numpy.ndarray
475485
maps the name of arguments to the corresponding numpy.ndarray.
476486
*In either case, value of all the arguments must be provided.*
477-
aux_states : ist or tuple or dict, optional
487+
aux_states : list or tuple or dict, optional
478488
The auxiliary states required when generating the executor for the symbol.
479489
numeric_eps : float, optional
480490
Delta for the finite difference method that approximates the gradient.
@@ -486,6 +496,8 @@ def check_numeric_gradient(sym, location, aux_states=None, numeric_eps=1e-3, rto
486496
Whether to use is_train=True when computing the finite-difference.
487497
ctx : Context, optional
488498
Check the gradient computation on the specified device.
499+
grad_stype_dict : dict of str->str, optional
500+
Storage type dictionary for gradient ndarrays.
489501
References
490502
---------
491503
..[1] https://github.com/Theano/Theano/blob/master/theano/gradient.py
@@ -509,7 +521,7 @@ def random_projection(shape):
509521
location_npy = {k:v.asnumpy() for k, v in location.items()}
510522
aux_states = _parse_aux_states(sym=sym, aux_states=aux_states, ctx=ctx)
511523
if aux_states is not None:
512-
aux_states_npy = {k:v.asnumpy() for k, v in aux_states.items()}
524+
aux_states_npy = {k: v.asnumpy() for k, v in aux_states.items()}
513525
else:
514526
aux_states_npy = None
515527
if grad_nodes is None:
@@ -536,6 +548,11 @@ def random_projection(shape):
536548
+ [("__random_proj", _rng.normal(0, 0.01, size=out_shape[0]))])
537549

538550
args_grad = {k: mx.nd.array(v, ctx=ctx) for k, v in args_grad_npy.items()}
551+
if grad_stype_dict is not None:
552+
assert isinstance(grad_stype_dict, dict), "grad_stype_dict must be a dict"
553+
for k, v in grad_stype_dict.items():
554+
if k in args_grad and v in _STORAGE_TYPE_STR_TO_ID and v != 'default':
555+
args_grad[k] = mx.nd.cast_storage(args_grad[k], storage_type=v)
539556

540557
executor = out.bind(ctx, grad_req=grad_req,
541558
args=location, args_grad=args_grad, aux_states=aux_states)

src/operator/tensor/indexing_op.cc

+41
Original file line numberDiff line numberDiff line change
@@ -264,5 +264,46 @@ Examples::
264264
.add_argument("indices", "NDArray-or-Symbol", "array of locations where to set on_value")
265265
.add_arguments(OneHotParam::__FIELDS__());
266266

267+
NNVM_REGISTER_OP(sparse_retain)
268+
.describe(R"code(pick rows specified by user input index array from a row sparse matrix
269+
and save them in the output sparse matrix.
270+
271+
Example::
272+
273+
data = [[1, 2], [3, 4], [5, 6]]
274+
indices = [0, 1, 3]
275+
shape = (4, 2)
276+
rsp_in = row_sparse(data, indices)
277+
to_retain = [0, 3]
278+
rsp_out = sparse_retain(rsp_in, to_retain)
279+
rsp_out.values = [[1, 2], [5, 6]]
280+
rsp_out.indices = [0, 3]
281+
282+
)code" ADD_FILELINE)
283+
.set_num_inputs(2)
284+
.set_num_outputs(1)
285+
.set_attr<nnvm::FListInputNames>("FListInputNames",
286+
[](const NodeAttrs& attrs) {
287+
return std::vector<std::string>{"data", "indices"};
288+
})
289+
.set_attr<nnvm::FInferShape>("FInferShape", SparseRetainOpShape)
290+
.set_attr<nnvm::FInferType>("FInferType", SparseRetainOpType)
291+
.set_attr<nnvm::FInferStorageType>("FInferStorageType", SparseRetainForwardInferStorageType)
292+
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseRetainOpForwardEx<cpu>)
293+
.set_attr<nnvm::FGradient>("FGradient",
294+
[](const nnvm::NodePtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
295+
return MakeNonlossGradNode("_backward_sparse_retain", n, ograds,
296+
{n->inputs[sr::kIdx]}, n->attrs.dict);
297+
})
298+
.add_argument("data", "NDArray-or-Symbol", "The input array for sparse_retain operator.")
299+
.add_argument("indices", "NDArray-or-Symbol", "The index array of rows ids that will be retained.");
300+
301+
NNVM_REGISTER_OP(_backward_sparse_retain)
302+
.set_num_inputs(2)
303+
.set_num_outputs(2)
304+
.set_attr<nnvm::TIsBackward>("TIsBackward", true)
305+
.set_attr<nnvm::FInferStorageType>("FInferStorageType", SparseRetainBackwardInferStorageType)
306+
.set_attr<FComputeEx>("FComputeEx<cpu>", SparseRetainOpBackwardEx<cpu>);
307+
267308
} // namespace op
268309
} // namespace mxnet

src/operator/tensor/indexing_op.cu

+6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,12 @@ NNVM_REGISTER_OP(batch_take)
2626
NNVM_REGISTER_OP(one_hot)
2727
.set_attr<FCompute>("FCompute<gpu>", OneHotOpForward<gpu>);
2828

29+
NNVM_REGISTER_OP(sparse_retain)
30+
.set_attr<FComputeEx>("FComputeEx<gpu>", SparseRetainOpForwardEx<gpu>);
31+
32+
NNVM_REGISTER_OP(_backward_sparse_retain)
33+
.set_attr<FComputeEx>("FComputeEx<gpu>", SparseRetainOpBackwardEx<gpu>);
34+
2935
} // namespace op
3036
} // namespace mxnet
3137

src/operator/tensor/indexing_op.h

+193
Original file line numberDiff line numberDiff line change
@@ -795,6 +795,199 @@ void OneHotOpForward(const nnvm::NodeAttrs& attrs,
795795
});
796796
}
797797

798+
/*!
799+
* \brief sparse retain namespace
800+
*/
801+
namespace sr {
802+
enum SparseRetainOpInputs {kArr, kIdx};
803+
enum SparseRetainOpOutputs {kOut};
804+
} // namespace sr
805+
806+
inline bool SparseRetainOpShape(const nnvm::NodeAttrs& attrs,
807+
std::vector<TShape> *in_attrs,
808+
std::vector<TShape> *out_attrs) {
809+
CHECK_EQ(in_attrs->size(), 2U)
810+
<< "sparse_retain operator takes 2 arguments (" << in_attrs->size() << " given)";
811+
CHECK_EQ(out_attrs->size(), 1U);
812+
813+
TShape tshape((*in_attrs)[sr::kArr]);
814+
shape_assign(&tshape, (*out_attrs)[sr::kOut]);
815+
SHAPE_ASSIGN_CHECK(*in_attrs, sr::kArr, tshape);
816+
SHAPE_ASSIGN_CHECK(*out_attrs, sr::kOut, tshape);
817+
return true;
818+
}
819+
820+
inline bool SparseRetainOpType(const nnvm::NodeAttrs& attrs,
821+
std::vector<int> *in_attrs,
822+
std::vector<int> *out_attrs) {
823+
CHECK_EQ(in_attrs->size(), 2U);
824+
CHECK_EQ(out_attrs->size(), 1U);
825+
CHECK_NE((*in_attrs)[sr::kIdx], -1) << "Index type must be set for sparse_retain operator";
826+
827+
TYPE_ASSIGN_CHECK(*out_attrs, 0, (*in_attrs)[sr::kArr]);
828+
TYPE_ASSIGN_CHECK(*in_attrs, 0, (*out_attrs)[sr::kOut]);
829+
return (*in_attrs)[0] != -1;
830+
}
831+
832+
inline bool SparseRetainForwardInferStorageType(const nnvm::NodeAttrs& attrs,
833+
std::vector<int> *in_attrs,
834+
std::vector<int> *out_attrs) {
835+
CHECK_EQ(in_attrs->size(), 2U);
836+
CHECK_EQ(out_attrs->size(), 1U);
837+
if (kRowSparseStorage == in_attrs->at(sr::kArr)) {
838+
out_attrs->at(sr::kOut) = kRowSparseStorage;
839+
}
840+
return true;
841+
}
842+
843+
inline bool SparseRetainBackwardInferStorageType(const nnvm::NodeAttrs& attrs,
844+
std::vector<int> *in_attrs,
845+
std::vector<int> *out_attrs) {
846+
CHECK_EQ(in_attrs->size(), 2U);
847+
CHECK_EQ(out_attrs->size(), 2U);
848+
out_attrs->at(sr::kArr) = kRowSparseStorage;
849+
out_attrs->at(sr::kIdx) = kDefaultStorage;
850+
return true;
851+
}
852+
853+
struct SparseRetainRspForward {
854+
template<typename DType, typename RType, typename IType>
855+
MSHADOW_XINLINE static void Map(int i, DType* out_data, RType* out_idx,
856+
const DType* in_data, const RType* in_idx,
857+
const IType* idx, const size_t nnr,
858+
const size_t num_cols) {
859+
const RType irow = idx[i];
860+
int j = -1, left = 0, right = nnr - 1;
861+
while (left <= right) {
862+
int m = left + (right - left) / 2;
863+
const auto in_idx_m = in_idx[m];
864+
if (in_idx_m == irow) {
865+
j = m;
866+
break;
867+
} else if (in_idx_m < irow) {
868+
left = m + 1;
869+
} else {
870+
right = m - 1;
871+
}
872+
}
873+
out_idx[i] = idx[i];
874+
if (j >= 0) {
875+
const size_t in_offset = j * num_cols;
876+
const size_t out_offset = i * num_cols;
877+
for (size_t k = 0; k < num_cols; ++k) {
878+
out_data[out_offset+k] = in_data[in_offset+k];
879+
}
880+
}
881+
}
882+
};
883+
884+
template<typename xpu>
885+
void SparseRetainOpForwardEx(const nnvm::NodeAttrs& attrs,
886+
const OpContext& ctx,
887+
const std::vector<NDArray>& inputs,
888+
const std::vector<OpReqType>& req,
889+
const std::vector<NDArray>& outputs) {
890+
CHECK_EQ(inputs.size(), 2U);
891+
CHECK_EQ(outputs.size(), 1U);
892+
CHECK_EQ(req.size(), 1U);
893+
CHECK_EQ(req[sr::kOut], kWriteTo) << "sparse_retain only supports req=\'write\'";
894+
895+
CHECK_EQ(inputs[sr::kArr].storage_type(), kRowSparseStorage)
896+
<< "sparse_retain operator only takes row sparse NDArray as input";
897+
CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage)
898+
<< "sparse_retain operator only takes default NDArray as its index array";
899+
CHECK_EQ(outputs[sr::kOut].storage_type(), kRowSparseStorage)
900+
<< "sparse_retain operator only outputs row sparse NDArray";
901+
902+
const NDArray& input_nd = inputs[sr::kArr];
903+
const TBlob idx_data = inputs[sr::kIdx].data();
904+
905+
if (req[sr::kOut] == kNullOp
906+
|| !input_nd.storage_initialized()
907+
|| idx_data.Size() == 0U) return;
908+
909+
const TBlob input_data = input_nd.data();
910+
if (input_data.shape_[0] == 0) return;
911+
const TBlob input_idx = input_nd.aux_data(rowsparse::kIdx);
912+
913+
NDArray output_nd = outputs[sr::kOut];
914+
output_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())});
915+
TBlob output_data = output_nd.data();
916+
TBlob output_idx = output_nd.aux_data(rowsparse::kIdx);
917+
918+
using namespace mxnet_op;
919+
Stream<xpu> *s = ctx.get_stream<xpu>();
920+
MSHADOW_TYPE_SWITCH(output_data.type_flag_, DType, { // output data type
921+
MSHADOW_INT_TYPE_SWITCH(output_idx.type_flag_, RType, { // row index data type
922+
MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type
923+
Kernel<set_zero, xpu>::Launch(s, output_data.Size(), output_data.dptr<DType>());
924+
Kernel<SparseRetainRspForward, xpu>::Launch(s, idx_data.Size(), output_data.dptr<DType>(),
925+
output_idx.dptr<RType>(), input_data.dptr<DType>(), input_idx.dptr<RType>(),
926+
idx_data.dptr<IType>(), input_data.shape_[0], input_data.shape_[1]);
927+
});
928+
});
929+
});
930+
}
931+
932+
template<int req>
933+
struct SparseRetainRspBackward {
934+
template<typename DType, typename RType, typename IType>
935+
MSHADOW_XINLINE static void Map(int i, DType* in_grad, RType* in_grad_idx,
936+
const DType* out_grad, const IType* idx,
937+
const size_t num_cols) {
938+
const RType irow = idx[i];
939+
in_grad_idx[i] = irow;
940+
const size_t out_offset = irow * num_cols;
941+
const size_t in_offset = i * num_cols;
942+
for (size_t j = 0; j < num_cols; ++j) {
943+
KERNEL_ASSIGN(in_grad[in_offset+j], req, out_grad[out_offset+j]);
944+
}
945+
}
946+
};
947+
948+
template<typename xpu>
949+
void SparseRetainOpBackwardEx(const nnvm::NodeAttrs& attrs,
950+
const OpContext& ctx,
951+
const std::vector<NDArray>& inputs,
952+
const std::vector<OpReqType>& req,
953+
const std::vector<NDArray>& outputs) {
954+
CHECK_EQ(inputs.size(), 2U);
955+
CHECK_EQ(outputs.size(), 2U);
956+
CHECK_EQ(req.size(), 2U);
957+
CHECK_NE(req[sr::kArr], kWriteInplace);
958+
CHECK_EQ(req[sr::kIdx], kNullOp)
959+
<< "sparse_retain does not support calculating gradients of indices";
960+
961+
CHECK_EQ(inputs[sr::kOut].storage_type(), kDefaultStorage)
962+
<< "sparse_retain backward only takes default NDArray as ograd";
963+
CHECK_EQ(inputs[sr::kIdx].storage_type(), kDefaultStorage)
964+
<< "sparse_retain backward only takes default NDArray as its index array";
965+
CHECK_EQ(outputs[sr::kArr].storage_type(), kRowSparseStorage)
966+
<< "sparse_retain backward only outputs row sparse NDArray as grad of input";
967+
968+
const TBlob out_grad_data = inputs[sr::kOut].data();
969+
const TBlob idx_data = inputs[sr::kIdx].data();
970+
971+
NDArray in_grad_nd = outputs[sr::kArr];
972+
in_grad_nd.CheckAndAlloc({mshadow::Shape1(idx_data.Size())});
973+
TBlob in_grad_data = in_grad_nd.data();
974+
TBlob in_grad_idx = in_grad_nd.aux_data(rowsparse::kIdx);
975+
976+
using namespace mxnet_op;
977+
Stream<xpu> *s = ctx.get_stream<xpu>();
978+
MSHADOW_TYPE_SWITCH(out_grad_data.type_flag_, DType, { // output data type
979+
MSHADOW_INT_TYPE_SWITCH(in_grad_idx.type_flag_, RType, { // row index data type
980+
MSHADOW_TYPE_SWITCH(idx_data.type_flag_, IType, { // index array data type
981+
MXNET_ASSIGN_REQ_SWITCH(req[sr::kArr], req_type, {
982+
Kernel<SparseRetainRspBackward<req_type>, xpu>::Launch(
983+
s, in_grad_idx.Size(), in_grad_data.dptr<DType>(), in_grad_idx.dptr<RType>(),
984+
out_grad_data.dptr<DType>(), idx_data.dptr<IType>(), out_grad_data.shape_[1]);
985+
});
986+
});
987+
});
988+
});
989+
}
990+
798991
} // namespace op
799992
} // namespace mxnet
800993
#ifdef __CUDACC__

0 commit comments

Comments
 (0)