Skip to content

Commit

Permalink
change ordering in propagate/update
Browse files Browse the repository at this point in the history
  • Loading branch information
CarloLucibello committed Sep 9, 2021
1 parent 4deb30e commit b5512db
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 116 deletions.
42 changes: 21 additions & 21 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,13 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T
end

message(l::GCNConv, xi, xj) = xj
update(l::GCNConv, m, x) = m
update(l::GCNConv, x, m) = m

function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T
g = add_self_loops(g)
c = 1 ./ sqrt.(degree(g, T, dir=:in))
x = x .* c'
_, x = propagate(l, g, nothing, x, nothing, +)
x, _ = propagate(l, g, +, x)
x = x .* c'
return l.σ.(l.weight * x .+ l.bias)
end
Expand Down Expand Up @@ -176,12 +176,12 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+;
GraphConv(W1, W2, b, σ, aggr)
end

message(gc::GraphConv, x_i, x_j, e_ij) = x_j
update(gc::GraphConv, m, x) = gc.σ.(gc.weight1 * x .+ gc.weight2 * m .+ gc.bias)
message(l::GraphConv, x_i, x_j, e_ij) = x_j
update(l::GraphConv, x, m) = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias)

function (gc::GraphConv)(g::GNNGraph, x::AbstractMatrix)
function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix)
check_num_nodes(g, x)
_, x = propagate(gc, g, nothing, x, nothing, +)
x, _ = propagate(l, g, +, x)
x
end

Expand Down Expand Up @@ -325,23 +325,23 @@ end


message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j
update(l::GatedGraphConv, m, x) = m
update(l::GatedGraphConv, x, m) = m

# remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521
@non_differentiable fill!(x...)

function (ggc::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real}
check_num_nodes(g, H)
m, n = size(H)
@assert (m <= ggc.out_ch) "number of input features must less or equals to output features."
if m < ggc.out_ch
Hpad = similar(H, S, ggc.out_ch - m, n)
@assert (m <= l.out_ch) "number of input features must less or equals to output features."
if m < l.out_ch
Hpad = similar(H, S, l.out_ch - m, n)
H = vcat(H, fill!(Hpad, 0))
end
for i = 1:ggc.num_layers
M = view(ggc.weight, :, :, i) * H
_, M = propagate(ggc, g, nothing, M, nothing, +)
H, _ = ggc.gru(H, M)
for i = 1:l.num_layers
M = view(l.weight, :, :, i) * H
M, _ = propagate(l, g, +, M)
H, _ = l.gru(H, M)
end
H
end
Expand Down Expand Up @@ -379,13 +379,13 @@ end

EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr)

message(ec::EdgeConv, x_i, x_j, e_ij) = ec.nn(vcat(x_i, x_j .- x_i))
message(l::EdgeConv, x_i, x_j, e_ij) = l.nn(vcat(x_i, x_j .- x_i))

update(ec::EdgeConv, m, x) = m
update(l::EdgeConv, x, m) = m

function (ec::EdgeConv)(g::GNNGraph, X::AbstractMatrix)
function (l::EdgeConv)(g::GNNGraph, X::AbstractMatrix)
check_num_nodes(g, X)
_, X = propagate(ec, g, nothing, X, nothing, ec.aggr)
X, _ = propagate(l, g, +, X)
X
end

Expand Down Expand Up @@ -425,10 +425,10 @@ function GINConv(nn; eps=0f0)
end

message(l::GINConv, x_i, x_j) = x_j
update(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m)
update(l::GINConv, x, m) = l.nn((1 + l.eps) * x + m)

function (l::GINConv)(g::GNNGraph, X::AbstractMatrix)
check_num_nodes(g, X)
_, X = propagate(l, g, nothing, X, nothing, +)
X, _ = propagate(l, g, +, X)
X
end
65 changes: 32 additions & 33 deletions src/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
# "Relational inductive biases, deep learning, and graph networks"

