diff --git a/GNNLux/src/GNNLux.jl b/GNNLux/src/GNNLux.jl index 16d129632..0dc2802fa 100644 --- a/GNNLux/src/GNNLux.jl +++ b/GNNLux/src/GNNLux.jl @@ -1,15 +1,22 @@ module GNNLux using ConcreteStructs: @concrete using NNlib: NNlib -using LuxCore: LuxCore, AbstractExplicitLayer -using Lux: glorot_uniform, zeros32 +using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer +using Lux: Lux, glorot_uniform, zeros32 using Reexport: @reexport using Random: AbstractRNG using GNNlib: GNNlib @reexport using GNNGraphs +include("layers/basic.jl") +export GNNLayer, + GNNContainerLayer, + GNNChain + include("layers/conv.jl") -export GraphConv +export GCNConv, + ChebConv, + GraphConv end #module \ No newline at end of file diff --git a/GNNLux/src/layers/basic.jl b/GNNLux/src/layers/basic.jl new file mode 100644 index 000000000..32f33bdbb --- /dev/null +++ b/GNNLux/src/layers/basic.jl @@ -0,0 +1,61 @@ +""" + abstract type GNNLayer <: AbstractExplicitLayer end + +An abstract type from which graph neural network layers are derived. +It is Derived from Lux's `AbstractExplicitLayer` type. + +See also [`GNNChain`](@ref GNNLux.GNNChain). +""" +abstract type GNNLayer <: AbstractExplicitLayer end + +abstract type GNNContainerLayer{T} <: AbstractExplicitContainerLayer{T} end + +@concrete struct GNNChain <: GNNContainerLayer{(:layers,)} + layers <: NamedTuple +end + +GNNChain(xs...) = GNNChain(; (Symbol("layer_", i) => x for (i, x) in enumerate(xs))...) + +function GNNChain(; kw...) + :layers in Base.keys(kw) && + throw(ArgumentError("a GNNChain cannot have a named layer called `layers`")) + nt = NamedTuple{keys(kw)}(values(kw)) + nt = map(_wrapforchain, nt) + return GNNChain(nt) +end + +_wrapforchain(l::AbstractExplicitLayer) = l +_wrapforchain(l) = Lux.WrappedFunction(l) + +Base.keys(c::GNNChain) = Base.keys(getfield(c, :layers)) +Base.getindex(c::GNNChain, i::Int) = c.layers[i] +Base.getindex(c::GNNChain, i::AbstractVector) = GNNChain(NamedTuple{keys(c)[i]}(Tuple(c.layers)[i])) + +function Base.getproperty(c::GNNChain, name::Symbol) + hasfield(typeof(c), name) && return getfield(c, name) + layers = getfield(c, :layers) + hasfield(typeof(layers), name) && return getfield(layers, name) + throw(ArgumentError("$(typeof(c)) has no field or layer $name")) +end + +Base.length(c::GNNChain) = length(c.layers) +Base.lastindex(c::GNNChain) = lastindex(c.layers) +Base.firstindex(c::GNNChain) = firstindex(c.layers) + +LuxCore.outputsize(c::GNNChain) = LuxCore.outputsize(c.layers[end]) + +(c::GNNChain)(g::GNNGraph, x, ps, st) = _applychain(c.layers, g, x, ps, st) + +function _applychain(layers, g::GNNGraph, x, ps, st) # type-unstable path, helps compile times + newst = (;) + for (name, l) in pairs(layers) + x, s′ = _applylayer(l, g, x, getproperty(ps, name), getproperty(st, name)) + newst = merge(newst, (; name => s′)) + end + return x, newst +end + +_applylayer(l, g::GNNGraph, x, ps, st) = l(x), (;) +_applylayer(l::AbstractExplicitLayer, g::GNNGraph, x, ps, st) = l(x, ps, st) +_applylayer(l::GNNLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) +_applylayer(l::GNNContainerLayer, g::GNNGraph, x, ps, st) = l(g, x, ps, st) diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index 671e55d62..54627522e 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -1,54 +1,132 @@ +# Missing Layers + +# | Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs | +# | :-------- | :---: |:---: |:---: | :---: | :---: | +# | [`AGNNConv`](@ref) | | | ✓ | | | +# | [`CGConv`](@ref) | | | ✓ | ✓ | ✓ | +# | [`EGNNConv`](@ref) | | | ✓ | | | +# | [`EdgeConv`](@ref) | | | | ✓ | | +# | [`GATConv`](@ref) | | | ✓ | ✓ | ✓ | +# | [`GATv2Conv`](@ref) | | | ✓ | ✓ | ✓ | +# | [`GatedGraphConv`](@ref) | ✓ | | | | ✓ | +# | [`GINConv`](@ref) | ✓ | | | ✓ | ✓ | +# | [`GMMConv`](@ref) | | | ✓ | | | +# | [`MEGNetConv`](@ref) | | | ✓ | | | +# | [`NNConv`](@ref) | | | ✓ | | | +# | [`ResGatedGraphConv`](@ref) | | | | ✓ | ✓ | +# | [`SAGEConv`](@ref) | ✓ | | | ✓ | ✓ | +# | [`SGConv`](@ref) | ✓ | | | | ✓ | +# | [`TransformerConv`](@ref) | | | ✓ | | | + + +@concrete struct GCNConv <: GNNLayer + in_dims::Int + out_dims::Int + use_bias::Bool + add_self_loops::Bool + use_edge_weight::Bool + init_weight + init_bias + σ +end -@doc raw""" - GraphConv(in => out, σ=identity; aggr=+, bias=true, init=glorot_uniform) +function GCNConv(ch::Pair{Int, Int}, σ = identity; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + add_self_loops::Bool = true, + use_edge_weight::Bool = false, + allow_fast_activation::Bool = true) + in_dims, out_dims = ch + σ = allow_fast_activation ? NNlib.fast_act(σ) : σ + return GCNConv(in_dims, out_dims, use_bias, add_self_loops, use_edge_weight, init_weight, init_bias, σ) +end -Graph convolution layer from Reference: [Weisfeiler and Leman Go Neural: Higher-order Graph Neural Networks](https://arxiv.org/abs/1810.02244). +function LuxCore.initialparameters(rng::AbstractRNG, l::GCNConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end -Performs: -```math -\mathbf{x}_i' = W_1 \mathbf{x}_i + \square_{j \in \mathcal{N}(i)} W_2 \mathbf{x}_j -``` +LuxCore.parameterlength(l::GCNConv) = l.use_bias ? l.in_dims * l.out_dims + l.out_dims : l.in_dims * l.out_dims +LuxCore.statelength(d::GCNConv) = 0 +LuxCore.outputsize(d::GCNConv) = (d.out_dims,) -where the aggregation type is selected by `aggr`. +function Base.show(io::IO, l::GCNConv) + print(io, "GCNConv(", l.in_dims, " => ", l.out_dims) + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + l.add_self_loops || print(io, ", add_self_loops=false") + !l.use_edge_weight || print(io, ", use_edge_weight=true") + print(io, ")") +end -# Arguments +# TODO norm_fn should be keyword argument only +(l::GCNConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing, norm_fn= d -> 1 ./ sqrt.(d)) = + l(g, x, edge_weight, norm_fn, ps, st; conv_weight) +(l::GCNConv)(g, x, edge_weight, ps, st; conv_weight=nothing, norm_fn = d -> 1 ./ sqrt.(d)) = + l(g, x, edge_weight, norm_fn, ps, st; conv_weight) +(l::GCNConv)(g, x, edge_weight, norm_fn, ps, st; conv_weight=nothing) = + GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight, ps), st -- `in`: The dimension of input features. -- `out`: The dimension of output features. -- `σ`: Activation function. -- `aggr`: Aggregation operator for the incoming messages (e.g. `+`, `*`, `max`, `min`, and `mean`). -- `bias`: Add learnable bias. -- `init`: Weights' initializer. +@concrete struct ChebConv <: GNNLayer + in_dims::Int + out_dims::Int + use_bias::Bool + k::Int + init_weight + init_bias + σ +end -# Examples +function ChebConv(ch::Pair{Int, Int}, k::Int, σ = identity; + init_weight = glorot_uniform, + init_bias = zeros32, + use_bias::Bool = true, + allow_fast_activation::Bool = true) + in_dims, out_dims = ch + σ = allow_fast_activation ? NNlib.fast_act(σ) : σ + return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias, σ) +end -```julia -# create data -s = [1,1,2,3] -t = [2,3,1,1] -in_channel = 3 -out_channel = 5 -g = GNNGraph(s, t) -x = randn(Float32, 3, g.num_nodes) +function LuxCore.initialparameters(rng::AbstractRNG, l::ChebConv) + weight = l.init_weight(rng, l.out_dims, l.in_dims, l.k) + if l.use_bias + bias = l.init_bias(rng, l.out_dims) + return (; weight, bias) + else + return (; weight) + end +end + +LuxCore.parameterlength(l::ChebConv) = l.use_bias ? l.in_dims * l.out_dims * l.k + l.out_dims : + l.in_dims * l.out_dims * l.k +LuxCore.statelength(d::ChebConv) = 0 +LuxCore.outputsize(d::ChebConv) = (d.out_dims,) + +function Base.show(io::IO, l::ChebConv) + print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", K=", l.K) + l.σ == identity || print(io, ", ", l.σ) + l.use_bias || print(io, ", use_bias=false") + print(io, ")") +end -# create layer -l = GraphConv(in_channel => out_channel, relu, bias = false, aggr = mean) +(l::ChebConv)(g, x, ps, st) = GNNlib.cheb_conv(l, g, x, ps), st -# forward pass -y = l(g, x) -``` -""" -@concrete struct GraphConv <: AbstractExplicitLayer +@concrete struct GraphConv <: GNNLayer in_dims::Int out_dims::Int use_bias::Bool - init_weight::Function - init_bias::Function + init_weight + init_bias σ aggr end - function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +, init_weight = glorot_uniform, @@ -65,10 +143,10 @@ function LuxCore.initialparameters(rng::AbstractRNG, l::GraphConv) weight2 = l.init_weight(rng, l.out_dims, l.in_dims) if l.use_bias bias = l.init_bias(rng, l.out_dims) + return (; weight1, weight2, bias) else - bias = false + return (; weight1, weight2) end - return (; weight1, weight2, bias) end function LuxCore.parameterlength(l::GraphConv) @@ -90,4 +168,4 @@ function Base.show(io::IO, l::GraphConv) print(io, ")") end -(l::GraphConv)(g::GNNGraph, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st +(l::GraphConv)(g, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st diff --git a/GNNLux/test/layers/basic_tests.jl b/GNNLux/test/layers/basic_tests.jl new file mode 100644 index 000000000..11a1d3a29 --- /dev/null +++ b/GNNLux/test/layers/basic_tests.jl @@ -0,0 +1,24 @@ +@testitem "layers/basic" setup=[SharedTestSetup] begin + rng = StableRNG(17) + g = rand_graph(10, 40, seed=17) + x = randn(rng, Float32, 3, 10) + + @testset "GNNLayer" begin + @test GNNLayer <: LuxCore.AbstractExplicitLayer + end + + @testset "GNNChain" begin + @test GNNChain <: LuxCore.AbstractExplicitContainerLayer{(:layers,)} + @test GNNChain <: GNNContainerLayer + c = GNNChain(GraphConv(3 => 5, relu), GCNConv(5 => 3)) + ps = LuxCore.initialparameters(rng, c) + st = LuxCore.initialstates(rng, c) + @test LuxCore.parameterlength(c) == LuxCore.parameterlength(ps) + @test LuxCore.statelength(c) == LuxCore.statelength(st) + y, st′ = c(g, x, ps, st) + @test LuxCore.outputsize(c) == (3,) + @test size(y) == (3, 10) + loss = (x, ps) -> sum(first(c(g, x, ps, st))) + @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true + end +end diff --git a/GNNLux/test/layers/conv_tests.jl b/GNNLux/test/layers/conv_tests.jl index 962188aff..9483ee822 100644 --- a/GNNLux/test/layers/conv_tests.jl +++ b/GNNLux/test/layers/conv_tests.jl @@ -1,10 +1,41 @@ @testitem "layers/conv" setup=[SharedTestSetup] begin rng = StableRNG(1234) - g = rand_graph(10, 30, seed=1234) + g = rand_graph(10, 40, seed=1234) x = randn(rng, Float32, 3, 10) + @testset "GCNConv" begin + l = GCNConv(3 => 5, relu) + @test l isa GNNLayer + ps = Lux.initialparameters(rng, l) + st = Lux.initialstates(rng, l) + @test Lux.parameterlength(l) == Lux.parameterlength(ps) + @test Lux.statelength(l) == Lux.statelength(st) + + y, _ = l(g, x, ps, st) + @test Lux.outputsize(l) == (5,) + @test size(y) == (5, 10) + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true + end + + @testset "ChebConv" begin + l = ChebConv(3 => 5, 2, relu) + @test l isa GNNLayer + ps = Lux.initialparameters(rng, l) + st = Lux.initialstates(rng, l) + @test Lux.parameterlength(l) == Lux.parameterlength(ps) + @test Lux.statelength(l) == Lux.statelength(st) + + y, _ = l(g, x, ps, st) + @test Lux.outputsize(l) == (5,) + @test size(y) == (5, 10) + loss = (x, ps) -> sum(first(l(g, x, ps, st))) + @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true + end + @testset "GraphConv" begin l = GraphConv(3 => 5, relu) + @test l isa GNNLayer ps = Lux.initialparameters(rng, l) st = Lux.initialstates(rng, l) @test Lux.parameterlength(l) == Lux.parameterlength(ps) @@ -14,6 +45,6 @@ @test Lux.outputsize(l) == (5,) @test size(y) == (5, 10) loss = (x, ps) -> sum(first(l(g, x, ps, st))) - @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 + @eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true end end diff --git a/GNNlib/src/layers/conv.jl b/GNNlib/src/layers/conv.jl index a0a3ccf6a..f3fbac984 100644 --- a/GNNlib/src/layers/conv.jl +++ b/GNNlib/src/layers/conv.jl @@ -10,12 +10,17 @@ end check_gcnconv_input(g::AbstractGNNGraph, edge_weight::Nothing) = nothing -function gcn_conv(l, g::AbstractGNNGraph, x, - edge_weight::EW = nothing, - norm_fn::Function = d -> 1 ./ sqrt.(d) - ) where {EW <: Union{Nothing, AbstractVector}} - +function gcn_conv(l, g::AbstractGNNGraph, x, edge_weight::EW, norm_fn::F, conv_weight::CW, ps) where + {EW <: Union{Nothing, AbstractVector}, CW<:Union{Nothing,AbstractMatrix}, F} check_gcnconv_input(g, edge_weight) + if conv_weight === nothing + weight = ps.weight + else + weight = conv_weight + if size(weight) != size(ps.weight) + throw(ArgumentError("The weight matrix has the wrong size. Expected $(size(ps.weight)) but got $(size(weight))")) + end + end if l.add_self_loops g = add_self_loops(g) @@ -26,11 +31,11 @@ function gcn_conv(l, g::AbstractGNNGraph, x, @assert length(edge_weight) == g.num_edges end end - Dout, Din = size(l.weight) + Dout, Din = size(weight) if Dout < Din && !(g isa GNNHeteroGraph) # multiply before convolution if it is more convenient, otherwise multiply after # (this works only for homogenous graph) - x = l.weight * x + x = weight * x end xj, xi = expand_srcdst(g, x) # expand only after potential multiplication @@ -60,34 +65,38 @@ function gcn_conv(l, g::AbstractGNNGraph, x, end x = x .* cin' if Dout >= Din || g isa GNNHeteroGraph - x = l.weight * x + x = weight * x end - return l.σ.(x .+ l.bias) + if l.use_bias + x = x .+ ps.bias + end + return l.σ.(x) end -function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, - edge_weight::AbstractVector, norm_fn::Function) - +# when we also have edge_weight we need to convert the graph to COO +function gcn_conv(l, g::GNNGraph{<:ADJMAT_T}, x::AbstractMatrix, edge_weight::AbstractVector, norm_fn::F, ps) where F g = GNNGraph(edge_index(g)...; g.num_nodes) # convert to COO - return gcn_conv(l, g, x, edge_weight, norm_fn) + return gcn_conv(l, g, x, edge_weight, norm_fn, ps) end - -function cheb_conv(c, g::GNNGraph, X::AbstractMatrix{T}) where {T} +function cheb_conv(c, g::GNNGraph, X::AbstractMatrix{T}, ps) where {T} check_num_nodes(g, X) - @assert size(X, 1)==size(c.weight, 2) "Input feature size must match input channel size." + @assert size(X, 1) == size(ps.weight, 2) "Input feature size must match input channel size." L̃ = scaled_laplacian(g, eltype(X)) Z_prev = X Z = X * L̃ - Y = view(c.weight, :, :, 1) * Z_prev - Y += view(c.weight, :, :, 2) * Z + Y = view(ps.weight, :, :, 1) * Z_prev + Y = Y .+ view(ps.weight, :, :, 2) * Z for k in 3:(c.k) Z, Z_prev = 2 * Z * L̃ - Z_prev, Z - Y += view(c.weight, :, :, k) * Z + Y = Y .+ view(ps.weight, :, :, k) * Z + end + if c.use_bias + Y = Y .+ ps.bias end - return Y .+ c.bias + return Y end function graph_conv(l, g::AbstractGNNGraph, x, ps) diff --git a/docs/src/dev.md b/docs/src/dev.md index 2f9e9d83e..2a2aae370 100644 --- a/docs/src/dev.md +++ b/docs/src/dev.md @@ -2,7 +2,7 @@ ## Develop and Managing the Monorepo -### Development +### Development Enviroment GraphNeuralNetworks.jl is package hosted in a monorepo that contains multiple packages. The GraphNeuralNetworks.jl package depends on GNNGraphs.jl, also hosted in the same monorepo. @@ -12,6 +12,14 @@ pkg> activate . pkg> dev ./GNNGraphs ``` +### Add a New Layer + +To add a new graph convolutional layer and make it available in both the Flux-based frontend (GraphNeuralNetworks.jl) and the Lux-based frontend (GNNLux), you need to: +1. Add the functional version to GNNlib +2. Add the stateful version to GraphNeuralNetworks +3. Add the stateless version to GNNLux +4. Add the layer to the table in docs/api/conv.md + ### Versions and Tagging Each PR should update the version number in the Porject.toml file of each involved package if needed by semnatic versioning. For instance, when adding new features GNNGraphs could move from "1.17.5" to "1.18.0-DEV". The "DEV" will be removed when the package is tagged and released. Pay also attention to updating the compat bounds, e.g. GraphNeuralNetworks might require a newer version of GNNGraphs. diff --git a/test/layers/conv.jl b/test/layers/conv.jl index 9e3a3f424..f6846010a 100644 --- a/test/layers/conv.jl +++ b/test/layers/conv.jl @@ -73,7 +73,6 @@ test_graphs = [g1, g_single_vertex] a = rand(T, in_channel, N) g2 = GNNGraph(adj1, ndata = a) @test l(g2, g2.ndata.x, conv_weight = w) == w * a - end end