Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

@functor ->@layer #484

Merged
merged 2 commits into from
Aug 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion GNNlib/src/msgpass.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ struct GNNConv <: GNNLayer
σ
end

Flux.@functor GNNConv
Flux.@layer GNNConv

function GNNConv(ch::Pair{Int,Int}, σ=identity)
in, out = ch
Expand Down
2 changes: 1 addition & 1 deletion docs/src/messagepassing.md
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ struct GCN{A<:AbstractMatrix, B, F} <: GNNLayer
σ::F
end

Flux.@functor GCN # allow gpu movement, select trainable params etc...
Flux.@layer GCN # allow gpu movement, select trainable params etc...

function GCN(ch::Pair{Int,Int}, σ=identity)
in, out = ch
Expand Down
4 changes: 2 additions & 2 deletions docs/src/models.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ and the *implicit modeling* style based on [`GNNChain`](@ref), more concise but
In the explicit modeling style, the model is created according to the following steps:

1. Define a new type for your model (`GNN` in the example below). Layers and submodels are fields.
2. Apply `Flux.@functor` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...)
2. Apply `Flux.@layer` to the new type to make it Flux's compatible (parameters' collection, gpu movement, etc...)
3. Optionally define a convenience constructor for your model.
4. Define the forward pass by implementing the call method for your type.
5. Instantiate the model.
Expand All @@ -30,7 +30,7 @@ struct GNN # step 1
dense
end

Flux.@functor GNN # step 2
Flux.@layer GNN # step 2

