Skip to content

Commit

Permalink
[GNNLux] more layers (#463)
Browse files Browse the repository at this point in the history
* move to MLDataDevices

* cg_conv

* edgeconv working

* cleanup
  • Loading branch information
CarloLucibello authored Jul 28, 2024
1 parent 3b42087 commit 80c672a
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 81 deletions.
4 changes: 2 additions & 2 deletions GNNGraphs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LuxDeviceUtils = "34f89e08-e1d5-43b4-8944-0b49ac560553"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce"
Expand All @@ -35,7 +35,7 @@ Functors = "0.4.1"
Graphs = "1.4"
KrylovKit = "0.8"
LinearAlgebra = "1"
LuxDeviceUtils = "0.1.24"
MLDataDevices = "1.0"
MLDatasets = "0.7"
MLUtils = "0.4"
NNlib = "0.9"
Expand Down
2 changes: 1 addition & 1 deletion GNNGraphs/src/GNNGraphs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ using LinearAlgebra, Random, Statistics
import MLUtils
using MLUtils: getobs, numobs, ones_like, zeros_like, chunk, batch, rand_like
import Functors
using LuxDeviceUtils: get_device, cpu_device, LuxCPUDevice
using MLDataDevices: get_device, cpu_device, CPUDevice

include("chainrules.jl") # hacks for differentiability

Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/test/gnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ end
# core functionality
g = GNNGraph(s, t; graph_type = GRAPH_T)
if TEST_GPU
dev = LuxCUDADevice() #TODO replace with gpu_device()
dev = CUDADevice()
g_gpu = g |> dev
end

Expand Down Expand Up @@ -141,7 +141,7 @@ end
# core functionality
g = GNNGraph(s, t; graph_type = GRAPH_T)
if TEST_GPU
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
dev = CUDADevice() #TODO replace with `gpu_device()`
g_gpu = g |> dev
end

Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/test/query.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ end
@test eltype(degree(g, Float32)) == Float32

if TEST_GPU
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
dev = CUDADevice() #TODO replace with `gpu_device()`
g_gpu = g |> dev
d = degree(g)
d_gpu = degree(g_gpu)
Expand Down Expand Up @@ -87,7 +87,7 @@ end
@test degree(g, edge_weight = 2 * eweight) [4.4, 2.4, 2.0, 0.0] broken = (GRAPH_T != :coo)

if TEST_GPU
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
dev = CUDADevice() #TODO replace with `gpu_device()`
g_gpu = g |> dev
d = degree(g)
d_gpu = degree(g_gpu)
Expand Down
4 changes: 2 additions & 2 deletions GNNGraphs/test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@ using Test
using MLDatasets
using InlineStrings # not used but with the import we test #98 and #104
using SimpleWeightedGraphs
using LuxDeviceUtils: gpu_device, cpu_device, get_device
using LuxDeviceUtils: LuxCUDADevice # remove after https://github.com/LuxDL/LuxDeviceUtils.jl/pull/58
using MLDataDevices: gpu_device, cpu_device, get_device
using MLDataDevices: CUDADevice

CUDA.allowscalar(false)

Expand Down
2 changes: 1 addition & 1 deletion GNNGraphs/test/temporalsnapshotsgnngraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ if TEST_GPU
snapshots = [rand_graph(10, 20; ndata = rand(5,10)) for i in 1:5]
tsg = TemporalSnapshotsGNNGraph(snapshots)
tsg.tgdata.x = rand(5)
dev = LuxCUDADevice() #TODO replace with `gpu_device()`
dev = CUDADevice() #TODO replace with `gpu_device()`
tsg = tsg |> dev
@test tsg.snapshots[1].ndata.x isa CuArray
@test tsg.snapshots[end].ndata.x isa CuArray
Expand Down
3 changes: 2 additions & 1 deletion GNNLux/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,11 @@ julia = "1.10"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
LuxTestUtils = "ac9de150-d08f-4546-94fb-7472b5760531"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
ReTestItems = "817f1d60-ba6b-4fd5-9520-3cf149f6a823"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]
test = ["Test", "MLDataDevices", "ComponentArrays", "Functors", "LuxTestUtils", "ReTestItems", "StableRNGs", "Zygote"]
24 changes: 21 additions & 3 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
module GNNLux
using ConcreteStructs: @concrete
using NNlib: NNlib
using NNlib: NNlib, sigmoid, relu
using LuxCore: LuxCore, AbstractExplicitLayer, AbstractExplicitContainerLayer
using Lux: Lux, glorot_uniform, zeros32
using Lux: Lux, Dense, glorot_uniform, zeros32, StatefulLuxLayer
using Reexport: @reexport
using Random: AbstractRNG
using GNNlib: GNNlib
Expand All @@ -14,9 +14,27 @@ export GNNLayer,
GNNChain