"""
propagate(mp, g, X, E, U, aggr)
propagate(mp, g, aggr, [X, E, U]) -> X′, E′, U′
propagate(mp, g, aggr) -> g′
Perform the sequence of operations implementing the message-passing scheme
on graph `g` with convolution layer `mp`.
of gnn layer `mp` on graph `g` .
Updates the node, edge, and global features `X`, `E`, and `U` respectively.
The computation involved is the following:
Expand All @@ -14,7 +15,7 @@ The computation involved is the following:
M = compute_batch_message(mp, g, X, E, U)
M̄ = aggregate_neighbors(mp, aggr, g, M)
X′ = update(mp, X, M̄, U)
E′ = update_edge(mp, M, E, U)
E′ = update_edge(mp, E, M, U)
U′ = update_global(mp, U, X′, E′)
```
Expand All @@ -25,7 +26,7 @@ this method in the forward pass:
```julia
function (l::MyLayer)(g, X)
... some prepocessing if needed ...
propagate(l, g, X, E, U, +)
propagate(l, g, +, X, E, U)
end
```
Expand All @@ -34,28 +35,28 @@ See also [`message`](@ref) and [`update`](@ref).
function propagate end

function propagate(mp, g::GNNGraph, aggr)
X, E, U = propagate(mp, g,
node_features(g), edge_features(g), global_features(g),
aggr)
GNNGraph(g, ndata=X, edata=E, gdata=U)
X, E, U = propagate(mp, g, aggr,
node_features(g), edge_features(g), global_features(g))

return GNNGraph(g, ndata=X, edata=E, gdata=U)
end

function propagate(mp, g::GNNGraph, E, X, U, aggr)
function propagate(mp, g::GNNGraph, aggr, X, E=nothing, U=nothing)
# TODO consider g.graph_indicator in propagating U
M = compute_batch_message(mp, g, E, X, U)
E = update_edge(mp, M, E, U)
= aggregate_neighbors(mp, aggr, g, M)
X = update(mp, M̄, X, U)
U = update_global(mp, E, X, U)
return E, X, U
M = compute_batch_message(mp, g, X, E, U)
= aggregate_neighbors(mp, g, aggr, M)
X′ = update(mp, X, M̄, U)
E′ = update_edge(mp, E, M, U)
U = update_global(mp, U, X′, E′)
return X′, E′, U
end

"""
message(mp, x_i, x_j, [e_ij, u])
Message function for the message-passing scheme,
returning the message from node `j` to node `i` .
In the message-passing scheme. the incoming messages
In the message-passing scheme, the incoming messages
from the neighborhood of `i` will later be aggregated
in order to [`update`](@ref) the features of node `i`.
Expand All @@ -75,7 +76,7 @@ See also [`update`](@ref) and [`propagate`](@ref).
function message end

"""
update(mp, m̄, x, [u])
update(mp, x, m̄, [u])
Update function for the message-passing scheme,
returning a new set of node features `x′` based on old
Expand All @@ -96,47 +97,45 @@ See also [`message`](@ref) and [`propagate`](@ref).
"""
function update end


_gather(x, i) = NNlib.gather(x, i)
_gather(x::Nothing, i) = nothing

## Step 1.

function compute_batch_message(mp, g, E, X, u)
function compute_batch_message(mp, g, X, E, U)
s, t = edge_index(g)
Xi = _gather(X, t)
Xj = _gather(X, s)
M = message(mp, Xi, Xj, E, u)
M = message(mp, Xi, Xj, E, U)
return M
end

# @inline message(mp, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future
@inline message(mp, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij)
@inline message(mp, x_i, x_j, e_ij) = message(mp, x_i, x_j)
@inline message(mp, x_i, x_j) = x_j

## Step 2

@inline update_edge(mp, M, E, u) = update_edge(mp, M, E)
@inline update_edge(mp, M, E) = E

## Step 3
## Step 2

function aggregate_neighbors(mp, aggr, g, E)
function aggregate_neighbors(mp, g, aggr, E)
s, t = edge_index(g)
NNlib.scatter(aggr, E, t)
end

aggregate_neighbors(mp, aggr::Nothing, g, E) = nothing
aggregate_neighbors(mp, g, aggr::Nothing, E) = nothing

## Step 3

@inline update(mp, x, m̄, u) = update(mp, x, m̄)
@inline update(mp, x, m̄) =

## Step 4

# @inline update(mp, i, m̄, x, u) = update(mp, m, x, u)
@inline update(mp, m̄, x, u) = update(mp, m̄, x)
@inline update(mp, m̄, x) =
@inline update_edge(mp, E, M, U) = update_edge(mp, E, M)
@inline update_edge(mp, E, M) = E

## Step 5

@inline update_global(mp, E, X, u) = u
@inline update_global(mp, U, X, E) = update_global(mp, U, X)
@inline update_global(mp, U, X) = U

### end steps ###
44 changes: 22 additions & 22 deletions test/cuda/msgpass.jl
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
in_channel = 10
out_channel = 5
N = 6
T = Float32
adj = [0 1 0 0 0 0
1 0 0 1 1 1
0 0 0 0 0 1
0 1 0 0 1 0
0 1 0 1 0 1
0 1 1 0 1 0]
@testset "cuda/msgpass" begin
in_channel = 10
out_channel = 5
N = 6
T = Float32
adj = [0 1 0 0 0 0
1 0 0 1 1 1
0 0 0 0 0 1
0 1 0 0 1 0
0 1 0 1 0 1
0 1 1 0 1 0]

struct NewCudaLayer
weight
end
NewCudaLayer(m, n) = NewCudaLayer(randn(T, m,n))
@functor NewCudaLayer
struct NewCudaLayer
weight
end
NewCudaLayer(m, n) = NewCudaLayer(randn(T, m, n))
@functor NewCudaLayer

(l::NewCudaLayer)(X) = GraphNeuralNetworks.propagate(l, X, +)
GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
GraphNeuralNetworks.update(::NewCudaLayer, m, x) = m
(l::NewCudaLayer)(g, X) = GraphNeuralNetworks.propagate(l, g, +, X)
GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j
GraphNeuralNetworks.update(::NewCudaLayer, x, m) = m

X = rand(T, in_channel, N) |> gpu
g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T)
l = NewCudaLayer(out_channel, in_channel) |> gpu
X = rand(T, in_channel, N) |> gpu
g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T)
l = NewCudaLayer(out_channel, in_channel) |> gpu

@testset "cuda/msgpass" begin
g_ = l(g)
@test size(node_features(g_)) == (out_channel, N)
end
Loading

0 comments on commit b5512db

Please sign in to comment.