Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Layer] Update layers #249

Merged
merged 1 commit into from
Jun 25, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 33 additions & 2 deletions cogdl/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,36 @@
from .maggregator import MeanAggregator, SumAggregator
from .gcn_layer import GCNLayer
from .sage_layer import MeanAggregator, SumAggregator, SAGELayer
from .gat_layer import GATLayer
from .gin_layer import GINLayer
from .se_layer import SELayer
from .deepergcn_layer import GENConv, DeepGCNLayer
from .disengcn_layer import DisenGCNLayer
from .gcnii_layer import GCNIILayer
from .mlp_layer import MLPLayer
from .saint_layer import SAINTLayer
from .han_layer import HANLayer
from .pprgo_layer import PPRGoLayer
from .rgcn_layer import RGCNLayer
from .sgc_layer import SGCLayer
from .mixhop_layer import MixHopLayer

__all__ = ["SELayer", "MeanAggregator", "SumAggregator", "MixHopLayer"]
__all__ = [
"GCNLayer",
"MeanAggregator",
"SumAggregator",
"SAGELayer",
"GATLayer",
"GINLayer",
"SELayer",
"GENConv",
"DeepGCNLayer",
"DisenGCNLayer",
"GCNIILayer",
"MLPLayer",
"SAINTLayer",
"HANLayer",
"PPRGoLayer",
"RGCNLayer",
"SGCLayer",
"MixHopLayer",
]
140 changes: 140 additions & 0 deletions cogdl/layers/deepergcn_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from cogdl.utils import get_activation, mul_edge_softmax
from torch.utils.checkpoint import checkpoint


class GENConv(nn.Module):
def __init__(
self,
in_feat,
out_feat,
aggr="softmax_sg",
beta=1.0,
p=1.0,
learn_beta=False,
learn_p=False,
use_msg_norm=False,
learn_msg_scale=True,
):
super(GENConv, self).__init__()
self.use_msg_norm = use_msg_norm
self.mlp = nn.Linear(in_feat, out_feat)

self.message_encoder = torch.nn.ReLU()

self.aggr = aggr
if aggr == "softmax_sg":
self.beta = torch.nn.Parameter(
torch.Tensor(
[
beta,
]
),
requires_grad=learn_beta,
)
else:
self.register_buffer("beta", None)
if aggr == "powermean":
self.p = torch.nn.Parameter(
torch.Tensor(
[
p,
]
),
requires_grad=learn_p,
)
else:
self.register_buffer("p", None)
self.eps = 1e-7

self.s = torch.nn.Parameter(torch.Tensor([1.0]), requires_grad=learn_msg_scale)
self.act = nn.ReLU()

def message_norm(self, x, msg):
x_norm = torch.norm(x, dim=1, p=2)
msg_norm = F.normalize(msg, p=2, dim=1)
msg_norm = msg_norm * x_norm.unsqueeze(-1)
return x + self.s * msg_norm

def forward(self, graph, x):
edge_index = graph.edge_index
dim = x.shape[1]
edge_msg = x[edge_index[1]] # if edge_attr is None else x[edge_index[1]] + edge_attr
edge_msg = self.act(edge_msg) + self.eps

if self.aggr == "softmax_sg":
h = mul_edge_softmax(graph, self.beta * edge_msg)
h = edge_msg * h
elif self.aggr == "softmax":
h = mul_edge_softmax(graph, edge_msg)
h = edge_msg * h
elif self.aggr == "powermean":
deg = graph.degrees()
h = edge_msg.pow(self.t) / deg[edge_index[0]].unsqueeze(-1)
else:
raise NotImplementedError

h = torch.zeros_like(x).scatter_add_(dim=0, index=edge_index[0].unsqueeze(-1).repeat(1, dim), src=h)
if self.aggr == "powermean":
h = h.pow(1.0 / self.p)
if self.use_msg_norm:
h = self.message_norm(x, h)
h = self.mlp(h)
return h