include("layers/conv.jl")
export GCNConv,
export AGNNConv,
CGConv,
ChebConv,
EdgeConv,
# EGNNConv,
# DConv,
# GATConv,
# GATv2Conv,
# GatedGraphConv,
GCNConv,
# GINConv,
# GMMConv,
GraphConv
# MEGNetConv,
# NNConv,
# ResGatedGraphConv,
# SAGEConv,
# SGConv,
# TAGConv,
# TransformerConv


end #module

156 changes: 121 additions & 35 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
@@ -1,22 +1,7 @@
# Missing Layers

# | Layer |Sparse Ops|Edge Weight|Edge Features| Heterograph | TemporalSnapshotsGNNGraphs |
# | :-------- | :---: |:---: |:---: | :---: | :---: |
# | [`AGNNConv`](@ref) | | | ✓ | | |
# | [`CGConv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`EGNNConv`](@ref) | | | ✓ | | |
# | [`EdgeConv`](@ref) | | | | ✓ | |
# | [`GATConv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GATv2Conv`](@ref) | | | ✓ | ✓ | ✓ |
# | [`GatedGraphConv`](@ref) | ✓ | | | | ✓ |
# | [`GINConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`GMMConv`](@ref) | | | ✓ | | |
# | [`MEGNetConv`](@ref) | | | ✓ | | |
# | [`NNConv`](@ref) | | | ✓ | | |
# | [`ResGatedGraphConv`](@ref) | | | | ✓ | ✓ |
# | [`SAGEConv`](@ref) | ✓ | | | ✓ | ✓ |
# | [`SGConv`](@ref) | ✓ | | | | ✓ |
# | [`TransformerConv`](@ref) | | | ✓ | | |
_getbias(ps) = hasproperty(ps, :bias) ? getproperty(ps, :bias) : false
_getstate(st, name) = hasproperty(st, name) ? getproperty(st, name) : NamedTuple()
_getstate(s::StatefulLuxLayer{true}) = s.st
_getstate(s::StatefulLuxLayer{false}) = s.st_any


@concrete struct GCNConv <: GNNLayer
Expand Down Expand Up @@ -65,13 +50,18 @@ function Base.show(io::IO, l::GCNConv)
print(io, ")")
end

# TODO norm_fn should be keyword argument only
(l::GCNConv)(g, x, ps, st; conv_weight=nothing, edge_weight=nothing, norm_fn= d -> 1 ./ sqrt.(d)) =
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
(l::GCNConv)(g, x, edge_weight, ps, st; conv_weight=nothing, norm_fn = d -> 1 ./ sqrt.(d)) =
l(g, x, edge_weight, norm_fn, ps, st; conv_weight)
(l::GCNConv)(g, x, edge_weight, norm_fn, ps, st; conv_weight=nothing) =
GNNlib.gcn_conv(l, g, x, edge_weight, norm_fn, conv_weight, ps), st
l(g, x, edge_weight, ps, st; conv_weight, norm_fn)

function (l::GCNConv)(g, x, edge_weight, ps, st;
norm_fn = d -> 1 ./ sqrt.(d),
conv_weight=nothing, )

m = (; ps.weight, bias = _getbias(ps),
l.add_self_loops, l.use_edge_weight, l.σ)
y = GNNlib.gcn_conv(m, g, x, edge_weight, norm_fn, conv_weight)
return y, st
end

@concrete struct ChebConv <: GNNLayer
in_dims::Int
Expand All @@ -80,17 +70,14 @@ end
k::Int
init_weight
init_bias
σ
end

function ChebConv(ch::Pair{Int, Int}, k::Int, σ = identity;
function ChebConv(ch::Pair{Int, Int}, k::Int;
init_weight = glorot_uniform,
init_bias = zeros32,
use_bias::Bool = true,
allow_fast_activation::Bool = true)
use_bias::Bool = true)
in_dims, out_dims = ch
σ = allow_fast_activation ? NNlib.fast_act(σ) : σ
return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias, σ)
return ChebConv(in_dims, out_dims, use_bias, k, init_weight, init_bias)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::ChebConv)
Expand All @@ -109,13 +96,17 @@ LuxCore.statelength(d::ChebConv) = 0
LuxCore.outputsize(d::ChebConv) = (d.out_dims,)

function Base.show(io::IO, l::ChebConv)
print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", K=", l.K)
l.σ == identity || print(io, ", ", l.σ)
print(io, "ChebConv(", l.in_dims, " => ", l.out_dims, ", k=", l.k)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

(l::ChebConv)(g, x, ps, st) = GNNlib.cheb_conv(l, g, x, ps), st
function (l::ChebConv)(g, x, ps, st)
m = (; ps.weight, bias = _getbias(ps), l.k)
y = GNNlib.cheb_conv(m, g, x)
return y, st

end

@concrete struct GraphConv <: GNNLayer
in_dims::Int
Expand Down Expand Up @@ -168,4 +159,99 @@ function Base.show(io::IO, l::GraphConv)
print(io, ")")
end