function GNN(din::Int, d::Int, dout::Int) # step 3
GNN(GCNConv(din => d),
Expand Down
2 changes: 1 addition & 1 deletion docs/tutorials/introductory_tutorials/gnn_intro_pluto.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ begin
layers::NamedTuple
end

Flux.@functor GCN # provides parameter collection, gpu movement and more
Flux.@layer GCN # provides parameter collection, gpu movement and more

function GCN(num_features, num_classes)
layers = (conv1 = GCNConv(num_features => 4),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,7 @@ begin
layers::NamedTuple
end

Flux.@functor MLP
Flux.@layer :expand MLP

function MLP(num_features, num_classes, hidden_channels; drop_rate = 0.5)
layers = (hidden = Dense(num_features => hidden_channels),
Expand Down Expand Up @@ -235,7 +235,7 @@ begin
layers::NamedTuple
end

Flux.@functor GCN # provides parameter collection, gpu movement and more
Flux.@layer GCN # provides parameter collection, gpu movement and more

function GCN(num_features, num_classes, hidden_channels; drop_rate = 0.5)
layers = (conv1 = GCNConv(num_features => hidden_channels),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ begin
dense::Dense
end

Flux.@functor GenderPredictionModel
Flux.@layer GenderPredictionModel

function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu)
mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation))
Expand Down
2 changes: 1 addition & 1 deletion examples/graph_classification_temporalbrains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ struct GenderPredictionModel
dense::Dense
end

Flux.@functor GenderPredictionModel
Flux.@layer GenderPredictionModel

function GenderPredictionModel(; nfeatures = 103, nhidden = 128, activation = relu)
mlp = Chain(Dense(nfeatures, nhidden, activation), Dense(nhidden, nhidden, activation))
Expand Down
2 changes: 1 addition & 1 deletion notebooks/gnn_intro.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -354,7 +354,7 @@
" layers::NamedTuple\n",
"end\n",
"\n",
"Flux.@functor GCN # provides parameter collection, gpu movement and more\n",
"Flux.@layer :expand GCN # provides parameter collection, gpu movement and more\n",
"\n",
"function GCN(num_features, num_classes)\n",
" layers = (conv1 = GCNConv(num_features => 4),\n",
Expand Down
2 changes: 1 addition & 1 deletion notebooks/graph_classification_solved.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -857,7 +857,7 @@
"\tact::F\n",
"end\n",
"\n",
"Flux.@functor MyConv\n",
"Flux.@layer MyConv\n",
"\n",
"function MyConv((nin, nout)::Pair, act=identity)\n",
"\tW1 = Flux.glorot_uniform(nout, nin)\n",
Expand Down
2 changes: 1 addition & 1 deletion src/GraphNeuralNetworks.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ module GraphNeuralNetworks
using Statistics: mean
using LinearAlgebra, Random
using Flux
using Flux: glorot_uniform, leakyrelu, GRUCell, @functor, batch
using Flux: glorot_uniform, leakyrelu, GRUCell, batch
using MacroTools: @forward
using NNlib
using NNlib: scatter, gather
Expand Down
4 changes: 2 additions & 2 deletions src/layers/basic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ end

WithGraph(model, g::GNNGraph; traingraph = false) = WithGraph(model, g, traingraph)

@functor WithGraph
Flux.@layer :expand WithGraph
Flux.trainable(l::WithGraph) = l.traingraph ? (; l.model, l.g) : (; l.model)

(l::WithGraph)(g::GNNGraph, x...; kws...) = l.model(g, x...; kws...)
Expand Down Expand Up @@ -107,7 +107,7 @@ struct GNNChain{T <: Union{Tuple, NamedTuple, AbstractVector}} <: GNNLayer
layers::T
end

@functor GNNChain
Flux.@layer :expand GNNChain

GNNChain(xs...) = GNNChain(xs)

Expand Down
51 changes: 25 additions & 26 deletions src/layers/conv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ struct GCNConv{W <: AbstractMatrix, B, F} <: GNNLayer
use_edge_weight::Bool
end

@functor GCNConv
Flux.@layer GCNConv

function GCNConv(ch::Pair{Int, Int}, σ = identity;
init = glorot_uniform,
Expand Down Expand Up @@ -167,7 +167,7 @@ function ChebConv(ch::Pair{Int, Int}, k::Int;
ChebConv(W, b, k)
end

@functor ChebConv
Flux.@layer ChebConv

(l::ChebConv)(g, x) = GNNlib.cheb_conv(l, g, x)

Expand Down Expand Up @@ -225,7 +225,7 @@ struct GraphConv{W <: AbstractMatrix, B, F, A} <: GNNLayer
aggr::A
end

@functor GraphConv
Flux.@layer GraphConv

function GraphConv(ch::Pair{Int, Int}, σ = identity; aggr = +,
init = glorot_uniform, bias::Bool = true)
Expand Down Expand Up @@ -300,8 +300,7 @@ l = GATConv(in_channel => out_channel, add_self_loops = false, bias = false; hea
y = l(g, x)
```
"""
struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMatrix, F, B} <:
GNNLayer
struct GATConv{DX<:Dense,DE<:Union{Dense, Nothing},DV,T,A<:AbstractMatrix,F,B} <: GNNLayer
dense_x::DX
dense_e::DE
bias::B
Expand All @@ -315,8 +314,8 @@ struct GATConv{DX <: Dense, DE <: Union{Dense, Nothing}, DV, T, A <: AbstractMat
dropout::DV
end

@functor GATConv
Flux.trainable(l::GATConv) = (dense_x = l.dense_x, dense_e = l.dense_e, bias = l.bias, a = l.a)
Flux.@layer GATConv
Flux.trainable(l::GATConv) = (; l.dense_x, l.dense_e, l.bias, l.a)

GATConv(ch::Pair{Int, Int}, args...; kws...) = GATConv((ch[1], 0) => ch[2], args...; kws...)

Expand Down Expand Up @@ -420,7 +419,7 @@ struct GATv2Conv{T, A1, A2, A3, DV, B, C <: AbstractMatrix, F} <: GNNLayer
dropout::DV
end

@functor GATv2Conv
Flux.@layer GATv2Conv
Flux.trainable(l::GATv2Conv) = (dense_i = l.dense_i, dense_j = l.dense_j, dense_e = l.dense_e, bias = l.bias, a = l.a)

function GATv2Conv(ch::Pair{Int, Int}, args...; kws...)
Expand Down Expand Up @@ -515,7 +514,7 @@ struct GatedGraphConv{W <: AbstractArray{<:Number, 3}, R, A} <: GNNLayer
aggr::A
end

@functor GatedGraphConv
Flux.@layer GatedGraphConv

function GatedGraphConv(dims::Int, num_layers::Int;
aggr = +, init = glorot_uniform)
Expand Down Expand Up @@ -572,7 +571,7 @@ struct EdgeConv{NN, A} <: GNNLayer
aggr::A
end

@functor EdgeConv
Flux.@layer :expand EdgeConv

EdgeConv(nn; aggr = max) = EdgeConv(nn, aggr)

Expand Down Expand Up @@ -626,7 +625,7 @@ struct GINConv{R <: Real, NN, A} <: GNNLayer
aggr::A
end

@functor GINConv
Flux.@layer :expand GINConv
Flux.trainable(l::GINConv) = (nn = l.nn,)

GINConv(nn, ϵ; aggr = +) = GINConv(nn, ϵ, aggr)
Expand Down Expand Up @@ -680,7 +679,7 @@ edim = 10
g = GNNGraph(s, t)

# create dense layer
nn = Dense(edim, out_channel * in_channel)
nn = Dense(edim => out_channel * in_channel)

# create layer
l = NNConv(in_channel => out_channel, nn, tanh, bias = true, aggr = +)
Expand All @@ -697,7 +696,7 @@ struct NNConv{W, B, NN, F, A} <: GNNLayer
aggr::A
end

@functor NNConv
Flux.@layer :expand NNConv

function NNConv(ch::Pair{Int, Int}, nn, σ = identity; aggr = +, bias = true,
init = glorot_uniform)
Expand Down Expand Up @@ -763,7 +762,7 @@ struct SAGEConv{W <: AbstractMatrix, B, F, A} <: GNNLayer
aggr::A
end

@functor SAGEConv
Flux.@layer SAGEConv

function SAGEConv(ch::Pair{Int, Int}, σ = identity; aggr = mean,
init = glorot_uniform, bias::Bool = true)
Expand Down Expand Up @@ -833,7 +832,7 @@ struct ResGatedGraphConv{W, B, F} <: GNNLayer
σ::F
end

@functor ResGatedGraphConv
Flux.@layer ResGatedGraphConv

function ResGatedGraphConv(ch::Pair{Int, Int}, σ = identity;
init = glorot_uniform, bias::Bool = true)
Expand Down Expand Up @@ -907,7 +906,7 @@ struct CGConv{D1, D2} <: GNNLayer
residual::Bool
end

@functor CGConv
Flux.@layer CGConv

CGConv(ch::Pair{Int, Int}, args...; kws...) = CGConv((ch[1], 0) => ch[2], args...; kws...)

Expand Down Expand Up @@ -980,7 +979,7 @@ struct AGNNConv{A <: AbstractVector} <: GNNLayer
trainable::Bool
end

@functor AGNNConv
Flux.@layer AGNNConv

Flux.trainable(l::AGNNConv) = l.trainable ? (; l.β) : (;)

Expand Down Expand Up @@ -1027,7 +1026,7 @@ struct MEGNetConv{TE, TV, A} <: GNNLayer
aggr::A
end

@functor MEGNetConv
Flux.@layer :expand MEGNetConv

MEGNetConv(ϕe, ϕv; aggr = mean) = MEGNetConv(ϕe, ϕv, aggr)

Expand Down Expand Up @@ -1108,7 +1107,7 @@ struct GMMConv{A <: AbstractMatrix, B, F} <: GNNLayer
residual::Bool
end

@functor GMMConv
Flux.@layer GMMConv

function GMMConv(ch::Pair{NTuple{2, Int}, Int},
σ = identity;
Expand Down Expand Up @@ -1191,7 +1190,7 @@ struct SGConv{A <: AbstractMatrix, B} <: GNNLayer
use_edge_weight::Bool
end

@functor SGConv
Flux.@layer SGConv

function SGConv(ch::Pair{Int, Int}, k = 1;
init = glorot_uniform,
Expand Down Expand Up @@ -1259,7 +1258,7 @@ struct TAGConv{A <: AbstractMatrix, B} <: GNNLayer
use_edge_weight::Bool
end

@functor TAGConv
Flux.@layer TAGConv

function TAGConv(ch::Pair{Int, Int}, k = 3;
init = glorot_uniform,
Expand All @@ -1269,7 +1268,7 @@ function TAGConv(ch::Pair{Int, Int}, k = 3;
in, out = ch
W = init(out, in)
b = bias ? Flux.create_bias(W, true, out) : false
TAGConv(W, b, k, add_self_loops, use_edge_weight)
return TAGConv(W, b, k, add_self_loops, use_edge_weight)
end

(l::TAGConv)(g, x, edge_weight = nothing) = GNNlib.tag_conv(l, g, x, edge_weight)
Expand Down Expand Up @@ -1343,10 +1342,10 @@ struct EGNNConv{TE, TX, TH, NF} <: GNNLayer
residual::Bool
end

@functor EGNNConv
Flux.@layer EGNNConv

function EGNNConv(ch::Pair{Int, Int}, hidden_size = 2 * ch[1]; residual = false)
EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual)
return EGNNConv((ch[1], 0) => ch[2]; hidden_size, residual)
end

#Follows reference implementation at https://github.com/vgsatorras/egnn/blob/main/models/egnn_clean/egnn_clean.py
Expand Down Expand Up @@ -1477,7 +1476,7 @@ struct TransformerConv{TW1, TW2, TW3, TW4, TW5, TW6, TFF, TBN1, TBN2} <: GNNLaye
sqrt_out::Float32
end

@functor TransformerConv
Flux.@layer TransformerConv

function Flux.trainable(l::TransformerConv)
(; l.W1, l.W2, l.W3, l.W4, l.W5, l.W6, l.FF, l.BN1, l.BN2)
Expand Down Expand Up @@ -1568,7 +1567,7 @@ struct DConv <: GNNLayer
k::Int
end

@functor DConv
Flux.@layer DConv

function DConv(ch::Pair{Int, Int}, k::Int; init = glorot_uniform, bias = true)
in, out = ch
Expand Down
2 changes: 1 addition & 1 deletion src/layers/heteroconv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ struct HeteroGraphConv
aggr::Function
end

Flux.@functor HeteroGraphConv
Flux.@layer HeteroGraphConv

HeteroGraphConv(itr::Dict; aggr = +) = HeteroGraphConv(pairs(itr); aggr)
HeteroGraphConv(itr::Pair...; aggr = +) = HeteroGraphConv(itr; aggr)
Expand Down
4 changes: 2 additions & 2 deletions src/layers/pool.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ struct GlobalAttentionPool{G, F}
ffeat::F
end

@functor GlobalAttentionPool
Flux.@layer GlobalAttentionPool

GlobalAttentionPool(fgate) = GlobalAttentionPool(fgate, identity)

Expand Down Expand Up @@ -146,7 +146,7 @@ struct Set2Set{L} <: GNNLayer
num_iters::Int
end

@functor Set2Set
Flux.@layer Set2Set

function Set2Set(n_in::Int, n_iters::Int, n_layers::Int = 1)
@assert n_layers >= 1
Expand Down
Loading
Loading