Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [GNNLux] Adding TransformerConv Layer #501

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
6 changes: 3 additions & 3 deletions GNNLux/src/GNNLux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,9 @@ export AGNNConv,
NNConv,
ResGatedGraphConv,
# SAGEConv,
SGConv
SGConv,
# TAGConv,
# TransformerConv
TransformerConv

include("layers/temporalconv.jl")
export TGCN,
Expand All @@ -49,4 +49,4 @@ export TGCN,
EvolveGCNO

end #module


117 changes: 117 additions & 0 deletions GNNLux/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -844,3 +844,120 @@ function Base.show(io::IO, l::ResGatedGraphConv)
l.use_bias || print(io, ", use_bias=false")
print(io, ")")
end

@concrete struct TransformerConv <: GNNContainerLayer{(:W1, :W2, :W3, :W4, :W5, :W6, :FF, :BN1, :BN2)}
in_dims::NTuple{2, Int}
out_dims::Int
heads::Int
add_self_loops::Bool
concat::Bool
skip_connection::Bool
sqrt_out::Float32
W1
W2
W3
W4
W5
W6
FF
BN1
BN2
end

function TransformerConv(ch::Pair{Int, Int}, args...; kws...)
return TransformerConv((ch[1], 0) => ch[2], args...; kws...)
end

function TransformerConv(ch::Pair{NTuple{2, Int}, Int};
heads::Int = 1,
concat::Bool = true,
init_weight = glorot_uniform,
init_bias = zeros32,
add_self_loops::Bool = false,
bias_qkv = true,
bias_root::Bool = true,
root_weight::Bool = true,
gating::Bool = false,
skip_connection::Bool = false,
batch_norm::Bool = false,
ff_channels::Int = 0)
(in, ein), out = ch

if add_self_loops
@assert iszero(ein) "Using edge features and setting add_self_loops=true at the same time is not yet supported."
end

if skip_connection
@assert in == (concat ? out * heads : out) "In-channels must correspond to out-channels * heads (or just out_channels if concat=false) if skip_connection is used"
end

W1 = root_weight ? Dense(in => out * (concat ? heads : 1); use_bias=bias_root, init_weight, init_bias) : nothing
W2 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
W3 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
W4 = Dense(in => out * heads; use_bias=bias_qkv, init_weight, init_bias)
out_mha = out * (concat ? heads : 1)
W5 = gating ? Dense(3 * out_mha => 1, sigmoid; use_bias=false, init_weight, init_bias) : nothing
W6 = ein > 0 ? Dense(ein => out * heads; use_bias=bias_qkv, init_weight, init_bias) : nothing
FF = ff_channels > 0 ?
Chain(Dense(out_mha => ff_channels, relu; init_weight, init_bias),
Dense(ff_channels => out_mha; init_weight, init_bias)) : nothing
BN1 = batch_norm ? BatchNorm(out_mha) : nothing
BN2 = (batch_norm && ff_channels > 0) ? BatchNorm(out_mha) : nothing

return TransformerConv((in, ein), out, heads, add_self_loops, concat,
skip_connection, Float32(√out), W1, W2, W3, W4, W5, W6, FF, BN1, BN2)
end

LuxCore.outputsize(l::TransformerConv) = (l.concat ? l.out_dims * l.heads : l.out_dims,)

function (l::TransformerConv)(g, x, ps, st)
return l(g, x, nothing, ps, st)
end

function (l::TransformerConv)(g, x, e, ps, st)
W1 = l.W1 === nothing ? nothing :
StatefulLuxLayer{true}(l.W1, ps.W1, _getstate(st, :W1))
W2 = StatefulLuxLayer{true}(l.W2, ps.W2, _getstate(st, :W2))
W3 = StatefulLuxLayer{true}(l.W3, ps.W3, _getstate(st, :W3))
W4 = StatefulLuxLayer{true}(l.W4, ps.W4, _getstate(st, :W4))
W5 = l.W5 === nothing ? nothing :
StatefulLuxLayer{true}(l.W5, ps.W5, _getstate(st, :W5))
W6 = l.W6 === nothing ? nothing :
StatefulLuxLayer{true}(l.W6, ps.W6, _getstate(st, :W6))
FF = l.FF === nothing ? nothing :
StatefulLuxLayer{true}(l.FF, ps.FF, _getstate(st, :FF))
BN1 = l.BN1 === nothing ? nothing :
StatefulLuxLayer{true}(l.BN1, ps.BN1, _getstate(st, :BN1))
BN2 = l.BN2 === nothing ? nothing :
StatefulLuxLayer{true}(l.BN2, ps.BN2, _getstate(st, :BN2))
m = (; W1, W2, W3, W4, W5, W6, FF, BN1, BN2, l.sqrt_out,
l.heads, l.concat, l.skip_connection, l.add_self_loops, l.in_dims, l.out_dims)
return GNNlib.transformer_conv(m, g, x, e), st
end

function LuxCore.parameterlength(l::TransformerConv)
n = parameterlength(l.W2) + parameterlength(l.W3) + parameterlength(l.W4)
n += l.W1 === nothing ? 0 : parameterlength(l.W1)
n += l.W5 === nothing ? 0 : parameterlength(l.W5)
n += l.W6 === nothing ? 0 : parameterlength(l.W6)
n += l.FF === nothing ? 0 : parameterlength(l.FF)
n += l.BN1 === nothing ? 0 : parameterlength(l.BN1)
n += l.BN2 === nothing ? 0 : parameterlength(l.BN2)
return n
end

function LuxCore.statelength(l::TransformerConv)
n = statelength(l.W2) + statelength(l.W3) + statelength(l.W4)
n += l.W1 === nothing ? 0 : statelength(l.W1)
n += l.W5 === nothing ? 0 : statelength(l.W5)
n += l.W6 === nothing ? 0 : statelength(l.W6)
n += l.FF === nothing ? 0 : statelength(l.FF)
n += l.BN1 === nothing ? 0 : statelength(l.BN1)
n += l.BN2 === nothing ? 0 : statelength(l.BN2)
return n
end

function Base.show(io::IO, l::TransformerConv)
(in, ein), out = (l.in_dims, l.out_dims)
print(io, "TransformerConv(($in, $ein) => $out, heads=$(l.heads))")
end
12 changes: 12 additions & 0 deletions GNNLux/test/layers/conv_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,18 @@
out_dims = 5
x = randn(rng, Float32, in_dims, 10)

@testset "TransformerConv" begin
x = randn(rng, Float32, 6, 10)
ein = 2
e = randn(rng, Float32, ein, g.num_edges)

l = TransformerConv((6, ein) => 8, heads = 2, gating = true, bias_qkv = true)
test_lux_layer(rng, l, g, x, outputsize = (16,), e = e, container = true)

# l = TransformerConv((16, ein) => 16, heads = 2, concat = false, skip_connection = true)
# test_lux_layer(rng, l, g, x, outputsize = (16,), e = e, container = true)
end

@testset "GCNConv" begin
l = GCNConv(in_dims => out_dims, tanh)
test_lux_layer(rng, l, g, x, outputsize=(out_dims,))
Expand Down
2 changes: 1 addition & 1 deletion GNNlib/src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -559,7 +559,7 @@ function transformer_conv(l, g::GNNGraph, x::AbstractMatrix, e::Union{AbstractM
g = add_self_loops(g)
end

out = l.channels[2]
out = l.out_dims
heads = l.heads
W1x = !isnothing(l.W1) ? l.W1(x) : nothing
W2x = reshape(l.W2(x), out, heads, :)
Expand Down
Loading