Skip to content

Commit

Permalink
Update csr sample op (#39)
Browse files Browse the repository at this point in the history
* add csr_neighborhood op

* update neighborhood sample

* Update csr_neighborhood_sample-inl.h

* Update csr_neighborhood_sample-inl.h

* Update csr_neighborhood_sample.cc

* Update csr_neighborhood_sample-inl.h

* Update csr_neighborhood_sample.cc

* Update csr_neighborhood_sample-inl.h
  • Loading branch information
aksnzhy authored and zheng-da committed Nov 10, 2018
1 parent 2d92dd8 commit f966cd9
Show file tree
Hide file tree
Showing 2 changed files with 153 additions and 34 deletions.
150 changes: 123 additions & 27 deletions src/operator/contrib/csr_neighborhood_sample-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include <vector>
#include <cstdlib>
#include <ctime>
#include <map>
#include <unordered_map>
#include <algorithm>
#include <queue>
Expand All @@ -46,6 +47,7 @@ namespace op {

typedef int64_t dgl_id_t;

// Input
//------------------------------------------------------------------------------
// input[0]: Graph
// input[1]: seed_vertices
Expand All @@ -54,6 +56,12 @@ typedef int64_t dgl_id_t;
// args[2]: max_num_vertices
//------------------------------------------------------------------------------

// Output
//------------------------------------------------------------------------------
// output[0]: sampled_vertices
// output[1]: sampled_csr_graph
//------------------------------------------------------------------------------

// For BFS traversal
struct ver_node {
dgl_id_t vertex_id;
Expand Down Expand Up @@ -82,15 +90,17 @@ static bool CSRNeighborSampleStorageType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 2);

CHECK_EQ(in_attrs->at(0), mxnet::kCSRStorage);

CHECK_EQ(in_attrs->at(1), mxnet::kDefaultStorage);

bool success = true;
if (!type_assign(&(*out_attrs)[0], mxnet::kDefaultStorage)) {
success = false;
success = false;
}
if (!type_assign(&(*out_attrs)[1], mxnet::kCSRStorage)) {
success = false;
}

*dispatch_mode = DispatchMode::kFComputeEx;
Expand All @@ -102,7 +112,7 @@ static bool CSRNeighborSampleShape(const nnvm::NodeAttrs& attrs,
std::vector<TShape> *in_attrs,
std::vector<TShape> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
CHECK_EQ(out_attrs->size(), 2);

CHECK_EQ(in_attrs->at(0).ndim(), 2U);
CHECK_EQ(in_attrs->at(1).ndim(), 1U);
Expand All @@ -116,37 +126,52 @@ static bool CSRNeighborSampleShape(const nnvm::NodeAttrs& attrs,
out_shape[0] = params.max_num_vertices;
SHAPE_ASSIGN_CHECK(*out_attrs, 0, out_shape);

TShape out_csr_shape(2);
out_csr_shape[0] = params.max_num_vertices;
out_csr_shape[1] = in_attrs->at(0)[1];
SHAPE_ASSIGN_CHECK(*out_attrs, 1, out_csr_shape);

return out_attrs->at(0).ndim() != 0U &&
out_attrs->at(0).Size() != 0U;
out_attrs->at(0).Size() != 0U &&
out_attrs->at(1).ndim() != 0U &&
out_attrs->at(1).Size() != 0U;
}

static bool CSRNeighborSampleType(const nnvm::NodeAttrs& attrs,
std::vector<int> *in_attrs,
std::vector<int> *out_attrs) {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(out_attrs->size(), 1);
out_attrs->at(0) = in_attrs->at(0);
CHECK_EQ(out_attrs->size(), 2);

TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(1));
TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(0));

return out_attrs->at(0) != -1;
}

static void GetSrcList(const dgl_id_t* col_list,
static void GetSrcList(const dgl_id_t* val_list,
const dgl_id_t* col_list,
const dgl_id_t* indptr,
const dgl_id_t dst_id,
std::vector<dgl_id_t>& src_list) {
std::vector<dgl_id_t>& src_list,
std::vector<dgl_id_t>& edge_list) {
for (dgl_id_t i = *(indptr+dst_id); i < *(indptr+dst_id+1); ++i) {
src_list.push_back(col_list[i]);
edge_list.push_back(val_list[i]);
}
}

static void GetSample(std::vector<dgl_id_t>& ver_list,
std::vector<dgl_id_t>& edge_list,
const size_t max_num_neighbor,
std::vector<dgl_id_t>& out) {
std::vector<dgl_id_t>& out,
std::vector<dgl_id_t>& out_edge) {
CHECK_EQ(ver_list.size(), edge_list.size());
// Copy ver_list to output
if (ver_list.size() <= max_num_neighbor) {
for (size_t i = 0; i < ver_list.size(); ++i) {
out.push_back(ver_list[i]);
out_edge.push_back(edge_list[i]);
}
return;
}
Expand All @@ -163,6 +188,7 @@ static void GetSample(std::vector<dgl_id_t>& ver_list,
}
mp[rand_num] = true;
out.push_back(ver_list[rand_num]);
out_edge.push_back(edge_list[rand_num]);
sample_count++;
if (sample_count == max_num_neighbor) {
break;
Expand All @@ -176,7 +202,7 @@ static void CSRNeighborSampleComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
CHECK_EQ(inputs.size(), 2U);
CHECK_EQ(outputs.size(), 1U);
CHECK_EQ(outputs.size(), 2U);

const NeighborSampleParam& params =
nnvm::get<NeighborSampleParam>(attrs.parsed);
Expand All @@ -192,15 +218,16 @@ static void CSRNeighborSampleComputeExCPU(const nnvm::NodeAttrs& attrs,

CHECK_GE(max_num_vertices, seed_num);

const dgl_id_t* col_list = inputs[0].aux_data(1).dptr<dgl_id_t>();
const dgl_id_t* indptr = inputs[0].aux_data(0).dptr<dgl_id_t>();
const dgl_id_t* val_list = inputs[0].data().dptr<dgl_id_t>();
const dgl_id_t* col_list = inputs[0].aux_data(csr::kIdx).dptr<dgl_id_t>();
const dgl_id_t* indptr = inputs[0].aux_data(csr::kIndPtr).dptr<dgl_id_t>();
const dgl_id_t* seed = inputs[1].data().dptr<dgl_id_t>();

dgl_id_t* out = outputs[0].data().dptr<dgl_id_t>();

// BFS traverse the graph and sample vertices
dgl_id_t sub_vertices_count = 0;
std::unordered_map<dgl_id_t, bool> sub_ver_mp;
std::map<dgl_id_t, bool> sub_ver_mp;
std::queue<ver_node> node_queue;
// add seed vertices
for (size_t i = 0; i < seed_num; ++i) {
Expand All @@ -213,35 +240,60 @@ static void CSRNeighborSampleComputeExCPU(const nnvm::NodeAttrs& attrs,
}

std::vector<dgl_id_t> tmp_src_list;
std::vector<dgl_id_t> tmp_sampled_list;
std::vector<dgl_id_t> tmp_edge_list;
std::vector<dgl_id_t> tmp_sampled_src_list;
std::vector<dgl_id_t> tmp_sampled_edge_list;

std::map<dgl_id_t, std::vector<dgl_id_t> > ver_mp;
std::map<dgl_id_t, std::vector<dgl_id_t> > edge_mp;

while (!node_queue.empty() && sub_vertices_count < max_num_vertices) {
while (!node_queue.empty()) {
ver_node& cur_node = node_queue.front();
if (cur_node.level < num_hops) {

dgl_id_t dst_id = cur_node.vertex_id;
tmp_src_list.clear();
tmp_sampled_list.clear();
GetSrcList(col_list, indptr, dst_id, tmp_src_list);
GetSample(tmp_src_list, num_neighbor, tmp_sampled_list);
for (size_t i = 0; i < tmp_sampled_list.size(); ++i) {
auto got = sub_ver_mp.find(tmp_sampled_list[i]);
tmp_edge_list.clear();
tmp_sampled_src_list.clear();
tmp_sampled_edge_list.clear();

GetSrcList(val_list,
col_list,
indptr,
dst_id,
tmp_src_list,
tmp_edge_list);

GetSample(tmp_src_list,
tmp_edge_list,
num_neighbor,
tmp_sampled_src_list,
tmp_sampled_edge_list);

ver_mp[dst_id] = tmp_sampled_src_list;
edge_mp[dst_id] = tmp_sampled_edge_list;

sub_vertices_count++;
if (sub_vertices_count == max_num_vertices) {
break;
}

for (size_t i = 0; i < tmp_sampled_src_list.size(); ++i) {
auto got = sub_ver_mp.find(tmp_sampled_src_list[i]);
if (got == sub_ver_mp.end()) {
sub_ver_mp[tmp_sampled_src_list[i]] = true;
sub_vertices_count++;
sub_ver_mp[tmp_sampled_list[i]] = true;
ver_node new_node;
new_node.vertex_id = tmp_sampled_list[i];
new_node.vertex_id = tmp_sampled_src_list[i];
new_node.level = cur_node.level + 1;
node_queue.push(new_node);
}
if (sub_vertices_count >= max_num_vertices) {
break;
}
}
}
node_queue.pop();
}

// Copy sub_ver_mp to output
// Copy sub_ver_mp to output[0]
dgl_id_t idx = 0;
for (auto& data: sub_ver_mp) {
if (data.second) {
Expand All @@ -253,6 +305,50 @@ static void CSRNeighborSampleComputeExCPU(const nnvm::NodeAttrs& attrs,
for (dgl_id_t i = idx; i < max_num_vertices; ++i) {
*(out+i) = -1;
}

// Construct sub_csr_graph
std::vector<dgl_id_t> sub_val;
std::vector<dgl_id_t> sub_col_list;
std::vector<dgl_id_t> sub_indptr(max_num_vertices+1, 0);

size_t index = 1;
for (auto& data: sub_ver_mp) {
dgl_id_t dst_id = data.first;
auto edge = edge_mp.find(dst_id);
auto vert = ver_mp.find(dst_id);
if (edge != edge_mp.end() && vert != ver_mp.end()) {
CHECK_EQ(edge->second.size(), vert->second.size());
for (auto& val : edge->second) {
sub_val.push_back(val);
}
for (auto& val : vert->second) {
sub_col_list.push_back(val);
}
sub_indptr[index] = sub_indptr[index-1] + edge->second.size();
} else {
sub_indptr[index] = sub_indptr[index-1];
}
index++;
}

// Copy sub_csr_graph to output[1]
const NDArray& sub_csr = outputs[1];
TShape shape_1(1);
TShape shape_2(1);
shape_1[0] = sub_val.size();
shape_2[0] = sub_indptr.size();
sub_csr.CheckAndAllocData(shape_1);
sub_csr.CheckAndAllocAuxData(csr::kIdx, shape_1);
sub_csr.CheckAndAllocAuxData(csr::kIndPtr, shape_2);

dgl_id_t* val_list_out = sub_csr.data().dptr<dgl_id_t>();
dgl_id_t* col_list_out = sub_csr.aux_data(1).dptr<dgl_id_t>();
dgl_id_t* indptr_out = sub_csr.aux_data(0).dptr<dgl_id_t>();


std::copy(sub_val.begin(), sub_val.end(), val_list_out);
std::copy(sub_col_list.begin(), sub_col_list.end(), col_list_out);
std::copy(sub_indptr.begin(), sub_indptr.end(), indptr_out);
}

} // op
Expand Down
37 changes: 30 additions & 7 deletions src/operator/contrib/csr_neighborhood_sample.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ namespace mxnet {
namespace op {

/*
Usage:
import mxnet as mx
Expand All @@ -40,16 +39,40 @@ indptr_np = np.array([0, 4,8,12,16,20], dtype=np.int64)
a = mx.nd.sparse.csr_matrix((data_np, indices_np, indptr_np), shape=shape)
a.asnumpy()
seed = mx.nd.array([0,1,2,3,4], dtype=np.int64)
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=2, max_num_vertices=5)
out[0]
out[1].asnumpy()
seed = mx.nd.array([0], dtype=np.int64)
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=1, max_num_vertices=4)
out[0]
out[1].asnumpy()
seed = mx.nd.array([0], dtype=np.int64)
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=1, max_num_vertices=5)
out.asnumpy()
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=2, num_neighbor=1, max_num_vertices=4)
out[0]
out[1].asnumpy()
seed = mx.nd.array([0,2,4], dtype=np.int64)
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=2, max_num_vertices=5)
out.asnumpy()
out[0]
out[1].asnumpy()
seed = mx.nd.array([0,4], dtype=np.int64)
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=2, max_num_vertices=5)
out[0]
out[1].asnumpy()
seed = mx.nd.array([0,4], dtype=np.int64)
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=2, num_neighbor=2, max_num_vertices=5)
out[0]
out[1].asnumpy()
seed = mx.nd.array([0,4], dtype=np.int64)
out = mx.nd.contrib.neighbor_sample(a, seed, num_hops=1, num_neighbor=2, max_num_vertices=5)
out.asnumpy()
out[0]
out[1].asnumpy()
*/

Expand All @@ -59,7 +82,7 @@ NNVM_REGISTER_OP(_contrib_neighbor_sample)
.MXNET_DESCRIBE("")
.set_attr_parser(ParamParser<NeighborSampleParam>)
.set_num_inputs(2)
.set_num_outputs(1)
.set_num_outputs(2)
.set_attr<FInferStorageType>("FInferStorageType", CSRNeighborSampleStorageType)
.set_attr<nnvm::FInferShape>("FInferShape", CSRNeighborSampleShape)
.set_attr<nnvm::FInferType>("FInferType", CSRNeighborSampleType)
Expand All @@ -69,4 +92,4 @@ NNVM_REGISTER_OP(_contrib_neighbor_sample)
.add_arguments(NeighborSampleParam::__FIELDS__());

} // op
} // mxnet
} // mxnet

0 comments on commit f966cd9

Please sign in to comment.