(l::GraphConv)(g, x, ps, st) = GNNlib.graph_conv(l, g, x, ps), st
function (l::GraphConv)(g, x, ps, st)
m = (; ps.weight1, ps.weight2, bias = _getbias(ps),
l.σ, l.aggr)
return GNNlib.graph_conv(m, g, x), st
end


@concrete struct AGNNConv <: GNNLayer
init_beta <: AbstractVector
add_self_loops::Bool
trainable::Bool
end

function AGNNConv(; init_beta = 1.0f0, add_self_loops = true, trainable = true)
return AGNNConv([init_beta], add_self_loops, trainable)
end

function LuxCore.initialparameters(rng::AbstractRNG, l::AGNNConv)
if l.trainable
return (; β = l.init_beta)
else
return (;)
end
end

LuxCore.parameterlength(l::AGNNConv) = l.trainable ? 1 : 0
LuxCore.statelength(d::AGNNConv) = 0

function Base.show(io::IO, l::AGNNConv)
print(io, "AGNNConv(", l.init_beta)
l.add_self_loops || print(io, ", add_self_loops=false")
l.trainable || print(io, ", trainable=false")
print(io, ")")
end

function (l::AGNNConv)(g, x::AbstractMatrix, ps, st)
β = l.trainable ? ps.β : l.init_beta
m = (; β, l.add_self_loops)
return GNNlib.agnn_conv(m, g, x), st
end

@concrete struct CGConv <: GNNContainerLayer{(:dense_f, :dense_s)}
in_dims::NTuple{2, Int}
out_dims::Int
dense_f
dense_s
residual::Bool
init_weight
init_bias
end

CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...)

function CGConv(ch::Pair{NTuple{2, Int}, Int}, act = identity; residual = false,
use_bias = true, init_weight = glorot_uniform, init_bias = zeros32,
allow_fast_activation = true)
(nin, ein), out = ch
dense_f = Dense(2nin + ein => out, sigmoid; use_bias, init_weight, init_bias, allow_fast_activation)
dense_s = Dense(2nin + ein => out, act; use_bias, init_weight, init_bias, allow_fast_activation)
return CGConv((nin, ein), out, dense_f, dense_s, residual, init_weight, init_bias)
end

LuxCore.outputsize(l::CGConv) = (l.out_dims,)

(l::CGConv)(g, x, ps, st) = l(g, x, nothing, ps, st)

function (l::CGConv)(g, x, e, ps, st)
dense_f = StatefulLuxLayer{true}(l.dense_f, ps.dense_f, _getstate(st, :dense_f))
dense_s = StatefulLuxLayer{true}(l.dense_s, ps.dense_s, _getstate(st, :dense_s))
m = (; dense_f, dense_s, l.residual)
return GNNlib.cg_conv(m, g, x, e), st
end

@concrete struct EdgeConv <: GNNContainerLayer{(:nn,)}
nn <: AbstractExplicitLayer
aggr
end

EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr)

function Base.show(io::IO, l::EdgeConv)
print(io, "EdgeConv(", l.nn)
print(io, ", aggr=", l.aggr)
print(io, ")")
end


function (l::EdgeConv)(g::AbstractGNNGraph, x, ps, st)
nn = StatefulLuxLayer{true}(l.nn, ps, st)
m = (; nn, l.aggr)
y = GNNlib.edge_conv(m, g, x)
stnew = _getstate(nn)
return y, stnew
end


45 changes: 44 additions & 1 deletion GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
end

@testset "ChebConv" begin
l = ChebConv(3 => 5, 2, relu)
l = ChebConv(3 => 5, 2)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
Expand Down Expand Up @@ -47,4 +47,47 @@
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end

@testset "AGNNConv" begin
l = AGNNConv(init_beta=1.0f0)
@test l isa GNNLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(ps) == 1
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)

y, _ = l(g, x, ps, st)
@test size(y) == size(x)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end

@testset "EdgeConv" begin
nn = Chain(Dense(6 => 5, relu), Dense(5 => 5))
l = EdgeConv(nn, aggr = +)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true skip_reverse_diff=true
end

@testset "CGConv" begin
l = CGConv(3 => 5, residual = true)
@test l isa GNNContainerLayer
ps = Lux.initialparameters(rng, l)
st = Lux.initialstates(rng, l)
@test Lux.parameterlength(l) == Lux.parameterlength(ps)
@test Lux.statelength(l) == Lux.statelength(st)
y, st′ = l(g, x, ps, st)
@test size(y) == (5, 10)
@test Lux.outputsize(l) == (5,)
loss = (x, ps) -> sum(first(l(g, x, ps, st)))
@eval @test_gradients $loss $x $ps atol=1.0f-3 rtol=1.0f-3 skip_tracker=true
end
end
Loading

0 comments on commit 80c672a

Please sign in to comment.