Skip to content

Commit

Permalink
[Feature] Reversed Graph and Transform Module (#331)
Browse files Browse the repository at this point in the history
* reverse a graph

* Reverse a graph

* Fix

* Revert "Fix"

This reverts commit 1728826.

* Fix

* Fix

* Delete vcs.xml

* Delete Project_Default.xml

* Fix

* Fix

* Fix

* Remove outdated test

* Reorg transform and update reverse (#2)

* Reorg transform and update reverse

* Fix doc and test

* Update test

* Resolve conflict

* CI oriented fix

* Remove outdated import

* Fix import

* Fix import

* define __all__ for wildcard imports

* Fix import

* Address circular imports

* Fix

* Fix test case

* Fix

* Fix

* Remove unused import

* Fix

* Fix

* Fix
  • Loading branch information
mufeili authored and jermainewang committed Jan 4, 2019
1 parent 4bd4d6e commit 24bbdb7
Show file tree
Hide file tree
Showing 9 changed files with 230 additions and 58 deletions.
1 change: 1 addition & 0 deletions docs/source/api/python/graph.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Transforming graph
DGLGraph.subgraphs
DGLGraph.edge_subgraph
DGLGraph.line_graph
DGLGraph.reverse

Converting from/to other format
-------------------------------
Expand Down
1 change: 1 addition & 0 deletions docs/source/api/python/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ API Reference
udf
sampler
data
transform
12 changes: 12 additions & 0 deletions docs/source/api/python/transform.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
.. _apigraph:

Transform -- Graph Transformation
=================================

.. automodule:: dgl.transform

.. autosummary::
:toctree: ../../generated/

line_graph
reverse
2 changes: 1 addition & 1 deletion docs/source/install/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ The backend is controlled by ``DGLBACKEND`` environment variable, which defaults
| | | `official website <https://pytorch.org>`_ |
+---------+---------+--------------------------------------------------+
| mxnet | MXNet | Requires nightly build; run the following |
| | | command to install (TODO): |
| | | command to install: |
| | | |
| | | .. code:: bash |
| | | |
Expand Down
1 change: 1 addition & 0 deletions python/dgl/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,4 @@
from .traversal import *
from .propagate import *
from .udf import NodeBatch, EdgeBatch
from .transform import *
21 changes: 8 additions & 13 deletions python/dgl/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from collections import defaultdict

import dgl
from .base import ALL, is_all, DGLError
from . import backend as F
from . import init
Expand Down Expand Up @@ -2760,22 +2761,16 @@ def incidence_matrix(self, typestr, ctx=F.cpu()):
def line_graph(self, backtracking=True, shared=False):
"""Return the line graph of this graph.
Parameters
----------
backtracking : bool, optional
Whether the returned line graph is backtracking.
See :func:`~dgl.transform.line_graph`.
"""
return dgl.line_graph(self, backtracking, shared)

shared : bool, optional
Whether the returned line graph shares representations with `self`.
def reverse(self, share_ndata=False, share_edata=False):
"""Return the reverse of this graph.
Returns
-------
DGLGraph
The line graph of this graph.
See :func:`~dgl.transform.reverse`.
"""
graph_data = self._graph.line_graph(backtracking)
node_frame = self._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)
return dgl.reverse(self, share_ndata, share_edata)

def filter_nodes(self, predicate, nodes=ALL):
"""Return a tensor of node IDs that satisfy the given predicate.
Expand Down
105 changes: 105 additions & 0 deletions python/dgl/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
"""Module for graph transformation methods."""
from .graph import DGLGraph
from .batched_graph import BatchedDGLGraph

__all__ = ['line_graph', 'reverse']


def line_graph(g, backtracking=True, shared=False):
"""Return the line graph of this graph.
Parameters
----------
g : dgl.DGLGraph
backtracking : bool, optional
Whether the returned line graph is backtracking.
shared : bool, optional
Whether the returned line graph shares representations with `self`.
Returns
-------
DGLGraph
The line graph of this graph.
"""
graph_data = g._graph.line_graph(backtracking)
node_frame = g._edge_frame if shared else None
return DGLGraph(graph_data, node_frame)

def reverse(g, share_ndata=False, share_edata=False):
"""Return the reverse of a graph
The reverse (also called converse, transpose) of a directed graph is another directed
graph on the same nodes with edges reversed in terms of direction.
Given a :class:`DGLGraph` object, we return another :class:`DGLGraph` object
representing its reverse.
Notes
-----
* This function does not support :class:`~dgl.BatchedDGLGraph` objects.
* We do not dynamically update the topology of a graph once that of its reverse changes.
This can be particularly problematic when the node/edge attrs are shared. For example,
if the topology of both the original graph and its reverse get changed independently,
you can get a mismatched node/edge feature.
Parameters
----------
g : dgl.DGLGraph
share_ndata: bool, optional
If True, the original graph and the reversed graph share memory for node attributes.
Otherwise the reversed graph will not be initialized with node attributes.
share_edata: bool, optional
If True, the original graph and the reversed graph share memory for edge attributes.
Otherwise the reversed graph will not have edge attributes.
Examples
--------
Create a graph to reverse.
>>> import dgl
>>> import torch as th
>>> g = dgl.DGLGraph()
>>> g.add_nodes(3)
>>> g.add_edges([0, 1, 2], [1, 2, 0])
>>> g.ndata['h'] = th.tensor([[0.], [1.], [2.]])
>>> g.edata['h'] = th.tensor([[3.], [4.], [5.]])
Reverse the graph and examine its structure.
>>> rg = g.reverse(share_ndata=True, share_edata=True)
>>> print(rg)
DGLGraph with 3 nodes and 3 edges.
Node data: {'h': Scheme(shape=(1,), dtype=torch.float32)}
Edge data: {'h': Scheme(shape=(1,), dtype=torch.float32)}
The edges are reversed now.
>>> rg.has_edges_between([1, 2, 0], [0, 1, 2])
tensor([1, 1, 1])
Reversed edges have the same feature as the original ones.
>>> g.edges[[0, 2], [1, 0]].data['h'] == rg.edges[[1, 0], [0, 2]].data['h']
tensor([[1],
[1]], dtype=torch.uint8)
The node/edge features of the reversed graph share memory with the original
graph, which is helpful for both forward computation and back propagation.
>>> g.ndata['h'] = g.ndata['h'] + 1
>>> rg.ndata['h']
tensor([[1.],
[2.],
[3.]])
"""
assert not isinstance(g, BatchedDGLGraph), \
'reverse is not supported for a BatchedDGLGraph object'
g_reversed = DGLGraph(multigraph=g.is_multigraph)
g_reversed.add_nodes(g.number_of_nodes())
g_edges = g.edges()
g_reversed.add_edges(g_edges[1], g_edges[0])
if share_ndata:
g_reversed._node_frame = g._node_frame
if share_edata:
g_reversed._edge_frame = g._edge_frame
return g_reversed
44 changes: 0 additions & 44 deletions tests/pytorch/test_line_graph.py

This file was deleted.

101 changes: 101 additions & 0 deletions tests/pytorch/test_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch as th
import networkx as nx
import numpy as np
import dgl
import dgl.function as fn
import utils as U

D = 5

# line graph related
def test_line_graph():
N = 5
G = dgl.DGLGraph(nx.star_graph(N))
G.edata['h'] = th.randn((2 * N, D))
n_edges = G.number_of_edges()
L = G.line_graph(shared=True)
assert L.number_of_nodes() == 2 * N
L.ndata['h'] = th.randn((2 * N, D))
# update node features on line graph should reflect to edge features on
# original graph.
u = [0, 0, 2, 3]
v = [1, 2, 0, 0]
eid = G.edge_ids(u, v)
L.nodes[eid].data['h'] = th.zeros((4, D))
assert U.allclose(G.edges[u, v].data['h'], th.zeros((4, D)))

# adding a new node feature on line graph should also reflect to a new
# edge feature on original graph
data = th.randn(n_edges, D)
L.ndata['w'] = data
assert U.allclose(G.edata['w'], data)

def test_no_backtracking():
N = 5
G = dgl.DGLGraph(nx.star_graph(N))
L = G.line_graph(backtracking=False)
assert L.number_of_nodes() == 2 * N
for i in range(1, N):
e1 = G.edge_id(0, i)
e2 = G.edge_id(i, 0)
assert not L.has_edge_between(e1, e2)
assert not L.has_edge_between(e2, e1)

# reverse graph related
def test_reverse():
g = dgl.DGLGraph()
g.add_nodes(5)
# The graph need not to be completely connected.
g.add_edges([0, 1, 2], [1, 2, 1])
g.ndata['h'] = th.tensor([[0.], [1.], [2.], [3.], [4.]])
g.edata['h'] = th.tensor([[5.], [6.], [7.]])
rg = g.reverse()

assert g.is_multigraph == rg.is_multigraph

assert g.number_of_nodes() == rg.number_of_nodes()
assert g.number_of_edges() == rg.number_of_edges()
assert U.allclose(rg.has_edges_between([1, 2, 1], [0, 1, 2]).float(), th.ones(3))
assert g.edge_id(0, 1) == rg.edge_id(1, 0)
assert g.edge_id(1, 2) == rg.edge_id(2, 1)
assert g.edge_id(2, 1) == rg.edge_id(1, 2)

def test_reverse_shared_frames():
g = dgl.DGLGraph()
g.add_nodes(3)
g.add_edges([0, 1, 2], [1, 2, 1])
g.ndata['h'] = th.tensor([[0.], [1.], [2.]], requires_grad=True)
g.edata['h'] = th.tensor([[3.], [4.], [5.]], requires_grad=True)

rg = g.reverse(share_ndata=True, share_edata=True)
assert U.allclose(g.ndata['h'], rg.ndata['h'])
assert U.allclose(g.edata['h'], rg.edata['h'])
assert U.allclose(g.edges[[0, 2], [1, 1]].data['h'],
rg.edges[[1, 1], [0, 2]].data['h'])

rg.ndata['h'] = rg.ndata['h'] + 1
assert U.allclose(rg.ndata['h'], g.ndata['h'])

g.edata['h'] = g.edata['h'] - 1
assert U.allclose(rg.edata['h'], g.edata['h'])

src_msg = fn.copy_src(src='h', out='m')
sum_reduce = fn.sum(msg='m', out='h')

rg.update_all(src_msg, sum_reduce)
assert U.allclose(g.ndata['h'], rg.ndata['h'])

# Grad check
g.ndata['h'].retain_grad()
rg.ndata['h'].retain_grad()
loss_func = th.nn.MSELoss()
target = th.zeros(3, 1)
loss = loss_func(rg.ndata['h'], target)
loss.backward()
assert U.allclose(g.ndata['h'].grad, rg.ndata['h'].grad)

if __name__ == '__main__':
test_line_graph()
test_no_backtracking()
test_reverse()
test_reverse_shared_frames()

0 comments on commit 24bbdb7

Please sign in to comment.