class DeepGCNLayer(nn.Module):
"""
Implementation of DeeperGCN in paper `"DeeperGCN: All You Need to Train Deeper GCNs"` <https://arxiv.org/abs/2006.07739>

Parameters
-----------
in_feat : int
Size of each input sample
out_feat : int
Size of each output sample
conv : class
Base convolution layer.
connection : str
Residual connection type, `res` or `res+`.
activation : str
dropout : float
checkpoint_grad : bool
"""

def __init__(
self,
in_feat,
out_feat,
conv,
connection="res",
activation="relu",
dropout=0.0,
checkpoint_grad=False,
):
super(DeepGCNLayer, self).__init__()
self.conv = conv
self.activation = get_activation(activation)
self.dropout = dropout
self.connection = connection
self.norm = nn.BatchNorm1d(out_feat, affine=True)
self.checkpoint_grad = checkpoint_grad

def forward(self, graph, x):
if self.connection == "res+":
h = self.norm(x)
h = self.activation(h)
h = F.dropout(h, p=self.dropout, training=self.training)
if self.checkpoint_grad:
h = checkpoint(self.conv, graph, h)
else:
h = self.conv(graph, h)
elif self.connection == "res":
h = self.conv(graph, x)
h = self.norm(h)
h = self.activation(h)
else:
raise NotImplementedError
return x + h
91 changes: 91 additions & 0 deletions cogdl/layers/disengcn_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import torch
import torch.nn as nn

from cogdl.utils import mul_edge_softmax


class DisenGCNLayer(nn.Module):
"""
Implementation of "Disentangled Graph Convolutional Networks" <http://proceedings.mlr.press/v97/ma19a.html>.
"""

def __init__(self, in_feats, out_feats, K, iterations, tau=1.0, activation="leaky_relu"):
super(DisenGCNLayer, self).__init__()
self.K = K
self.tau = tau
self.iterations = iterations
self.factor_dim = int(out_feats / K)

self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
self.bias = nn.Parameter(torch.Tensor(out_feats))
self.reset_parameters()

if activation == "leaky_relu":
self.activation = nn.LeakyReLU()
elif activation == "sigmoid":
self.activation = nn.Sigmoid()
elif activation == "tanh":
self.activation = nn.Tanh()
elif activation == "prelu":
self.activation = nn.PReLU()
elif activation == "relu":
self.activation = nn.ReLU()
else:
raise NotImplementedError

def reset_parameters(self):
nn.init.xavier_normal_(self.weight.data, gain=1.414)
nn.init.zeros_(self.bias.data)

def forward(self, graph, x):
num_nodes = x.shape[0]
device = x.device

h = self.activation(torch.matmul(x, self.weight) + self.bias)

h = h.split(self.factor_dim, dim=-1)
h = torch.cat([dt.unsqueeze(0) for dt in h], dim=0)
norm = h.pow(2).sum(dim=-1).sqrt().unsqueeze(-1)

# multi-channel softmax: faster
h_normed = h / norm # (K, N, d)
h_src = h_dst = h_normed.permute(1, 0, 2) # (N, K, d)
add_shape = h.shape # (K, N, d)

edge_index = graph.edge_index
for _ in range(self.iterations):
src_edge_attr = h_dst[edge_index[0]] * h_src[edge_index[1]]
src_edge_attr = src_edge_attr.sum(dim=-1) / self.tau # shape: (N, K)
edge_attr_softmax = mul_edge_softmax(graph, src_edge_attr).T # shape: (E, K)
edge_attr_softmax = edge_attr_softmax.unsqueeze(-1) # shape: (K, E, 1)

dst_edge_attr = h_src.index_select(0, edge_index[1]).permute(1, 0, 2) # shape: (E, K, d) -> (K, E, d)
dst_edge_attr = dst_edge_attr * edge_attr_softmax
edge_index_ = edge_index[0].unsqueeze(-1).unsqueeze(0).repeat(self.K, 1, h.shape[-1])
node_attr = torch.zeros(add_shape).to(device).scatter_add_(1, edge_index_, dst_edge_attr) # (K, N, d)
node_attr = node_attr + h_normed
node_attr_norm = node_attr.pow(2).sum(-1).sqrt().unsqueeze(-1) # shape: (K, N, 1)
node_attr = (node_attr / node_attr_norm).permute(1, 0, 2) # shape: (N, K, d)
h_dst = node_attr

