-
Notifications
You must be signed in to change notification settings - Fork 3k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[Feature] Reversed Graph and Transform Module (#331)
* reverse a graph * Reverse a graph * Fix * Revert "Fix" This reverts commit 1728826. * Fix * Fix * Delete vcs.xml * Delete Project_Default.xml * Fix * Fix * Fix * Remove outdated test * Reorg transform and update reverse (#2) * Reorg transform and update reverse * Fix doc and test * Update test * Resolve conflict * CI oriented fix * Remove outdated import * Fix import * Fix import * define __all__ for wildcard imports * Fix import * Address circular imports * Fix * Fix test case * Fix * Fix * Remove unused import * Fix * Fix * Fix
- Loading branch information
1 parent
4bd4d6e
commit 24bbdb7
Showing
9 changed files
with
230 additions
and
58 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,3 +13,4 @@ API Reference | |
udf | ||
sampler | ||
data | ||
transform |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
.. _apigraph: | ||
|
||
Transform -- Graph Transformation | ||
================================= | ||
|
||
.. automodule:: dgl.transform | ||
|
||
.. autosummary:: | ||
:toctree: ../../generated/ | ||
|
||
line_graph | ||
reverse |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
"""Module for graph transformation methods.""" | ||
from .graph import DGLGraph | ||
from .batched_graph import BatchedDGLGraph | ||
|
||
__all__ = ['line_graph', 'reverse'] | ||
|
||
|
||
def line_graph(g, backtracking=True, shared=False): | ||
"""Return the line graph of this graph. | ||
Parameters | ||
---------- | ||
g : dgl.DGLGraph | ||
backtracking : bool, optional | ||
Whether the returned line graph is backtracking. | ||
shared : bool, optional | ||
Whether the returned line graph shares representations with `self`. | ||
Returns | ||
------- | ||
DGLGraph | ||
The line graph of this graph. | ||
""" | ||
graph_data = g._graph.line_graph(backtracking) | ||
node_frame = g._edge_frame if shared else None | ||
return DGLGraph(graph_data, node_frame) | ||
|
||
def reverse(g, share_ndata=False, share_edata=False): | ||
"""Return the reverse of a graph | ||
The reverse (also called converse, transpose) of a directed graph is another directed | ||
graph on the same nodes with edges reversed in terms of direction. | ||
Given a :class:`DGLGraph` object, we return another :class:`DGLGraph` object | ||
representing its reverse. | ||
Notes | ||
----- | ||
* This function does not support :class:`~dgl.BatchedDGLGraph` objects. | ||
* We do not dynamically update the topology of a graph once that of its reverse changes. | ||
This can be particularly problematic when the node/edge attrs are shared. For example, | ||
if the topology of both the original graph and its reverse get changed independently, | ||
you can get a mismatched node/edge feature. | ||
Parameters | ||
---------- | ||
g : dgl.DGLGraph | ||
share_ndata: bool, optional | ||
If True, the original graph and the reversed graph share memory for node attributes. | ||
Otherwise the reversed graph will not be initialized with node attributes. | ||
share_edata: bool, optional | ||
If True, the original graph and the reversed graph share memory for edge attributes. | ||
Otherwise the reversed graph will not have edge attributes. | ||
Examples | ||
-------- | ||
Create a graph to reverse. | ||
>>> import dgl | ||
>>> import torch as th | ||
>>> g = dgl.DGLGraph() | ||
>>> g.add_nodes(3) | ||
>>> g.add_edges([0, 1, 2], [1, 2, 0]) | ||
>>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]]) | ||
>>> g.edata['h'] = th.tensor([[3.], [4.], [5.]]) | ||
Reverse the graph and examine its structure. | ||
>>> rg = g.reverse(share_ndata=True, share_edata=True) | ||
>>> print(rg) | ||
DGLGraph with 3 nodes and 3 edges. | ||
Node data: {'h': Scheme(shape=(1,), dtype=torch.float32)} | ||
Edge data: {'h': Scheme(shape=(1,), dtype=torch.float32)} | ||
The edges are reversed now. | ||
>>> rg.has_edges_between([1, 2, 0], [0, 1, 2]) | ||
tensor([1, 1, 1]) | ||
Reversed edges have the same feature as the original ones. | ||
>>> g.edges[[0, 2], [1, 0]].data['h'] == rg.edges[[1, 0], [0, 2]].data['h'] | ||
tensor([[1], | ||
[1]], dtype=torch.uint8) | ||
The node/edge features of the reversed graph share memory with the original | ||
graph, which is helpful for both forward computation and back propagation. | ||
>>> g.ndata['h'] = g.ndata['h'] + 1 | ||
>>> rg.ndata['h'] | ||
tensor([[1.], | ||
[2.], | ||
[3.]]) | ||
""" | ||
assert not isinstance(g, BatchedDGLGraph), \ | ||
'reverse is not supported for a BatchedDGLGraph object' | ||
g_reversed = DGLGraph(multigraph=g.is_multigraph) | ||
g_reversed.add_nodes(g.number_of_nodes()) | ||
g_edges = g.edges() | ||
g_reversed.add_edges(g_edges[1], g_edges[0]) | ||
if share_ndata: | ||
g_reversed._node_frame = g._node_frame | ||
if share_edata: | ||
g_reversed._edge_frame = g._edge_frame | ||
return g_reversed |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import torch as th | ||
import networkx as nx | ||
import numpy as np | ||
import dgl | ||
import dgl.function as fn | ||
import utils as U | ||
|
||
D = 5 | ||
|
||
# line graph related | ||
def test_line_graph(): | ||
N = 5 | ||
G = dgl.DGLGraph(nx.star_graph(N)) | ||
G.edata['h'] = th.randn((2 * N, D)) | ||
n_edges = G.number_of_edges() | ||
L = G.line_graph(shared=True) | ||
assert L.number_of_nodes() == 2 * N | ||
L.ndata['h'] = th.randn((2 * N, D)) | ||
# update node features on line graph should reflect to edge features on | ||
# original graph. | ||
u = [0, 0, 2, 3] | ||
v = [1, 2, 0, 0] | ||
eid = G.edge_ids(u, v) | ||
L.nodes[eid].data['h'] = th.zeros((4, D)) | ||
assert U.allclose(G.edges[u, v].data['h'], th.zeros((4, D))) | ||
|
||
# adding a new node feature on line graph should also reflect to a new | ||
# edge feature on original graph | ||
data = th.randn(n_edges, D) | ||
L.ndata['w'] = data | ||
assert U.allclose(G.edata['w'], data) | ||
|
||
def test_no_backtracking(): | ||
N = 5 | ||
G = dgl.DGLGraph(nx.star_graph(N)) | ||
L = G.line_graph(backtracking=False) | ||
assert L.number_of_nodes() == 2 * N | ||
for i in range(1, N): | ||
e1 = G.edge_id(0, i) | ||
e2 = G.edge_id(i, 0) | ||
assert not L.has_edge_between(e1, e2) | ||
assert not L.has_edge_between(e2, e1) | ||
|
||
# reverse graph related | ||
def test_reverse(): | ||
g = dgl.DGLGraph() | ||
g.add_nodes(5) | ||
# The graph need not to be completely connected. | ||
g.add_edges([0, 1, 2], [1, 2, 1]) | ||
g.ndata['h'] = th.tensor([[0.], [1.], [2.], [3.], [4.]]) | ||
g.edata['h'] = th.tensor([[5.], [6.], [7.]]) | ||
rg = g.reverse() | ||
|
||
assert g.is_multigraph == rg.is_multigraph | ||
|
||
assert g.number_of_nodes() == rg.number_of_nodes() | ||
assert g.number_of_edges() == rg.number_of_edges() | ||
assert U.allclose(rg.has_edges_between([1, 2, 1], [0, 1, 2]).float(), th.ones(3)) | ||
assert g.edge_id(0, 1) == rg.edge_id(1, 0) | ||
assert g.edge_id(1, 2) == rg.edge_id(2, 1) | ||
assert g.edge_id(2, 1) == rg.edge_id(1, 2) | ||
|
||
def test_reverse_shared_frames(): | ||
g = dgl.DGLGraph() | ||
g.add_nodes(3) | ||
g.add_edges([0, 1, 2], [1, 2, 1]) | ||
g.ndata['h'] = th.tensor([[0.], [1.], [2.]], requires_grad=True) | ||
g.edata['h'] = th.tensor([[3.], [4.], [5.]], requires_grad=True) | ||
|
||
rg = g.reverse(share_ndata=True, share_edata=True) | ||
assert U.allclose(g.ndata['h'], rg.ndata['h']) | ||
assert U.allclose(g.edata['h'], rg.edata['h']) | ||
assert U.allclose(g.edges[[0, 2], [1, 1]].data['h'], | ||
rg.edges[[1, 1], [0, 2]].data['h']) | ||
|
||
rg.ndata['h'] = rg.ndata['h'] + 1 | ||
assert U.allclose(rg.ndata['h'], g.ndata['h']) | ||
|
||
g.edata['h'] = g.edata['h'] - 1 | ||
assert U.allclose(rg.edata['h'], g.edata['h']) | ||
|
||
src_msg = fn.copy_src(src='h', out='m') | ||
sum_reduce = fn.sum(msg='m', out='h') | ||
|
||
rg.update_all(src_msg, sum_reduce) | ||
assert U.allclose(g.ndata['h'], rg.ndata['h']) | ||
|
||
# Grad check | ||
g.ndata['h'].retain_grad() | ||
rg.ndata['h'].retain_grad() | ||
loss_func = th.nn.MSELoss() | ||
target = th.zeros(3, 1) | ||
loss = loss_func(rg.ndata['h'], target) | ||
loss.backward() | ||
assert U.allclose(g.ndata['h'].grad, rg.ndata['h'].grad) | ||
|
||
if __name__ == '__main__': | ||
test_line_graph() | ||
test_no_backtracking() | ||
test_reverse() | ||
test_reverse_shared_frames() |