diff --git a/src/GraphNeuralNetworks.jl b/src/GraphNeuralNetworks.jl index ae4440f22..2f2f94839 100644 --- a/src/GraphNeuralNetworks.jl +++ b/src/GraphNeuralNetworks.jl @@ -73,6 +73,7 @@ export # layers/temporalconv TGCN, A3TGCN, + GConvGRU, # layers/pool GlobalPool, diff --git a/src/layers/temporalconv.jl b/src/layers/temporalconv.jl index 628c3adbe..3717374c6 100644 --- a/src/layers/temporalconv.jl +++ b/src/layers/temporalconv.jl @@ -187,6 +187,98 @@ function Base.show(io::IO, a3tgcn::A3TGCN) print(io, "A3TGCN($(a3tgcn.in) => $(a3tgcn.out))") end +struct GConvGRUCell <: GNNLayer + conv_x_r::ChebConv + conv_h_r::ChebConv + conv_x_z::ChebConv + conv_h_z::ChebConv + conv_x_h::ChebConv + conv_h_h::ChebConv + k::Int + state0 + in::Int + out::Int +end + +Flux.@functor GConvGRUCell + +function GConvGRUCell(ch::Pair{Int, Int}, k::Int, n::Int; + bias::Bool = true, + init = Flux.glorot_uniform, + init_state = Flux.zeros32) + in, out = ch + # reset gate + conv_x_r = ChebConv(in => out, k; bias, init) + conv_h_r = ChebConv(out => out, k; bias, init) + # update gate + conv_x_z = ChebConv(in => out, k; bias, init) + conv_h_z = ChebConv(out => out, k; bias, init) + # new gate + conv_x_h = ChebConv(in => out, k; bias, init) + conv_h_h = ChebConv(out => out, k; bias, init) + state0 = init_state(out, n) + return GConvGRUCell(conv_x_r, conv_h_r, conv_x_z, conv_h_z, conv_x_h, conv_h_h, k, state0, in, out) +end + +function (ggru::GConvGRUCell)(h, g::GNNGraph, x) + r = ggru.conv_x_r(g, x) .+ ggru.conv_h_r(g, h) + r = Flux.sigmoid_fast(r) + z = ggru.conv_x_z(g, x) .+ ggru.conv_h_z(g, h) + z = Flux.sigmoid_fast(z) + h̃ = ggru.conv_x_h(g, x) .+ ggru.conv_h_h(g, r .* h) + h̃ = Flux.tanh_fast(h̃) + h = (1 .- z) .* h̃ .+ z .* h + return h, h +end + +function Base.show(io::IO, ggru::GConvGRUCell) + print(io, "GConvGRUCell($(ggru.in) => $(ggru.out))") +end + +""" + GConvGRU(in => out, k, n; [bias, init, init_state]) + +Graph Convolutional Gated Recurrent Unit (GConvGRU) recurrent layer from the paper [Structured Sequence Modeling with Graph Convolutional Recurrent Networks](https://arxiv.org/pdf/1612.07659). + +Performs a layer of ChebConv to model spatial dependencies, followed by a Gated Recurrent Unit (GRU) cell to model temporal dependencies. + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `k`: Chebyshev polynomial order. +- `n`: Number of nodes in the graph. +- `bias`: Add learnable bias. Default `true`. +- `init`: Weights' initializer. Default `glorot_uniform`. +- `init_state`: Initial state of the hidden stat of the GRU layer. Default `zeros32`. + +# Examples + +```jldoctest +julia> g1, x1 = rand_graph(5, 10), rand(Float32, 2, 5); + +julia> ggru = GConvGRU(2 => 5, 2, g1.num_nodes); + +julia> y = ggru(g1, x1); + +julia> size(y) +(5, 5) + +julia> g2, x2 = rand_graph(5, 10), rand(Float32, 2, 5, 30); + +julia> z = ggru(g2, x2); + +julia> size(z) +(5, 5, 30) +``` +""" +GConvGRU(ch, k, n; kwargs...) = Flux.Recur(GConvGRUCell(ch, k, n; kwargs...)) +Flux.Recur(ggru::GConvGRUCell) = Flux.Recur(ggru, ggru.state0) + +(l::Flux.Recur{GConvGRUCell})(g::GNNGraph) = GNNGraph(g, ndata = l(g, node_features(g))) +_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph, x) = l(g, x) +_applylayer(l::Flux.Recur{GConvGRUCell}, g::GNNGraph) = l(g) + function (l::GINConv)(tg::TemporalSnapshotsGNNGraph, x::AbstractVector) return l.(tg.snapshots, x) end diff --git a/test/layers/temporalconv.jl b/test/layers/temporalconv.jl index 28d5cf863..a947a7e41 100644 --- a/test/layers/temporalconv.jl +++ b/test/layers/temporalconv.jl @@ -34,6 +34,20 @@ end @test model(g1) isa GNNGraph end +@testset "GConvGRUCell" begin + gconvlstm = GraphNeuralNetworks.GConvGRUCell(in_channel => out_channel, 2, g1.num_nodes) + h, h = gconvlstm(gconvlstm.state0, g1, g1.ndata.x) + @test size(h) == (out_channel, N) +end + +@testset "GConvGRU" begin + gconvlstm = GConvGRU(in_channel => out_channel, 2, g1.num_nodes) + @test size(Flux.gradient(x -> sum(gconvlstm(g1, x)), g1.ndata.x)[1]) == (in_channel, N) + model = GNNChain(GConvGRU(in_channel => out_channel, 2, g1.num_nodes), Dense(out_channel, 1)) + @test size(model(g1, g1.ndata.x)) == (1, N) + @test model(g1) isa GNNGraph +end + @testset "GINConv" begin ginconv = GINConv(Dense(in_channel => out_channel),0.3) @test length(ginconv(tg, tg.ndata.x)) == S