h_dst = h_dst.reshape(num_nodes, -1)

# Calculate the softmax of each channel separately
# h_src = h_dst = h / norm # (K, N, d)
#
# for _ in range(self.iterations):
# for i in range(self.K):
# h_attr = h_dst[i]
# edge_attr = h_attr[edge_index[0]] * h_src[i][edge_index[1]]
#
# edge_attr = edge_attr.sum(-1)/self.tau
# edge_attr = edge_softmax(edge_index, edge_attr, shape=(num_nodes, num_nodes))
#
# node_attr = spmm(edge_index, edge_attr, h_src[i])
#
# node_attr = node_attr + h_src[i]
# h_src[i] = node_attr / node_attr.pow(2).sum(-1).sqrt().unsqueeze(-1)
#
# h_dst = h_dst.permute(1, 0, 2).reshape(num_nodes, -1)

return h_dst
89 changes: 89 additions & 0 deletions cogdl/layers/gat_layer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import math

import torch
import torch.nn as nn
from cogdl.utils import check_mh_spmm, mh_spmm, mul_edge_softmax, spmm


class GATLayer(nn.Module):
"""
Sparse version GAT layer, similar to https://arxiv.org/abs/1710.10903
"""

def __init__(self, in_features, out_features, nhead=1, alpha=0.2, dropout=0.6, concat=True, residual=False):
super(GATLayer, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.alpha = alpha
self.concat = concat
self.nhead = nhead

self.W = nn.Parameter(torch.FloatTensor(in_features, out_features * nhead))

self.a_l = nn.Parameter(torch.zeros(size=(1, nhead, out_features)))
self.a_r = nn.Parameter(torch.zeros(size=(1, nhead, out_features)))

self.dropout = nn.Dropout(dropout)
self.leakyrelu = nn.LeakyReLU(self.alpha)

if residual:
out_features = out_features * nhead if concat else out_features
self.residual = nn.Linear(in_features, out_features)
else:
self.register_buffer("residual", None)
self.reset_parameters()

def reset_parameters(self):
def reset(tensor):
stdv = math.sqrt(6.0 / (tensor.size(-2) + tensor.size(-1)))
tensor.data.uniform_(-stdv, stdv)

reset(self.a_l)
reset(self.a_r)
reset(self.W)

# nn.init.xavier_uniform_(self.W.data, gain=1.414)
# nn.init.xavier_uniform_(self.a_r.data, gain=1.414)
# nn.init.xavier_uniform_(self.a_l.data, gain=1.414)

def forward(self, graph, x):
h = torch.matmul(x, self.W).view(-1, self.nhead, self.out_features)
h[torch.isnan(h)] = 0.0

row, col = graph.edge_index
# Self-attention on the nodes - Shared attention mechanism
h_l = (self.a_l * h).sum(dim=-1)[row]
h_r = (self.a_r * h).sum(dim=-1)[col]
edge_attention = self.leakyrelu(h_l + h_r)
# edge_attention: E * H
edge_attention = mul_edge_softmax(graph, edge_attention)
edge_attention = self.dropout(edge_attention)

if check_mh_spmm() and next(self.parameters()).device.type != "cpu":
if self.nhead > 1:
h_prime = mh_spmm(graph, edge_attention, h)
out = h_prime.view(h_prime.shape[0], -1)
else:
edge_weight = edge_attention.view(-1)
with graph.local_graph():
graph.edge_weight = edge_weight
out = spmm(graph, h.squeeze(1))
else:
with graph.local_graph():
h_prime = []
h = h.permute(1, 0, 2).contiguous()
for i in range(self.nhead):
edge_weight = edge_attention[:, i]
graph.edge_weight = edge_weight
hidden = h[i]
assert not torch.isnan(hidden).any()
h_prime.append(spmm(graph, hidden))
out = torch.cat(h_prime, dim=1)

if self.residual:
res = self.residual(x)
out += res
return out

def __repr__(self):
return self.__class__.__name__ + " (" + str(self.in_features) + " -> " + str(self.out_features) + ")"
Loading