Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] FindEdge/FindEdges for Immutable Graph #404

Merged
merged 26 commits into from
Feb 27, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 48 additions & 8 deletions include/dgl/immutable_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,30 @@ class ImmutableGraph: public GraphInterface {
dgl_id_t edge_id;
};

// Edge list indexed by edge id;
struct EdgeList {
typedef std::shared_ptr<EdgeList> Ptr;
std::vector<dgl_id_t> src_points;
std::vector<dgl_id_t> dst_points;

EdgeList(int64_t len, dgl_id_t val) {
src_points.resize(len, val);
dst_points.resize(len, val);
}

void register_edge(dgl_id_t eid, dgl_id_t src, dgl_id_t dst) {
CHECK_LT(eid, src_points.size()) << "Invalid edge id " << eid;
src_points[eid] = src;
dst_points[eid] = dst;
}

static EdgeList::Ptr FromCSR(
const std::vector<int64_t>& indptr,
const std::vector<dgl_id_t>& indices,
const std::vector<dgl_id_t>& edge_ids,
bool in_csr);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you make the input arguments const reference? input pointers usually mean the vectors will be modified by the method.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

};

struct CSR {
typedef std::shared_ptr<CSR> Ptr;
std::vector<int64_t> indptr;
Expand Down Expand Up @@ -79,6 +103,7 @@ class ImmutableGraph: public GraphInterface {
void ReadAllEdges(std::vector<Edge> *edges) const;
CSR::Ptr Transpose() const;
std::pair<CSR::Ptr, IdArray> VertexSubgraph(IdArray vids) const;
std::pair<CSR::Ptr, IdArray> EdgeSubgraph(IdArray eids, EdgeList::Ptr edge_list) const;
/*
* Construct a CSR from a list of edges.
*
Expand Down Expand Up @@ -261,20 +286,14 @@ class ImmutableGraph: public GraphInterface {
* \param eid The edge ID
* \return a pair whose first element is the source and the second the destination.
*/
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const {
LOG(FATAL) << "FindEdge isn't supported in ImmutableGraph";
return std::pair<dgl_id_t, dgl_id_t>();
}
std::pair<dgl_id_t, dgl_id_t> FindEdge(dgl_id_t eid) const;

/*!
* \brief Find the edge IDs and return their source and target node IDs.
* \param eids The edge ID array.
* \return EdgeArray containing all edges with id in eid. The order is preserved.
*/
EdgeArray FindEdges(IdArray eids) const {
LOG(FATAL) << "FindEdges isn't supported in ImmutableGraph";
return EdgeArray();
}
EdgeArray FindEdges(IdArray eids) const;

/*!
* \brief Get the in edges of the vertex.
Expand Down Expand Up @@ -496,6 +515,25 @@ class ImmutableGraph: public GraphInterface {
}
}

/*
* The edge list is required for FindEdge/FindEdges/EdgeSubgraph, if no such function is called, we would not create edge list.
* if such function is called the first time, we create a edge list from one of the graph's csr representations,
* if we have called such function before, we get the one cached in the structure.
*/
EdgeList::Ptr GetEdgeList() const {
if (edge_list_)
return edge_list_;
if (in_csr_) {
const_cast<ImmutableGraph *>(this)->edge_list_ =\
EdgeList::FromCSR(in_csr_->indptr, in_csr_->indices, in_csr_->edge_ids, true);
} else {
CHECK(out_csr_ != nullptr) << "one of the CSRs must exist";
const_cast<ImmutableGraph *>(this)->edge_list_ =\
EdgeList::FromCSR(out_csr_->indptr, out_csr_->indices, out_csr_->edge_ids, false);
}
return edge_list_;
}

protected:
DGLIdIters GetInEdgeIdRef(dgl_id_t src, dgl_id_t dst) const;
DGLIdIters GetOutEdgeIdRef(dgl_id_t src, dgl_id_t dst) const;
Expand Down Expand Up @@ -525,6 +563,8 @@ class ImmutableGraph: public GraphInterface {
CSR::Ptr in_csr_;
// Store the out-edges.
CSR::Ptr out_csr_;
// Store the edge list indexed by edge id
EdgeList::Ptr edge_list_;
/*!
* \brief Whether if this is a multigraph.
*
Expand Down
6 changes: 2 additions & 4 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3070,10 +3070,8 @@ def readonly(self, readonly_state=True):
>>> G.number_of_nodes()
8
"""
if readonly_state == self._graph.is_readonly():
return self
self._graph.readonly(readonly_state)
return self
if readonly_state != self.is_readonly:
self._graph.readonly(readonly_state)

def __repr__(self):
ret = ('DGLGraph(num_nodes={node}, num_edges={edge},\n'
Expand Down
3 changes: 2 additions & 1 deletion src/graph/graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -252,6 +252,7 @@ Graph::EdgeArray Graph::EdgeIds(IdArray src_ids, IdArray dst_ids) const {
}

Graph::EdgeArray Graph::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
int64_t len = eids->shape[0];

IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);
Expand Down Expand Up @@ -472,7 +473,7 @@ Subgraph Graph::VertexSubgraph(IdArray vids) const {
}

Subgraph Graph::EdgeSubgraph(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid vertex id array.";
CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";

const auto len = eids->shape[0];
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
Expand Down
100 changes: 97 additions & 3 deletions src/graph/immutable_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -158,14 +158,33 @@ class HashTableChecker {
}
};

ImmutableGraph::EdgeList::Ptr ImmutableGraph::EdgeList::FromCSR(
const std::vector<int64_t>& indptr,
const std::vector<dgl_id_t>& indices,
const std::vector<dgl_id_t>& edge_ids,
bool in_csr) {
const auto n = indptr.size() - 1;
const auto len = edge_ids.size();
auto t = std::make_shared<EdgeList>(len, n);
for (size_t i = 0; i < indptr.size() - 1; i++) {
for (int64_t j = indptr[i]; j < indptr[i + 1]; j++) {
dgl_id_t row = i, col = indices[j];
if (in_csr)
std::swap(row, col);
t->register_edge(edge_ids[j], row, col);
}
}
return t;
}

std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::VertexSubgraph(
IdArray vids) const {
CHECK(IsValidIdArray(vids)) << "Invalid vertex id array.";
const dgl_id_t* vid_data = static_cast<dgl_id_t*>(vids->data);
const int64_t len = vids->shape[0];

HashTableChecker def_check(vid_data, len);
// check if varr is sorted.
// check if vid_data is sorted.
CHECK(std::is_sorted(vid_data, vid_data + len)) << "The input vertex list has to be sorted";

// Collect the non-zero entries in from the original graph.
Expand Down Expand Up @@ -197,6 +216,42 @@ std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::VertexSubgraph
return std::pair<ImmutableGraph::CSR::Ptr, IdArray>(sub_csr, rst_eids);
}

std::pair<ImmutableGraph::CSR::Ptr, IdArray> ImmutableGraph::CSR::EdgeSubgraph(
IdArray eids, EdgeList::Ptr edge_list) const {
// Return sub_csr and vids array.
CHECK(IsValidIdArray(eids)) << "Invalid edge id array.";
const dgl_id_t* eid_data = static_cast<dgl_id_t*>(eids->data);
const int64_t len = eids->shape[0];
std::vector<dgl_id_t> nodes;
std::unordered_map<dgl_id_t, dgl_id_t> oldv2newv;
std::vector<Edge> edges;

for (int64_t i = 0; i < len; i++) {
dgl_id_t src_id = edge_list->src_points[eid_data[i]];
dgl_id_t dst_id = edge_list->dst_points[eid_data[i]];

// pair<iterator, bool>, the second indicates whether the insertion is successful or not.
auto src_pair = oldv2newv.insert(std::make_pair(src_id, oldv2newv.size()));
auto dst_pair = oldv2newv.insert(std::make_pair(dst_id, oldv2newv.size()));
if (src_pair.second)
nodes.push_back(src_id);
if (dst_pair.second)
nodes.push_back(dst_id);
edges.push_back(Edge{src_pair.first->second, dst_pair.first->second, static_cast<dgl_id_t>(i)});
}

const size_t n = oldv2newv.size();
auto sub_csr = CSR::FromEdges(&edges, 0, n);

IdArray rst_vids = IdArray::Empty({static_cast<int64_t>(nodes.size())},
DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
dgl_id_t* vid_data = static_cast<dgl_id_t*>(rst_vids->data);
std::copy(nodes.begin(), nodes.end(), vid_data);

return std::make_pair(sub_csr, rst_vids);
}


ImmutableGraph::CSR::Ptr ImmutableGraph::CSR::FromEdges(std::vector<Edge> *edges,
int sort_on, uint64_t num_nodes) {
CHECK(sort_on == 0 || sort_on == 1) << "we must sort on the first or the second vector";
Expand Down Expand Up @@ -449,6 +504,34 @@ ImmutableGraph::EdgeArray ImmutableGraph::EdgeIds(IdArray src_ids, IdArray dst_i
return ImmutableGraph::EdgeArray{rst_src, rst_dst, rst_eid};
}

std::pair<dgl_id_t, dgl_id_t> ImmutableGraph::FindEdge(dgl_id_t eid) const {
dgl_id_t row = 0, col = 0;
auto edge_list = GetEdgeList();
CHECK(eid < NumEdges()) << "Invalid edge id " << eid;
row = edge_list->src_points[eid];
col = edge_list->dst_points[eid];
CHECK(row < NumVertices() && col < NumVertices()) << "Invalid edge id " << eid;
return std::pair<dgl_id_t, dgl_id_t>(row, col);
}

ImmutableGraph::EdgeArray ImmutableGraph::FindEdges(IdArray eids) const {
CHECK(IsValidIdArray(eids)) << "Invalid edge id array";
dgl_id_t* eid_data = static_cast<dgl_id_t*>(eids->data);
int64_t len = eids->shape[0];
IdArray rst_src = IdArray::Empty({len}, eids->dtype, eids->ctx);
IdArray rst_dst = IdArray::Empty({len}, eids->dtype, eids->ctx);
dgl_id_t* rst_src_data = static_cast<dgl_id_t*>(rst_src->data);
dgl_id_t* rst_dst_data = static_cast<dgl_id_t*>(rst_dst->data);

for (int64_t i = 0; i < len; i++) {
auto edge = ImmutableGraph::FindEdge(eid_data[i]);
rst_src_data[i] = edge.first;
rst_dst_data[i] = edge.second;
}

return ImmutableGraph::EdgeArray{rst_src, rst_dst, eids};
}

ImmutableGraph::EdgeArray ImmutableGraph::Edges(const std::string &order) const {
int64_t rstlen = NumEdges();
IdArray rst_src = IdArray::Empty({rstlen}, DLDataType{kDLInt, 64, 1}, DLContext{kDLCPU, 0});
Expand Down Expand Up @@ -506,8 +589,19 @@ Subgraph ImmutableGraph::VertexSubgraph(IdArray vids) const {
}

Subgraph ImmutableGraph::EdgeSubgraph(IdArray eids) const {
LOG(FATAL) << "EdgeSubgraph isn't implemented in immutable graph";
return Subgraph();
Subgraph subg;
std::pair<CSR::Ptr, IdArray> ret;
auto edge_list = GetEdgeList();
if (out_csr_) {
ret = out_csr_->EdgeSubgraph(eids, edge_list);
subg.graph = GraphPtr(new ImmutableGraph(nullptr, ret.first, IsMultigraph()));
} else {
ret = in_csr_->EdgeSubgraph(eids, edge_list);
subg.graph = GraphPtr(new ImmutableGraph(ret.first, nullptr, IsMultigraph()));
}
subg.induced_edges = eids;
subg.induced_vertices = ret.second;
return subg;
}

ImmutableGraph::CSRArray ImmutableGraph::GetInCSRArray() const {
Expand Down
30 changes: 30 additions & 0 deletions tests/compute/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,10 +184,40 @@ def test_readonly():
assert g.number_of_edges() == 14
assert F.shape(g.edata['x']) == (14, 4)

def test_find_edges():
g = dgl.DGLGraph()
g.add_nodes(10)
g.add_edges(range(9), range(1, 10))
e = g.find_edges([1, 3, 2, 4])
assert e[0][0] == 1 and e[0][1] == 3 and e[0][2] == 2 and e[0][3] == 4
assert e[1][0] == 2 and e[1][1] == 4 and e[1][2] == 3 and e[1][3] == 5

try:
g.find_edges([10])
fail = False
except DGLError:
fail = True
finally:
assert fail

g.readonly()
e = g.find_edges([1, 3, 2, 4])
assert e[0][0] == 1 and e[0][1] == 3 and e[0][2] == 2 and e[0][3] == 4
assert e[1][0] == 2 and e[1][1] == 4 and e[1][2] == 3 and e[1][3] == 5

try:
g.find_edges([10])
fail = False
except DGLError:
fail = True
finally:
assert fail

if __name__ == '__main__':
test_graph_creation()
test_create_from_elist()
test_adjmat_cache()
test_incmat()
test_incmat_cache()
test_readonly()
test_find_edges()
zheng-da marked this conversation as resolved.
Show resolved Hide resolved
17 changes: 17 additions & 0 deletions tests/graph_index/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,24 @@ def test_edge_subgraph():
assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d])

def test_immutable_edge_subgraph():
gi = create_graph_index()
gi.add_nodes(4)
gi.add_edge(0, 1)
gi.add_edge(0, 1)
gi.add_edge(0, 2)
gi.add_edge(2, 3)
gi.readonly() # Make the graph readonly

sub2par_edgemap = [3, 2]
sgi = gi.edge_subgraph(toindex(sub2par_edgemap))

for s, d, e in zip(*sgi.edges()):
assert sgi.induced_edges[e] in gi.edge_id(
sgi.induced_nodes[s], sgi.induced_nodes[d])


if __name__ == '__main__':
test_node_subgraph()
test_edge_subgraph()
test_immutable_edge_subgraph()