diff --git a/GNNLux/src/layers/conv.jl b/GNNLux/src/layers/conv.jl index f0b51066b..2c0b0fa99 100644 --- a/GNNLux/src/layers/conv.jl +++ b/GNNLux/src/layers/conv.jl @@ -5,7 +5,80 @@ _getstate(s::StatefulLuxLayer{Static.True}) = s.st _getstate(s::StatefulLuxLayer{false}) = s.st_any _getstate(s::StatefulLuxLayer{Static.False}) = s.st_any - +@doc raw""" + GCNConv(in => out, σ=identity; [init_weight, init_bias, use_bias, add_self_loops, use_edge_weight]) + +Graph convolutional layer from paper [Semi-supervised Classification with Graph Convolutional Networks](https://arxiv.org/abs/1609.02907). + +Performs the operation +```math +\mathbf{x}'_i = \sum_{j\in N(i)} a_{ij} W \mathbf{x}_j +``` +where ``a_{ij} = 1 / \sqrt{|N(i)||N(j)|}`` is a normalization factor computed from the node degrees. + +If the input graph has weighted edges and `use_edge_weight=true`, than ``a_{ij}`` will be computed as +```math +a_{ij} = \frac{e_{j\to i}}{\sqrt{\sum_{j \in N(i)} e_{j\to i}} \sqrt{\sum_{i \in N(j)} e_{i\to j}}} +``` + +# Arguments + +- `in`: Number of input features. +- `out`: Number of output features. +- `σ`: Activation function. Default `identity`. +- `init_weight`: Weights' initializer. Default `glorot_uniform`. +- `init_bias`: Bias initializer. Default `zeros32`. +- `use_bias`: Add learnable bias. Default `true`. +- `add_self_loops`: Add self loops to the graph before performing the convolution. Default `false`. +- `use_edge_weight`: If `true`, consider the edge weights in the input graph (if available). + If `add_self_loops=true` the new weights will be set to 1. + This option is ignored if the `edge_weight` is explicitly provided in the forward pass. + Default `false`. + +# Forward + + (::GCNConv)(g, x, [edge_weight], ps, st; norm_fn = d -> 1 ./ sqrt.(d), conv_weight=nothing) + +Takes as input a graph `g`, a node feature matrix `x` of size `[in, num_nodes]`, optionally an edge weight vector and the parameter and state of the layer. Returns a node feature matrix of size +`[out, num_nodes]`. + +The `norm_fn` parameter allows for custom normalization of the graph convolution operation by passing a function as argument. +By default, it computes ``\frac{1}{\sqrt{d}}`` i.e the inverse square root of the degree (`d`) of each node in the graph. +If `conv_weight` is an `AbstractMatrix` of size `[out, in]`, then the convolution is performed using that weight matrix. + +# Examples + +```julia +using GNNLux, Lux, Random +# initialize random number generator +rng = Random.default_rng() +Random.seed!(rng, 0) +# create data +s = [1,1,2,3] +t = [2,3,1,1] +g = GNNGraph(s, t) +x = randn(Float32, 3, g.num_nodes) + +# create layer +l = GCNConv(3 => 5) + +# setup layer +ps, st = LuxCore.setup(rng, l) + +# forward pass +y = l(g, x, ps, st) # size of the output first entry: 5 × num_nodes + +# convolution with edge weights and custom normalization function +w = [1.1, 0.1, 2.3, 0.5] +custom_norm_fn(d) = 1 ./ sqrt.(d + 1) # Custom normalization function +y = l(g, x, w, ps, st; norm_fn = custom_norm_fn) + +# Edge weights can also be embedded in the graph. +g = GNNGraph(s, t, w) +l = GCNConv(3 => 5, use_edge_weight=true) +y = l(g, x, ps, st) # same as l(g, x, w) +``` +""" @concrete struct GCNConv <: GNNLayer in_dims::Int out_dims::Int @@ -18,7 +91,7 @@ _getstate(s::StatefulLuxLayer{Static.False}) = s.st_any end function GCNConv(ch::Pair{Int, Int}, σ = identity; - init_weight = glorot_uniform, + init_weight = glorot_uniform, init_bias = zeros32, use_bias::Bool = true, add_self_loops::Bool = true, @@ -55,7 +128,7 @@ end function (l::GCNConv)(g, x, edge_weight, ps, st; norm_fn = d -> 1 ./ sqrt.(d), - conv_weight=nothing, ) + conv_weight=nothing) m = (; ps.weight, bias = _getbias(ps), l.add_self_loops, l.use_edge_weight, l.σ)