From b5512db8957a19cb770d44a68996db4408fc364c Mon Sep 17 00:00:00 2001 From: Carlo Lucibello Date: Thu, 9 Sep 2021 08:43:50 +0200 Subject: [PATCH] change ordering in propagate/update --- src/layers/conv.jl | 42 +++++++++++------------ src/msgpass.jl | 65 ++++++++++++++++++----------------- test/cuda/msgpass.jl | 44 ++++++++++++------------ test/msgpass.jl | 80 ++++++++++++++++++++++---------------------- 4 files changed, 115 insertions(+), 116 deletions(-) diff --git a/src/layers/conv.jl b/src/layers/conv.jl index 17b04b582..b1044aba6 100644 --- a/src/layers/conv.jl +++ b/src/layers/conv.jl @@ -46,13 +46,13 @@ function (l::GCNConv)(g::GNNGraph, x::AbstractMatrix{T}) where T end message(l::GCNConv, xi, xj) = xj -update(l::GCNConv, m, x) = m +update(l::GCNConv, x, m) = m function (l::GCNConv)(g::GNNGraph, x::CuMatrix{T}) where T g = add_self_loops(g) c = 1 ./ sqrt.(degree(g, T, dir=:in)) x = x .* c' - _, x = propagate(l, g, nothing, x, nothing, +) + x, _ = propagate(l, g, +, x) x = x .* c' return l.σ.(l.weight * x .+ l.bias) end @@ -176,12 +176,12 @@ function GraphConv(ch::Pair{Int,Int}, σ=identity, aggr=+; GraphConv(W1, W2, b, σ, aggr) end -message(gc::GraphConv, x_i, x_j, e_ij) = x_j -update(gc::GraphConv, m, x) = gc.σ.(gc.weight1 * x .+ gc.weight2 * m .+ gc.bias) +message(l::GraphConv, x_i, x_j, e_ij) = x_j +update(l::GraphConv, x, m) = l.σ.(l.weight1 * x .+ l.weight2 * m .+ l.bias) -function (gc::GraphConv)(g::GNNGraph, x::AbstractMatrix) +function (l::GraphConv)(g::GNNGraph, x::AbstractMatrix) check_num_nodes(g, x) - _, x = propagate(gc, g, nothing, x, nothing, +) + x, _ = propagate(l, g, +, x) x end @@ -325,23 +325,23 @@ end message(l::GatedGraphConv, x_i, x_j, e_ij) = x_j -update(l::GatedGraphConv, m, x) = m +update(l::GatedGraphConv, x, m) = m # remove after https://github.com/JuliaDiff/ChainRules.jl/pull/521 @non_differentiable fill!(x...) -function (ggc::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real} +function (l::GatedGraphConv)(g::GNNGraph, H::AbstractMatrix{S}) where {T<:AbstractVector,S<:Real} check_num_nodes(g, H) m, n = size(H) - @assert (m <= ggc.out_ch) "number of input features must less or equals to output features." - if m < ggc.out_ch - Hpad = similar(H, S, ggc.out_ch - m, n) + @assert (m <= l.out_ch) "number of input features must less or equals to output features." + if m < l.out_ch + Hpad = similar(H, S, l.out_ch - m, n) H = vcat(H, fill!(Hpad, 0)) end - for i = 1:ggc.num_layers - M = view(ggc.weight, :, :, i) * H - _, M = propagate(ggc, g, nothing, M, nothing, +) - H, _ = ggc.gru(H, M) + for i = 1:l.num_layers + M = view(l.weight, :, :, i) * H + M, _ = propagate(l, g, +, M) + H, _ = l.gru(H, M) end H end @@ -379,13 +379,13 @@ end EdgeConv(nn; aggr=max) = EdgeConv(nn, aggr) -message(ec::EdgeConv, x_i, x_j, e_ij) = ec.nn(vcat(x_i, x_j .- x_i)) +message(l::EdgeConv, x_i, x_j, e_ij) = l.nn(vcat(x_i, x_j .- x_i)) -update(ec::EdgeConv, m, x) = m +update(l::EdgeConv, x, m) = m -function (ec::EdgeConv)(g::GNNGraph, X::AbstractMatrix) +function (l::EdgeConv)(g::GNNGraph, X::AbstractMatrix) check_num_nodes(g, X) - _, X = propagate(ec, g, nothing, X, nothing, ec.aggr) + X, _ = propagate(l, g, +, X) X end @@ -425,10 +425,10 @@ function GINConv(nn; eps=0f0) end message(l::GINConv, x_i, x_j) = x_j -update(l::GINConv, m, x) = l.nn((1 + l.eps) * x + m) +update(l::GINConv, x, m) = l.nn((1 + l.eps) * x + m) function (l::GINConv)(g::GNNGraph, X::AbstractMatrix) check_num_nodes(g, X) - _, X = propagate(l, g, nothing, X, nothing, +) + X, _ = propagate(l, g, +, X) X end diff --git a/src/msgpass.jl b/src/msgpass.jl index ef52a2567..5a868ffb1 100644 --- a/src/msgpass.jl +++ b/src/msgpass.jl @@ -2,10 +2,11 @@ # "Relational inductive biases, deep learning, and graph networks" """ - propagate(mp, g, X, E, U, aggr) + propagate(mp, g, aggr, [X, E, U]) -> X′, E′, U′ + propagate(mp, g, aggr) -> g′ Perform the sequence of operations implementing the message-passing scheme -on graph `g` with convolution layer `mp`. +of gnn layer `mp` on graph `g` . Updates the node, edge, and global features `X`, `E`, and `U` respectively. The computation involved is the following: @@ -14,7 +15,7 @@ The computation involved is the following: M = compute_batch_message(mp, g, X, E, U) M̄ = aggregate_neighbors(mp, aggr, g, M) X′ = update(mp, X, M̄, U) -E′ = update_edge(mp, M, E, U) +E′ = update_edge(mp, E, M, U) U′ = update_global(mp, U, X′, E′) ``` @@ -25,7 +26,7 @@ this method in the forward pass: ```julia function (l::MyLayer)(g, X) ... some prepocessing if needed ... - propagate(l, g, X, E, U, +) + propagate(l, g, +, X, E, U) end ``` @@ -34,20 +35,20 @@ See also [`message`](@ref) and [`update`](@ref). function propagate end function propagate(mp, g::GNNGraph, aggr) - X, E, U = propagate(mp, g, - node_features(g), edge_features(g), global_features(g), - aggr) - GNNGraph(g, ndata=X, edata=E, gdata=U) + X, E, U = propagate(mp, g, aggr, + node_features(g), edge_features(g), global_features(g)) + + return GNNGraph(g, ndata=X, edata=E, gdata=U) end -function propagate(mp, g::GNNGraph, E, X, U, aggr) +function propagate(mp, g::GNNGraph, aggr, X, E=nothing, U=nothing) # TODO consider g.graph_indicator in propagating U - M = compute_batch_message(mp, g, E, X, U) - E = update_edge(mp, M, E, U) - M̄ = aggregate_neighbors(mp, aggr, g, M) - X = update(mp, M̄, X, U) - U = update_global(mp, E, X, U) - return E, X, U + M = compute_batch_message(mp, g, X, E, U) + M̄ = aggregate_neighbors(mp, g, aggr, M) + X′ = update(mp, X, M̄, U) + E′ = update_edge(mp, E, M, U) + U′ = update_global(mp, U, X′, E′) + return X′, E′, U′ end """ @@ -55,7 +56,7 @@ end Message function for the message-passing scheme, returning the message from node `j` to node `i` . -In the message-passing scheme. the incoming messages +In the message-passing scheme, the incoming messages from the neighborhood of `i` will later be aggregated in order to [`update`](@ref) the features of node `i`. @@ -75,7 +76,7 @@ See also [`update`](@ref) and [`propagate`](@ref). function message end """ - update(mp, m̄, x, [u]) + update(mp, x, m̄, [u]) Update function for the message-passing scheme, returning a new set of node features `x′` based on old @@ -96,47 +97,45 @@ See also [`message`](@ref) and [`propagate`](@ref). """ function update end - _gather(x, i) = NNlib.gather(x, i) _gather(x::Nothing, i) = nothing ## Step 1. -function compute_batch_message(mp, g, E, X, u) +function compute_batch_message(mp, g, X, E, U) s, t = edge_index(g) Xi = _gather(X, t) Xj = _gather(X, s) - M = message(mp, Xi, Xj, E, u) + M = message(mp, Xi, Xj, E, U) return M end -# @inline message(mp, i, j, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij, u) # TODO add in the future @inline message(mp, x_i, x_j, e_ij, u) = message(mp, x_i, x_j, e_ij) @inline message(mp, x_i, x_j, e_ij) = message(mp, x_i, x_j) @inline message(mp, x_i, x_j) = x_j -## Step 2 - -@inline update_edge(mp, M, E, u) = update_edge(mp, M, E) -@inline update_edge(mp, M, E) = E - -## Step 3 +## Step 2 -function aggregate_neighbors(mp, aggr, g, E) +function aggregate_neighbors(mp, g, aggr, E) s, t = edge_index(g) NNlib.scatter(aggr, E, t) end -aggregate_neighbors(mp, aggr::Nothing, g, E) = nothing +aggregate_neighbors(mp, g, aggr::Nothing, E) = nothing + +## Step 3 + +@inline update(mp, x, m̄, u) = update(mp, x, m̄) +@inline update(mp, x, m̄) = m̄ ## Step 4 -# @inline update(mp, i, m̄, x, u) = update(mp, m, x, u) -@inline update(mp, m̄, x, u) = update(mp, m̄, x) -@inline update(mp, m̄, x) = m̄ +@inline update_edge(mp, E, M, U) = update_edge(mp, E, M) +@inline update_edge(mp, E, M) = E ## Step 5 -@inline update_global(mp, E, X, u) = u +@inline update_global(mp, U, X, E) = update_global(mp, U, X) +@inline update_global(mp, U, X) = U ### end steps ### diff --git a/test/cuda/msgpass.jl b/test/cuda/msgpass.jl index 30e323a82..601e6cea0 100644 --- a/test/cuda/msgpass.jl +++ b/test/cuda/msgpass.jl @@ -1,29 +1,29 @@ -in_channel = 10 -out_channel = 5 -N = 6 -T = Float32 -adj = [0 1 0 0 0 0 - 1 0 0 1 1 1 - 0 0 0 0 0 1 - 0 1 0 0 1 0 - 0 1 0 1 0 1 - 0 1 1 0 1 0] +@testset "cuda/msgpass" begin + in_channel = 10 + out_channel = 5 + N = 6 + T = Float32 + adj = [0 1 0 0 0 0 + 1 0 0 1 1 1 + 0 0 0 0 0 1 + 0 1 0 0 1 0 + 0 1 0 1 0 1 + 0 1 1 0 1 0] -struct NewCudaLayer - weight -end -NewCudaLayer(m, n) = NewCudaLayer(randn(T, m,n)) -@functor NewCudaLayer + struct NewCudaLayer + weight + end + NewCudaLayer(m, n) = NewCudaLayer(randn(T, m, n)) + @functor NewCudaLayer -(l::NewCudaLayer)(X) = GraphNeuralNetworks.propagate(l, X, +) -GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j -GraphNeuralNetworks.update(::NewCudaLayer, m, x) = m + (l::NewCudaLayer)(g, X) = GraphNeuralNetworks.propagate(l, g, +, X) + GraphNeuralNetworks.message(n::NewCudaLayer, x_i, x_j, e_ij) = n.weight * x_j + GraphNeuralNetworks.update(::NewCudaLayer, x, m) = m -X = rand(T, in_channel, N) |> gpu -g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T) -l = NewCudaLayer(out_channel, in_channel) |> gpu + X = rand(T, in_channel, N) |> gpu + g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T) + l = NewCudaLayer(out_channel, in_channel) |> gpu -@testset "cuda/msgpass" begin g_ = l(g) @test size(node_features(g_)) == (out_channel, N) end diff --git a/test/msgpass.jl b/test/msgpass.jl index d6dea5e05..cf6db009d 100644 --- a/test/msgpass.jl +++ b/test/msgpass.jl @@ -16,7 +16,7 @@ X = rand(T, in_channel, num_V) E = rand(T, in_channel, num_E) - u = rand(T, in_channel) + U = rand(T, in_channel) @testset "no aggregation" begin @@ -24,70 +24,70 @@ (l::NewLayer{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, nothing) g = GNNGraph(adj, ndata=X, graph_type=GRAPH_T) - fg_ = l(g) + g_ = l(g) - @test adjacency_matrix(fg_) == adj - @test node_features(fg_) === nothing - @test edge_features(fg_) === nothing - @test global_features(fg_) === nothing + @test adjacency_matrix(g_) == adj + @test node_features(g_) === nothing + @test edge_features(g_) === nothing + @test global_features(g_) === nothing end @testset "neighbor aggregation (+)" begin l = NewLayer{GRAPH_T}() (l::NewLayer{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +) - g = GNNGraph(adj, ndata=X, edata=E, gdata=u, graph_type=GRAPH_T) - fg_ = l(g) + g = GNNGraph(adj, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T) + g_ = l(g) - @test adjacency_matrix(fg_) == adj - @test size(node_features(fg_)) == (in_channel, num_V) - @test edge_features(fg_) ≈ E - @test global_features(fg_) ≈ u + @test adjacency_matrix(g_) == adj + @test size(node_features(g_)) == (in_channel, num_V) + @test edge_features(g_) ≈ E + @test global_features(g_) ≈ U end - GraphNeuralNetworks.message(l::NewLayer{GRAPH_T}, xi, xj, e, u) = ones(T, out_channel, size(e,2)) + GraphNeuralNetworks.message(l::NewLayer{GRAPH_T}, xi, xj, e, U) = ones(T, out_channel, size(e,2)) @testset "custom message and neighbor aggregation" begin l = NewLayer{GRAPH_T}() (l::NewLayer{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +) - g = GNNGraph(adj, ndata=X, edata=E, gdata=u, graph_type=GRAPH_T) - fg_ = l(g) + g = GNNGraph(adj, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T) + g_ = l(g) - @test adjacency_matrix(fg_) == adj - @test size(node_features(fg_)) == (out_channel, num_V) - @test edge_features(fg_) ≈ edge_features(g) - @test global_features(fg_) ≈ global_features(g) + @test adjacency_matrix(g_) == adj + @test size(node_features(g_)) == (out_channel, num_V) + @test edge_features(g_) ≈ edge_features(g) + @test global_features(g_) ≈ global_features(g) end - GraphNeuralNetworks.update_edge(l::NewLayer{GRAPH_T}, m, e) = m + GraphNeuralNetworks.update_edge(l::NewLayer{GRAPH_T}, e, m) = m @testset "update_edge" begin l = NewLayer{GRAPH_T}() (l::NewLayer{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +) - g = GNNGraph(adj, ndata=X, edata=E, gdata=u, graph_type=GRAPH_T) - fg_ = l(g) + g = GNNGraph(adj, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T) + g_ = l(g) - @test adjacency_matrix(fg_) == adj - @test size(node_features(fg_)) == (out_channel, num_V) - @test size(edge_features(fg_)) == (out_channel, num_E) - @test global_features(fg_) ≈ global_features(g) + @test adjacency_matrix(g_) == adj + @test size(node_features(g_)) == (out_channel, num_V) + @test size(edge_features(g_)) == (out_channel, num_E) + @test global_features(g_) ≈ global_features(g) end - GraphNeuralNetworks.update(l::NewLayer{GRAPH_T}, m̄, xi, u) = rand(T, 2*out_channel, size(xi, 2)) + GraphNeuralNetworks.update(l::NewLayer{GRAPH_T}, m̄, xi, U) = rand(T, 2*out_channel, size(xi, 2)) @testset "update edge/vertex" begin l = NewLayer{GRAPH_T}() (l::NewLayer{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +) - g = GNNGraph(adj, ndata=X, edata=E, gdata=u, graph_type=GRAPH_T) - fg_ = l(g) + g = GNNGraph(adj, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T) + g_ = l(g) - @test all(adjacency_matrix(fg_) .== adj) - @test size(node_features(fg_)) == (2*out_channel, num_V) - @test size(edge_features(fg_)) == (out_channel, num_E) - @test size(global_features(fg_)) == (in_channel,) + @test all(adjacency_matrix(g_) .== adj) + @test size(node_features(g_)) == (2*out_channel, num_V) + @test size(edge_features(g_)) == (out_channel, num_E) + @test size(global_features(g_)) == (in_channel,) end struct NewLayerW{G} @@ -97,18 +97,18 @@ NewLayerW(in, out) = NewLayerW{GRAPH_T}(randn(T, out, in)) GraphNeuralNetworks.message(l::NewLayerW{GRAPH_T}, x_i, x_j, e_ij) = l.weight * x_j - GraphNeuralNetworks.update(l::NewLayerW{GRAPH_T}, m, x) = l.weight * x + m + GraphNeuralNetworks.update(l::NewLayerW{GRAPH_T}, x, m) = l.weight * x + m @testset "message and update with weights" begin l = NewLayerW(in_channel, out_channel) (l::NewLayerW{GRAPH_T})(g) = GraphNeuralNetworks.propagate(l, g, +) - g = GNNGraph(adj, ndata=X, edata=E, gdata=u, graph_type=GRAPH_T) - fg_ = l(g) + g = GNNGraph(adj, ndata=X, edata=E, gdata=U, graph_type=GRAPH_T) + g_ = l(g) - @test adjacency_matrix(fg_) == adj - @test size(node_features(fg_)) == (out_channel, num_V) - @test edge_features(fg_) === E - @test global_features(fg_) === u + @test adjacency_matrix(g_) == adj + @test size(node_features(g_)) == (out_channel, num_V) + @test edge_features(g_) === E + @test global_features(g_) === U end end