-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1 from zheng-da/mxnet
Add the MXNet backend.
- Loading branch information
Showing
6 changed files
with
521 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,222 @@ | ||
""" | ||
Graph Attention Networks | ||
Paper: https://arxiv.org/abs/1710.10903 | ||
Code: https://github.com/PetarV-/GAT | ||
GAT with batch processing | ||
""" | ||
|
||
import argparse | ||
import numpy as np | ||
import time | ||
import mxnet as mx | ||
from mxnet import gluon | ||
import dgl | ||
from dgl import DGLGraph | ||
from dgl.data import register_data_args, load_data | ||
|
||
def elu(data): | ||
return mx.nd.LeakyReLU(data, act_type='elu') | ||
|
||
def gat_message(src, edge): | ||
return {'ft' : src['ft'], 'a2' : src['a2']} | ||
|
||
class GATReduce(gluon.Block): | ||
def __init__(self, attn_drop): | ||
super(GATReduce, self).__init__() | ||
self.attn_drop = attn_drop | ||
|
||
def forward(self, node, msgs): | ||
a1 = mx.nd.expand_dims(node['a1'], 1) # shape (B, 1, 1) | ||
a2 = msgs['a2'] # shape (B, deg, 1) | ||
ft = msgs['ft'] # shape (B, deg, D) | ||
# attention | ||
a = a1 + a2 # shape (B, deg, 1) | ||
e = mx.nd.softmax(mx.nd.LeakyReLU(a)) | ||
if self.attn_drop != 0.0: | ||
e = mx.nd.Dropout(e, self.attn_drop) | ||
return {'accum' : mx.nd.sum(e * ft, axis=1)} # shape (B, D) | ||
|
||
class GATFinalize(gluon.Block): | ||
def __init__(self, headid, indim, hiddendim, activation, residual): | ||
super(GATFinalize, self).__init__() | ||
self.headid = headid | ||
self.activation = activation | ||
self.residual = residual | ||
self.residual_fc = None | ||
if residual: | ||
if indim != hiddendim: | ||
self.residual_fc = gluon.nn.Dense(hiddendim) | ||
|
||
def forward(self, node): | ||
ret = node['accum'] | ||
if self.residual: | ||
if self.residual_fc is not None: | ||
ret = self.residual_fc(node['h']) + ret | ||
else: | ||
ret = node['h'] + ret | ||
return {'head%d' % self.headid : self.activation(ret)} | ||
|
||
class GATPrepare(gluon.Block): | ||
def __init__(self, indim, hiddendim, drop): | ||
super(GATPrepare, self).__init__() | ||
self.fc = gluon.nn.Dense(hiddendim) | ||
self.drop = drop | ||
self.attn_l = gluon.nn.Dense(1) | ||
self.attn_r = gluon.nn.Dense(1) | ||
|
||
def forward(self, feats): | ||
h = feats | ||
if self.drop != 0.0: | ||
h = mx.nd.Dropout(h, self.drop) | ||
ft = self.fc(h) | ||
a1 = self.attn_l(ft) | ||
a2 = self.attn_r(ft) | ||
return {'h' : h, 'ft' : ft, 'a1' : a1, 'a2' : a2} | ||
|
||
class GAT(gluon.Block): | ||
def __init__(self, | ||
g, | ||
num_layers, | ||
in_dim, | ||
num_hidden, | ||
num_classes, | ||
num_heads, | ||
activation, | ||
in_drop, | ||
attn_drop, | ||
residual): | ||
super(GAT, self).__init__() | ||
self.g = g | ||
self.num_layers = num_layers | ||
self.num_heads = num_heads | ||
self.prp = gluon.nn.Sequential() | ||
self.red = gluon.nn.Sequential() | ||
self.fnl = gluon.nn.Sequential() | ||
# input projection (no residual) | ||
for hid in range(num_heads): | ||
self.prp.add(GATPrepare(in_dim, num_hidden, in_drop)) | ||
self.red.add(GATReduce(attn_drop)) | ||
self.fnl.add(GATFinalize(hid, in_dim, num_hidden, activation, False)) | ||
# hidden layers | ||
for l in range(num_layers - 1): | ||
for hid in range(num_heads): | ||
# due to multi-head, the in_dim = num_hidden * num_heads | ||
self.prp.add(GATPrepare(num_hidden * num_heads, num_hidden, in_drop)) | ||
self.red.add(GATReduce(attn_drop)) | ||
self.fnl.add(GATFinalize(hid, num_hidden * num_heads, | ||
num_hidden, activation, residual)) | ||
# output projection | ||
self.prp.add(GATPrepare(num_hidden * num_heads, num_classes, in_drop)) | ||
self.red.add(GATReduce(attn_drop)) | ||
self.fnl.add(GATFinalize(0, num_hidden * num_heads, | ||
num_classes, activation, residual)) | ||
# sanity check | ||
assert len(self.prp) == self.num_layers * self.num_heads + 1 | ||
assert len(self.red) == self.num_layers * self.num_heads + 1 | ||
assert len(self.fnl) == self.num_layers * self.num_heads + 1 | ||
|
||
def forward(self, features): | ||
last = features | ||
for l in range(self.num_layers): | ||
for hid in range(self.num_heads): | ||
i = l * self.num_heads + hid | ||
# prepare | ||
self.g.set_n_repr(self.prp[i](last)) | ||
# message passing | ||
self.g.update_all(gat_message, self.red[i], self.fnl[i], batchable=True) | ||
# merge all the heads | ||
last = mx.nd.concat( | ||
*[self.g.pop_n_repr('head%d' % hid) for hid in range(self.num_heads)], | ||
dim=1) | ||
# output projection | ||
self.g.set_n_repr(self.prp[-1](last)) | ||
self.g.update_all(gat_message, self.red[-1], self.fnl[-1], batchable=True) | ||
return self.g.pop_n_repr('head0') | ||
|
||
def main(args): | ||
# load and preprocess dataset | ||
data = load_data(args) | ||
|
||
features = mx.nd.array(data.features) | ||
labels = mx.nd.array(data.labels) | ||
mask = mx.nd.array(data.train_mask) | ||
in_feats = features.shape[1] | ||
n_classes = data.num_labels | ||
n_edges = data.graph.number_of_edges() | ||
|
||
if args.gpu < 0: | ||
cuda = False | ||
else: | ||
cuda = True | ||
torch.cuda.set_device(args.gpu) | ||
features = features.cuda() | ||
labels = labels.cuda() | ||
mask = mask.cuda() | ||
|
||
# create GCN model | ||
g = DGLGraph(data.graph) | ||
|
||
# create model | ||
model = GAT(g, | ||
args.num_layers, | ||
in_feats, | ||
args.num_hidden, | ||
n_classes, | ||
args.num_heads, | ||
elu, | ||
args.in_drop, | ||
args.attn_drop, | ||
args.residual) | ||
|
||
if cuda: | ||
model.cuda() | ||
model.initialize() | ||
|
||
# use optimizer | ||
trainer = gluon.Trainer(model.collect_params(), 'adam', {'learning_rate': args.lr}) | ||
|
||
# initialize graph | ||
dur = [] | ||
for epoch in range(args.epochs): | ||
if epoch >= 3: | ||
t0 = time.time() | ||
# forward | ||
with mx.autograd.record(): | ||
logits = model(features) | ||
loss = mx.nd.softmax_cross_entropy(logits, labels) | ||
|
||
#optimizer.zero_grad() | ||
loss.backward() | ||
trainer.step(features.shape[0]) | ||
|
||
if epoch >= 3: | ||
dur.append(time.time() - t0) | ||
print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format( | ||
epoch, loss.asnumpy()[0], np.mean(dur), n_edges / np.mean(dur) / 1000)) | ||
|
||
if __name__ == '__main__': | ||
parser = argparse.ArgumentParser(description='GAT') | ||
register_data_args(parser) | ||
parser.add_argument("--gpu", type=int, default=-1, | ||
help="Which GPU to use. Set -1 to use CPU.") | ||
parser.add_argument("--epochs", type=int, default=20, | ||
help="number of training epochs") | ||
parser.add_argument("--num-heads", type=int, default=3, | ||
help="number of attentional heads to use") | ||
parser.add_argument("--num-layers", type=int, default=1, | ||
help="number of hidden layers") | ||
parser.add_argument("--num-hidden", type=int, default=8, | ||
help="size of hidden units") | ||
parser.add_argument("--residual", action="store_false", | ||
help="use residual connection") | ||
parser.add_argument("--in-drop", type=float, default=.6, | ||
help="input feature dropout") | ||
parser.add_argument("--attn-drop", type=float, default=.6, | ||
help="attention dropout") | ||
parser.add_argument("--lr", type=float, default=0.005, | ||
help="learning rate") | ||
args = parser.parse_args() | ||
print(args) | ||
|
||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
Graph Convolutional Networks (GCN) | ||
============ | ||
|
||
Paper link: [https://arxiv.org/abs/1609.02907](https://arxiv.org/abs/1609.02907) | ||
Author's code repo: [https://github.com/tkipf/gcn](https://github.com/tkipf/gcn) | ||
|
||
The folder contains three different implementations using DGL. | ||
|
||
Naive GCN (gcn.py) | ||
------- | ||
The model is defined in the finest granularity (aka on *one* edge and *one* node). | ||
|
||
* The message function `gcn_msg` computes the message for one edge. It simply returns the `h` representation of the source node. | ||
```python | ||
def gcn_msg(src, edge): | ||
# src['h'] is a tensor of shape (D,). D is the feature length. | ||
return src['h'] | ||
``` | ||
* The reduce function `gcn_reduce` accumulates the incoming messages for one node. The `msgs` argument is a list of all the messages. In GCN, the incoming messages are summed up. | ||
```python | ||
def gcn_reduce(node, msgs): | ||
# msgs is a list of in-coming messages. | ||
return sum(msgs) | ||
``` | ||
* The update function `NodeUpdateModule` computes the new new node representation `h` using non-linear transformation on the reduced messages. | ||
```python | ||
class NodeUpdateModule(nn.Module): | ||
def __init__(self, in_feats, out_feats, activation=None): | ||
super(NodeUpdateModule, self).__init__() | ||
self.linear = nn.Linear(in_feats, out_feats) | ||
self.activation = activation | ||
|
||
def forward(self, node, accum): | ||
# accum is a tensor of shape (D,). | ||
h = self.linear(accum) | ||
if self.activation: | ||
h = self.activation(h) | ||
return {'h' : h} | ||
``` | ||
|
||
After defining the functions on each node/edge, the message passing is triggered by calling `update_all` on the DGLGraph object (in GCN module). | ||
|
||
Batched GCN (gcn_batch.py) | ||
----------- | ||
Defining the model on only one node and edge makes it hard to fully utilize GPUs. As a result, we allow users to define model on a *batch of* nodes and edges. | ||
|
||
* The message function `gcn_msg` computes the message for a batch of edges. Here, the `src` argument is the batched representation of the source endpoints of the edges. The function simply returns the source node representations. | ||
```python | ||
def gcn_msg(src, edge): | ||
# src is a tensor of shape (B, D). B is the number of edges being batched. | ||
return src | ||
``` | ||
* The reduce function `gcn_reduce` also accumulates messages for a batch of nodes. We batch the messages on the second dimension fo the `msgs` argument: | ||
```python | ||
def gcn_reduce(node, msgs): | ||
# The msgs is a tensor of shape (B, deg, D). B is the number of nodes in the batch; | ||
# deg is the number of messages; D is the message tensor dimension. DGL gaurantees | ||
# that all the nodes in a batch have the same in-degrees (through "degree-bucketing"). | ||
# Reduce on the second dimension is equal to sum up all the in-coming messages. | ||
return torch.sum(msgs, 1) | ||
``` | ||
* The update module is similar. The first dimension of each tensor is the batch dimension. Since PyTorch operation is usually aware of the batch dimension, the code is the same as the naive GCN. | ||
|
||
Triggering message passing is also similar. User needs to set `batchable=True` to indicate that the functions all support batching. | ||
```python | ||
self.g.update_all(gcn_msg, gcn_reduce, layer, batchable=True)` | ||
``` | ||
|
||
Batched GCN with spMV optimization (gcn_spmv.py) | ||
----------- | ||
Batched computation is much more efficient than naive vertex-centric approach, but is still not ideal. For example, the batched message function needs to look up source node data and save it on edges. Such kind of lookups is very common and incurs extra memory copy operations. In fact, the message and reduce phase of GCN model can be fused into one sparse-matrix-vector multiplication (spMV). Therefore, DGL provides many built-in message/reduce functions so we can figure out the chance of optimization. In gcn_spmv.py, user only needs to write update module and trigger the message passing as follows: | ||
```python | ||
self.g.update_all('from_src', 'sum', layer, batchable=True) | ||
``` | ||
Here, `'from_src'` and `'sum'` are the builtin message and reduce function. |
Oops, something went wrong.